ot_markov_distances.wl module
- ot_markov_distances.wl.wl_k(MX: Tensor, MY: Tensor, l1: Tensor | None = None, l2: Tensor | None = None, *, cost_matrix: Tensor | None = None, k: int, muX: Tensor | None = None, muY: Tensor | None = None, reg: float = 0.1, sinkhorn_parameters: dict = {})
computes the WL distance
computes the WL distance between two markov transition matrices (represented as torch tensor)
Batched over first dimension (b)
- Parameters:
MX – (b, n, n) first transition tensor
MY – (b, m, m) second transition tensor
l1 – (b, n,) label values for the first space
l2 – (b, m,) label values for the second space
k – number of steps (k parameter for the WL distance)
muX – stationary distribution for MX (if omitted, will be recomuputed)
muY – stationary distribution for MY (if omitted, will be recomuputed)
reg – regularization parameter for sinkhorn
- ot_markov_distances.wl.wl_k_sparse(MX: Tensor, MY: Tensor, l1: Tensor | None = None, l2: Tensor | None = None, *, cost_matrix: Tensor | None = None, k: int, muX: Tensor | None = None, muY: Tensor | None = None, reg: float = 0.1, sinkhorn_parameters: dict = {})
computes the WL distance, sparse version
computes the WL distance between two markov transition matrices (represented as torch tensor)
Batched over first dimension (b)
- Parameters:
MX – (b, n, n) first transition tensor
MY – (b, m, m) second transition tensor
l1 – (b, n,) label values for the first space
l2 – (b, m,) label values for the second space
k – number of steps (k parameter for the WL distance)
muX – stationary distribution for MX (if omitted, will be recomuputed)
muY – stationary distribution for MY (if omitted, will be recomuputed)
reg – regularization parameter for sinkhorn