ot_markov_distances package

Submodules

Module contents

This module contains the implementation of the discounted wl distance for the paper “Distances for Markov Chains, and Their Differentiation” [BWW23]

ot_markov_distances.sinkhorn_distance(a: Tensor, b: Tensor, C: Tensor, epsilon: float, max_iter: int = 100, *, check_convergence_interval: int | float = 0.1, cv_atol=0.0001, cv_rtol=1e-05, return_has_converged: Literal[True, False] = False) Tensor

Differentiable sinkhorn distance

This is a pytorch implementation of sinkhorn, batched (over a, b and C)

It is compatible with pytorch autograd gradient computations.

See the documentation of Sinkhorn for details.

Parameters:
  • a – (*batch, n) vector of the first distribution

  • b – (*batch, m) vector of the second distribtion

  • C – (*batch, n, m) cost matrix

  • epsilon – regularisation term for sinkhorn

  • max_iter – max number of sinkhorn iterations (default 100)

  • check_convergence_interval – if int, check for convergence every check_convergence_interval. If float, check for convergence every check_convergence_interval * max_iter. If 0, never check for convergence (apart from the last iteration if return_has_converged==True) If convergence is reached early, the algorithm returns.

  • cv_atol – absolute and relative tolerance for the converegence criterion

  • cv_rtol – absolute and relative tolerance for the converegence criterion

  • return_has_converged – whether to return a boolean indicating whether the algorithm has converged. Setting this to True means that the function will always check for convergence at the last iteration (regardless of the value of check_convergence_interval)

Returns:

Tensor – (*batch). result of the sinkhorn computation

ot_markov_distances.wl_k(MX: Tensor, MY: Tensor, l1: Tensor | None = None, l2: Tensor | None = None, *, cost_matrix: Tensor | None = None, k: int, muX: Tensor | None = None, muY: Tensor | None = None, reg: float = 0.1, sinkhorn_parameters: dict = {})

computes the WL distance

computes the WL distance between two markov transition matrices (represented as torch tensor)

Batched over first dimension (b)

Parameters:
  • MX – (b, n, n) first transition tensor

  • MY – (b, m, m) second transition tensor

  • l1 – (b, n,) label values for the first space

  • l2 – (b, m,) label values for the second space

  • k – number of steps (k parameter for the WL distance)

  • muX – stationary distribution for MX (if omitted, will be recomuputed)

  • muY – stationary distribution for MY (if omitted, will be recomuputed)

  • reg – regularization parameter for sinkhorn

ot_markov_distances.markov_measure(M: Tensor) Tensor

Takes a (batched) markov transition matrix, and outputs its stationary distribution

Parameters:

M – (*b, n, n) the markov transition matrix

Returns:

Tensor – m (*b, n) so that m @ b = m

ot_markov_distances.discounted_wl_infty(MX: Tensor, MY: Tensor, distance_matrix: Tensor, muX: Tensor | None = None, muY: Tensor | None = None, delta: float = 0.4, sinkhorn_reg: float = 0.01, max_iter: int = 50, convergence_threshold_rtol: float = 0.005, convergence_threshold_atol: float = 1e-06, sinkhorn_iter: int = 100, sinkhorn_iter_schedule: int = 10, x_is_sparse: bool | None = None, y_is_sparse: bool | None = None)

Discounted WL infinity distance

Computes the discounted WL infinity distance between (MX, muX) and (MY, muY) with cost matrix distance_matrix and discount factor delta.

Parameters:
  • MX – (b, n, n) first transition tensor

  • MY – (b, m, m) second transition tensor

  • distance_matrix – [TODO:description]

  • muX – initial distribution for MX (if omitted, the stationary distribution will be used instead)

  • muY – initial distribution for MY (if omitted, the stationary distribution will be used instead)

  • delta – discount factor

  • sinkhorn_reg – regularization parameter for the sinkhorn algorithm

  • max_iter – maximum number of iterations.

  • convergence_threshold_rtol – relative tolerance for convergence criterion (see torch.allclose)

  • convergence_threshold_atol – absolute tolerance for convergence criterion (see torch.allclose)

  • sinkhorn_iter – maximum number of sinkhorn iteration

  • ([TODO (sinkhorn_iter_schedule) – type]): [TODO:description]

  • x_is_sparse – whether to use the accelerated algorithm, considering MX is sparse (default: compute the degree, and check whether it lower that 2/3 n. If so, consider MX sparse)

  • y_is_sparse – whether to use the accelerated algorithm, considering MY is sparse (default: compute the degree, and check whether it lower that 2/3 m. if so, consider MY sparse)

ot_markov_distances.discounted_wl_k(MX: Tensor, MY: Tensor, l1: Tensor | None = None, l2: Tensor | None = None, *, cost_matrix: Tensor | None = None, delta: Tensor | float = 0.4, k: int, muX: Tensor | None = None, muY: Tensor | None = None, reg: float = 0.1, sinkhorn_parameters: dict = {}, return_differences: bool = False)

computes the discounted WL distance

computes the WL-delta distance between two markov transition matrices (represented as torch tensor)

This function does not have the backward pass mentioned in the paper, because that formula is only valid for the case \(k=\infty\)

Batched over first dimension (b)

Parameters:
  • MX – (b, n, n) first transition tensor

  • MY – (b, m, m) second transition tensor

  • l1 – (b, n,) label values for the first space

  • l2 – (b, m,) label values for the second space

  • cost_matrix – (b, n, m) allows specifying the cost matrix instead

  • k – number of steps (k parameter for the WL distance)

  • muX – distribution for MX (if omitted, the stationary distrubution will be used)

  • muY – distribution for MY (if omitted, the stationary distrubution will be used)

  • reg – regularization parameter for sinkhorn