ot_markov_distances.discounted_wl module

This module contains the implementation of the discounted WL distance, with its forward and backward pass (implemented as a torch.autograd.Function)

The depth-\(\infty\) version can be computed with the function wl_reg_infty(). The depth-\(k\) version can be computed with the function wl_reg_k().

ot_markov_distances.discounted_wl.discounted_wl_k(MX: Tensor, MY: Tensor, l1: Tensor | None = None, l2: Tensor | None = None, *, cost_matrix: Tensor | None = None, delta: Tensor | float = 0.4, k: int, muX: Tensor | None = None, muY: Tensor | None = None, reg: float = 0.1, sinkhorn_parameters: dict = {}, return_differences: bool = False)

computes the discounted WL distance

computes the WL-delta distance between two markov transition matrices (represented as torch tensor)

This function does not have the backward pass mentioned in the paper, because that formula is only valid for the case \(k=\infty\)

Batched over first dimension (b)

Parameters:
  • MX – (b, n, n) first transition tensor

  • MY – (b, m, m) second transition tensor

  • l1 – (b, n,) label values for the first space

  • l2 – (b, m,) label values for the second space

  • cost_matrix – (b, n, m) allows specifying the cost matrix instead

  • k – number of steps (k parameter for the WL distance)

  • muX – distribution for MX (if omitted, the stationary distrubution will be used)

  • muY – distribution for MY (if omitted, the stationary distrubution will be used)

  • reg – regularization parameter for sinkhorn

ot_markov_distances.discounted_wl.discounted_wl_infty(MX: Tensor, MY: Tensor, distance_matrix: Tensor, muX: Tensor | None = None, muY: Tensor | None = None, delta: float = 0.4, sinkhorn_reg: float = 0.01, max_iter: int = 50, convergence_threshold_rtol: float = 0.005, convergence_threshold_atol: float = 1e-06, sinkhorn_iter: int = 100, sinkhorn_iter_schedule: int = 10, x_is_sparse: bool | None = None, y_is_sparse: bool | None = None)

Discounted WL infinity distance

Computes the discounted WL infinity distance between (MX, muX) and (MY, muY) with cost matrix distance_matrix and discount factor delta.

Parameters:
  • MX – (b, n, n) first transition tensor

  • MY – (b, m, m) second transition tensor

  • distance_matrix – [TODO:description]

  • muX – initial distribution for MX (if omitted, the stationary distribution will be used instead)

  • muY – initial distribution for MY (if omitted, the stationary distribution will be used instead)

  • delta – discount factor

  • sinkhorn_reg – regularization parameter for the sinkhorn algorithm

  • max_iter – maximum number of iterations.

  • convergence_threshold_rtol – relative tolerance for convergence criterion (see torch.allclose)

  • convergence_threshold_atol – absolute tolerance for convergence criterion (see torch.allclose)

  • sinkhorn_iter – maximum number of sinkhorn iteration

  • ([TODO (sinkhorn_iter_schedule) – type]): [TODO:description]

  • x_is_sparse – whether to use the accelerated algorithm, considering MX is sparse (default: compute the degree, and check whether it lower that 2/3 n. If so, consider MX sparse)

  • y_is_sparse – whether to use the accelerated algorithm, considering MY is sparse (default: compute the degree, and check whether it lower that 2/3 m. if so, consider MY sparse)

ot_markov_distances.discounted_wl.discounted_wl_infty_cost_matrix(MX: Tensor, MY: Tensor, distance_matrix: Tensor, delta: float = 0.4, sinkhorn_reg: float = 0.01, max_iter: int = 50, convergence_threshold_rtol: float = 0.005, convergence_threshold_atol: float = 1e-06, sinkhorn_iter: int = 100, return_differences: bool = False, sinkhorn_iter_schedule: int = 10, x_is_sparse: bool | None = None, y_is_sparse: bool | None = None)