ot_markov_distances.wl module

ot_markov_distances.wl.wl_k(MX: Tensor, MY: Tensor, l1: Tensor | None = None, l2: Tensor | None = None, *, cost_matrix: Tensor | None = None, k: int, muX: Tensor | None = None, muY: Tensor | None = None, reg: float = 0.1, sinkhorn_parameters: dict = {})

computes the WL distance

computes the WL distance between two markov transition matrices (represented as torch tensor)

Batched over first dimension (b)

Parameters:
  • MX – (b, n, n) first transition tensor

  • MY – (b, m, m) second transition tensor

  • l1 – (b, n,) label values for the first space

  • l2 – (b, m,) label values for the second space

  • k – number of steps (k parameter for the WL distance)

  • muX – stationary distribution for MX (if omitted, will be recomuputed)

  • muY – stationary distribution for MY (if omitted, will be recomuputed)

  • reg – regularization parameter for sinkhorn

ot_markov_distances.wl.wl_k_sparse(MX: Tensor, MY: Tensor, l1: Tensor | None = None, l2: Tensor | None = None, *, cost_matrix: Tensor | None = None, k: int, muX: Tensor | None = None, muY: Tensor | None = None, reg: float = 0.1, sinkhorn_parameters: dict = {})

computes the WL distance, sparse version

computes the WL distance between two markov transition matrices (represented as torch tensor)

Batched over first dimension (b)

Parameters:
  • MX – (b, n, n) first transition tensor

  • MY – (b, m, m) second transition tensor

  • l1 – (b, n,) label values for the first space

  • l2 – (b, m,) label values for the second space

  • k – number of steps (k parameter for the WL distance)

  • muX – stationary distribution for MX (if omitted, will be recomuputed)

  • muY – stationary distribution for MY (if omitted, will be recomuputed)

  • reg – regularization parameter for sinkhorn