ot_markov_distances package
Submodules
Module contents
This module contains the implementation of the discounted wl distance for the paper “Distances for Markov Chains, and Their Differentiation” [BWW23]
- ot_markov_distances.sinkhorn_distance(a: Tensor, b: Tensor, C: Tensor, epsilon: float, max_iter: int = 100, *, check_convergence_interval: int | float = 0.1, cv_atol=0.0001, cv_rtol=1e-05, return_has_converged: Literal[True, False] = False) Tensor
Differentiable sinkhorn distance
This is a pytorch implementation of sinkhorn, batched (over
a
,b
andC
)It is compatible with pytorch autograd gradient computations.
See the documentation of
Sinkhorn
for details.- Parameters:
a – (*batch, n) vector of the first distribution
b – (*batch, m) vector of the second distribtion
C – (*batch, n, m) cost matrix
epsilon – regularisation term for sinkhorn
max_iter – max number of sinkhorn iterations (default 100)
check_convergence_interval – if int, check for convergence every
check_convergence_interval
. If float, check for convergence everycheck_convergence_interval * max_iter
. If 0, never check for convergence (apart from the last iteration ifreturn_has_converged==True
) If convergence is reached early, the algorithm returns.cv_atol – absolute and relative tolerance for the converegence criterion
cv_rtol – absolute and relative tolerance for the converegence criterion
return_has_converged – whether to return a boolean indicating whether the algorithm has converged. Setting this to True means that the function will always check for convergence at the last iteration (regardless of the value of
check_convergence_interval
)
- Returns:
Tensor – (*batch). result of the sinkhorn computation
- ot_markov_distances.wl_k(MX: Tensor, MY: Tensor, l1: Tensor | None = None, l2: Tensor | None = None, *, cost_matrix: Tensor | None = None, k: int, muX: Tensor | None = None, muY: Tensor | None = None, reg: float = 0.1, sinkhorn_parameters: dict = {})
computes the WL distance
computes the WL distance between two markov transition matrices (represented as torch tensor)
Batched over first dimension (b)
- Parameters:
MX – (b, n, n) first transition tensor
MY – (b, m, m) second transition tensor
l1 – (b, n,) label values for the first space
l2 – (b, m,) label values for the second space
k – number of steps (k parameter for the WL distance)
muX – stationary distribution for MX (if omitted, will be recomuputed)
muY – stationary distribution for MY (if omitted, will be recomuputed)
reg – regularization parameter for sinkhorn
- ot_markov_distances.markov_measure(M: Tensor) Tensor
Takes a (batched) markov transition matrix, and outputs its stationary distribution
- Parameters:
M – (*b, n, n) the markov transition matrix
- Returns:
Tensor – m (*b, n) so that m @ b = m
- ot_markov_distances.discounted_wl_infty(MX: Tensor, MY: Tensor, distance_matrix: Tensor, muX: Tensor | None = None, muY: Tensor | None = None, delta: float = 0.4, sinkhorn_reg: float = 0.01, max_iter: int = 50, convergence_threshold_rtol: float = 0.005, convergence_threshold_atol: float = 1e-06, sinkhorn_iter: int = 100, sinkhorn_iter_schedule: int = 10, x_is_sparse: bool | None = None, y_is_sparse: bool | None = None)
Discounted WL infinity distance
Computes the discounted WL infinity distance between
(MX, muX)
and(MY, muY)
with cost matrixdistance_matrix
and discount factordelta
.- Parameters:
MX – (b, n, n) first transition tensor
MY – (b, m, m) second transition tensor
distance_matrix – [TODO:description]
muX – initial distribution for MX (if omitted, the stationary distribution will be used instead)
muY – initial distribution for MY (if omitted, the stationary distribution will be used instead)
delta – discount factor
sinkhorn_reg – regularization parameter for the sinkhorn algorithm
max_iter – maximum number of iterations.
convergence_threshold_rtol – relative tolerance for convergence criterion (see
torch.allclose
)convergence_threshold_atol – absolute tolerance for convergence criterion (see
torch.allclose
)sinkhorn_iter – maximum number of sinkhorn iteration
([TODO (sinkhorn_iter_schedule) – type]): [TODO:description]
x_is_sparse – whether to use the accelerated algorithm, considering MX is sparse (default: compute the degree, and check whether it lower that 2/3 n. If so, consider MX sparse)
y_is_sparse – whether to use the accelerated algorithm, considering MY is sparse (default: compute the degree, and check whether it lower that 2/3 m. if so, consider MY sparse)
- ot_markov_distances.discounted_wl_k(MX: Tensor, MY: Tensor, l1: Tensor | None = None, l2: Tensor | None = None, *, cost_matrix: Tensor | None = None, delta: Tensor | float = 0.4, k: int, muX: Tensor | None = None, muY: Tensor | None = None, reg: float = 0.1, sinkhorn_parameters: dict = {}, return_differences: bool = False)
computes the discounted WL distance
computes the WL-delta distance between two markov transition matrices (represented as torch tensor)
This function does not have the backward pass mentioned in the paper, because that formula is only valid for the case \(k=\infty\)
Batched over first dimension (b)
- Parameters:
MX – (b, n, n) first transition tensor
MY – (b, m, m) second transition tensor
l1 – (b, n,) label values for the first space
l2 – (b, m,) label values for the second space
cost_matrix – (b, n, m) allows specifying the cost matrix instead
k – number of steps (k parameter for the WL distance)
muX – distribution for MX (if omitted, the stationary distrubution will be used)
muY – distribution for MY (if omitted, the stationary distrubution will be used)
reg – regularization parameter for sinkhorn