ot_markov_distances.discounted_wl module
This module contains the implementation of the discounted WL distance, with its forward and backward pass (implemented as a torch.autograd.Function)
The depth-\(\infty\) version can be computed with the function wl_reg_infty()
.
The depth-\(k\) version can be computed with the function wl_reg_k()
.
- ot_markov_distances.discounted_wl.discounted_wl_k(MX: Tensor, MY: Tensor, l1: Tensor | None = None, l2: Tensor | None = None, *, cost_matrix: Tensor | None = None, delta: Tensor | float = 0.4, k: int, muX: Tensor | None = None, muY: Tensor | None = None, reg: float = 0.1, sinkhorn_parameters: dict = {}, return_differences: bool = False)
computes the discounted WL distance
computes the WL-delta distance between two markov transition matrices (represented as torch tensor)
This function does not have the backward pass mentioned in the paper, because that formula is only valid for the case \(k=\infty\)
Batched over first dimension (b)
- Parameters:
MX – (b, n, n) first transition tensor
MY – (b, m, m) second transition tensor
l1 – (b, n,) label values for the first space
l2 – (b, m,) label values for the second space
cost_matrix – (b, n, m) allows specifying the cost matrix instead
k – number of steps (k parameter for the WL distance)
muX – distribution for MX (if omitted, the stationary distrubution will be used)
muY – distribution for MY (if omitted, the stationary distrubution will be used)
reg – regularization parameter for sinkhorn
- ot_markov_distances.discounted_wl.discounted_wl_infty(MX: Tensor, MY: Tensor, distance_matrix: Tensor, muX: Tensor | None = None, muY: Tensor | None = None, delta: float = 0.4, sinkhorn_reg: float = 0.01, max_iter: int = 50, convergence_threshold_rtol: float = 0.005, convergence_threshold_atol: float = 1e-06, sinkhorn_iter: int = 100, sinkhorn_iter_schedule: int = 10, x_is_sparse: bool | None = None, y_is_sparse: bool | None = None)
Discounted WL infinity distance
Computes the discounted WL infinity distance between
(MX, muX)
and(MY, muY)
with cost matrixdistance_matrix
and discount factordelta
.- Parameters:
MX – (b, n, n) first transition tensor
MY – (b, m, m) second transition tensor
distance_matrix – [TODO:description]
muX – initial distribution for MX (if omitted, the stationary distribution will be used instead)
muY – initial distribution for MY (if omitted, the stationary distribution will be used instead)
delta – discount factor
sinkhorn_reg – regularization parameter for the sinkhorn algorithm
max_iter – maximum number of iterations.
convergence_threshold_rtol – relative tolerance for convergence criterion (see
torch.allclose
)convergence_threshold_atol – absolute tolerance for convergence criterion (see
torch.allclose
)sinkhorn_iter – maximum number of sinkhorn iteration
([TODO (sinkhorn_iter_schedule) – type]): [TODO:description]
x_is_sparse – whether to use the accelerated algorithm, considering MX is sparse (default: compute the degree, and check whether it lower that 2/3 n. If so, consider MX sparse)
y_is_sparse – whether to use the accelerated algorithm, considering MY is sparse (default: compute the degree, and check whether it lower that 2/3 m. if so, consider MY sparse)
- ot_markov_distances.discounted_wl.discounted_wl_infty_cost_matrix(MX: Tensor, MY: Tensor, distance_matrix: Tensor, delta: float = 0.4, sinkhorn_reg: float = 0.01, max_iter: int = 50, convergence_threshold_rtol: float = 0.005, convergence_threshold_atol: float = 1e-06, sinkhorn_iter: int = 100, return_differences: bool = False, sinkhorn_iter_schedule: int = 10, x_is_sparse: bool | None = None, y_is_sparse: bool | None = None)