Notes on Gradient computation

We compute the gradient using the technique described in Brugère et al. [BWW23].

But we push the formulae a little further to simplify the computation:

Remark that , with the notations of the paper, denoting

\[\begin{split}\begin{gather} K = I_{nm} - (1 - \delta)P\\ \Delta := \left(\Delta_{ij}^{kl}\right){}_{1 \leq i \leq n, 1 \leq j \leq m }^{ 1 \leq k \leq n, 1 \leq l \leq m}, \, \Gamma := \left(\Gamma_{ij}^{kk'}\right){}_{1 \leq i \leq n, 1 \leq j \leq m }^{ 1 \leq k \leq n, 1 \leq k' \leq n}, \Theta := \left(\Theta_{ij}^{ll'}\right){}_{1 \leq i \leq n, 1 \leq j \leq m }^{ 1 \leq l \leq m, 1 \leq l' \leq m}, \\ \text{where}~~~ \Delta_{ij}^{kl} := \frac{\partial C^{\epsilon,\delta, (\infty)}_{ij}}{\partial C_{kl}}, \, \Gamma_{ij}^{kk'} := \frac{\partial C^{\epsilon,\delta, (\infty)}_{ij}}{\partial m^{\setX}_{kk'}}, \, \Theta_{ij}^{ll'} := \frac{\partial C^{\epsilon,\delta, (\infty)}_{ij}}{\partial m^{\setY}_{ll'}}.\\ \end{gather}\end{split}\]

and denoting also

\[\begin{split}\begin{gather} G^O_{ij} := \frac{\partial \text{loss}}{\partial C^{\epsilon,\delta, (\infty)}_{ij}}\\ G^C_{kl} := \frac{\partial \text{loss}}{\partial C_{kl}}, \\ G^X_{kk'} := \frac{\partial \text{loss}}{\partial m^{\setX}_{kk'}}\\ G^Y_{ll'} := \frac{\partial \text{loss}}{\partial m^{\setY}_{ll'}}\\ \end{gather}\end{split}\]

Then (in matrix notation, ie with dimensions/codims flattened together)

\[\begin{split}\begin{gather} G^C = \Delta^T G^O\\ G^X = \Gamma^T G^O\\ G^Y = \Theta^T G^O\\ \end{gather}\end{split}\]

Developing

\[\begin{split}\begin{gather} G^C = \delta (K^T)^{-1} G^O\\ G^X = (1-\delta) F^T (K^T)^{-1} G^O\\ G^Y = (1-\delta) G^T (K^T)^{-1} G^O\\ \end{gather}\end{split}\]

Thus we save some compute by applying above formulae, and computing \((K^T)^{-1} G^O\) only once.

Note also that \((K^T)^{-1} G^O\) can be computed with torch.solve instead of torch.inv for more efficiency and stability

We call this matrix K_Tm1_grad in the implementation

Note also that \(F\) and \(G\) do not need to be explicitly computed: denote by \(L := (K^T)^{-1} G^O\), then

\[\begin{split}G^X_{kk'} &= (F^T L)_{kk'} \\ &= (F^{T})_{kk'}^{ij} L_{ij} \\ &= F^{kk'}_{ij} L_{ij} \\ &= f^{k'}_{ij}\1_{i=k} L_{ij} \\ &= f^{k'}_{kj} L{kj} \\\end{split}\]

And similarly for \(G^Y\)