-
|
Hi, and thanks for maintaining this great library! I'm currently using POT (with PyTorch backend) to compute OT-based losses. So I would like to confirm: Minimal reproducible example import torch device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') n = 10 M = torch.randn(n, n, device=device, requires_grad=True) loss_emd2 = ot.emd2(a, b, M) M.grad.zero_() OUTPUT: Sinkhorn loss: -1.5583062171936035 Please clarify whether: ot.emd2 is intentionally non-differentiable (since it solves a linear program); |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
Yes we have implemented proper backward gradient propagation wrt the exact OT objective but keep in mind that for both it corresponds only to sub-gradients since the objective is indeed not differentiable (relu is not either). We have examples that show that you can optimize through our loss in the documentation. |
Beta Was this translation helpful? Give feedback.
Yes we have implemented proper backward gradient propagation wrt the exact OT objective but keep in mind that for both it corresponds only to sub-gradients since the objective is indeed not differentiable (relu is not either). We have examples that show that you can optimize through our loss in the documentation.