Skip to content

Commit 03ca4ef

Browse files
authored
[MRG] make alpha parameter in FGW diferentiable (#463)
* make alpha diferentiable * update release file * debug tensorflow to_numpy
1 parent 25d72db commit 03ca4ef

File tree

4 files changed

+49
-6
lines changed

4 files changed

+49
-6
lines changed

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#### New features
66

7+
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
8+
79
#### Closed issues
810

911
- Fix circleci-redirector action and codecov (PR #460)

ot/backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1694,10 +1694,12 @@ def backward(ctx, grad_output):
16941694
self.ValFunction = ValFunction
16951695

16961696
def _to_numpy(self, a):
1697+
if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
1698+
return np.array(a)
16971699
return a.cpu().detach().numpy()
16981700

16991701
def _from_numpy(self, a, type_as=None):
1700-
if isinstance(a, float):
1702+
if isinstance(a, float) or isinstance(a, int):
17011703
a = np.array(a)
17021704
if type_as is None:
17031705
return torch.from_numpy(a)
@@ -2501,6 +2503,8 @@ def __init__(self):
25012503
)
25022504

25032505
def _to_numpy(self, a):
2506+
if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
2507+
return np.array(a)
25042508
return a.numpy()
25052509

25062510
def _from_numpy(self, a, type_as=None):

ot/gromov/_gw.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
370370
Information and Inference: A Journal of the IMA, 8(4), 757-787.
371371
"""
372372
p, q = list_to_array(p, q)
373-
p0, q0, C10, C20, M0 = p, q, C1, C2, M
373+
p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha
374374
if G0 is None:
375375
nx = get_backend(p0, q0, C10, C20, M0)
376376
else:
@@ -382,6 +382,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
382382
C1 = nx.to_numpy(C10)
383383
C2 = nx.to_numpy(C20)
384384
M = nx.to_numpy(M0)
385+
alpha = nx.to_numpy(alpha0)
385386

386387
if symmetric is None:
387388
symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
@@ -535,10 +536,19 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric
535536
if loss_fun == 'square_loss':
536537
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
537538
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
538-
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
539-
(log_fgw['u'] - nx.mean(log_fgw['u']),
540-
log_fgw['v'] - nx.mean(log_fgw['v']),
541-
alpha * gC1, alpha * gC2, (1 - alpha) * T))
539+
if isinstance(alpha, int) or isinstance(alpha, float):
540+
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
541+
(log_fgw['u'] - nx.mean(log_fgw['u']),
542+
log_fgw['v'] - nx.mean(log_fgw['v']),
543+
alpha * gC1, alpha * gC2, (1 - alpha) * T))
544+
else:
545+
lin_term = nx.sum(T * M)
546+
gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha
547+
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
548+
(log_fgw['u'] - nx.mean(log_fgw['u']),
549+
log_fgw['v'] - nx.mean(log_fgw['v']),
550+
alpha * gC1, alpha * gC2, (1 - alpha) * T,
551+
gw_term - lin_term))
542552

543553
if log:
544554
return fgw_dist, log_fgw

test/test_gromov.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def test_gromov2_gradients():
209209
if torch.cuda.is_available():
210210
devices.append(torch.device("cuda"))
211211
for device in devices:
212+
213+
# classical gradients
212214
p1 = torch.tensor(p, requires_grad=True, device=device)
213215
q1 = torch.tensor(q, requires_grad=True, device=device)
214216
C11 = torch.tensor(C1, requires_grad=True, device=device)
@@ -226,6 +228,12 @@ def test_gromov2_gradients():
226228
assert C12.shape == C12.grad.shape
227229

228230
# Test with armijo line-search
231+
# classical gradients
232+
p1 = torch.tensor(p, requires_grad=True, device=device)
233+
q1 = torch.tensor(q, requires_grad=True, device=device)
234+
C11 = torch.tensor(C1, requires_grad=True, device=device)
235+
C12 = torch.tensor(C2, requires_grad=True, device=device)
236+
229237
q1.grad = None
230238
p1.grad = None
231239
C11.grad = None
@@ -830,6 +838,25 @@ def test_fgw2_gradients():
830838
assert C12.shape == C12.grad.shape
831839
assert M1.shape == M1.grad.shape
832840

841+
# full gradients with alpha
842+
p1 = torch.tensor(p, requires_grad=True, device=device)
843+
q1 = torch.tensor(q, requires_grad=True, device=device)
844+
C11 = torch.tensor(C1, requires_grad=True, device=device)
845+
C12 = torch.tensor(C2, requires_grad=True, device=device)
846+
M1 = torch.tensor(M, requires_grad=True, device=device)
847+
alpha = torch.tensor(0.5, requires_grad=True, device=device)
848+
849+
val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1, alpha=alpha)
850+
851+
val.backward()
852+
853+
assert val.device == p1.device
854+
assert q1.shape == q1.grad.shape
855+
assert p1.shape == p1.grad.shape
856+
assert C11.shape == C11.grad.shape
857+
assert C12.shape == C12.grad.shape
858+
assert alpha.shape == alpha.grad.shape
859+
833860

834861
def test_fgw_helper_backend(nx):
835862
n_samples = 20 # nb samples

0 commit comments

Comments
 (0)