Skip to content

Commit a52a97c

Browse files
complete coverage
1 parent 432062c commit a52a97c

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

test/test_gromov.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,11 +1441,20 @@ def test_fgw_barycenter(nx):
14411441
p = ot.unif(n_samples)
14421442

14431443
ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p)
1444-
1445-
Xb, Cb = ot.gromov.fgw_barycenters(
1446-
n_samples, [ysb, ytb], [C1b, C2b], None, [.5, .5], 0.5, fixed_structure=False,
1447-
fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345
1444+
lambdas = [.5, .5]
1445+
Csb = [C1b, C2b]
1446+
Ysb = [ysb, ytb]
1447+
Xb, Cb, logb = ot.gromov.fgw_barycenters(
1448+
n_samples, Ysb, Csb, None, lambdas, 0.5, fixed_structure=False,
1449+
fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
1450+
random_state=12345, log=True
14481451
)
1452+
# test correspondance with utils function
1453+
recovered_Cb = ot.gromov.update_square_loss(pb, lambdas, logb['Ts_iter'][-1], Csb)
1454+
recovered_Xb = ot.gromov.update_feature_matrix(lambdas, [y.T for y in Ysb], logb['Ts_iter'][-1], pb).T
1455+
1456+
np.testing.assert_allclose(Cb, recovered_Cb)
1457+
np.testing.assert_allclose(Xb, recovered_Xb)
14491458

14501459
xalea = rng.randn(n_samples, 2)
14511460
init_C = ot.dist(xalea, xalea)
@@ -1454,7 +1463,7 @@ def test_fgw_barycenter(nx):
14541463

14551464
with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_structure=True`and `init_C=None`
14561465
Xb, Cb = ot.gromov.fgw_barycenters(
1457-
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
1466+
n_samples, Ysb, Csb, ps=[p1b, p2b], lambdas=None,
14581467
alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False,
14591468
p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
14601469
)
@@ -1490,14 +1499,19 @@ def test_fgw_barycenter(nx):
14901499
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
14911500

14921501
# add test with 'kl_loss'
1493-
X, C = ot.gromov.fgw_barycenters(
1502+
X, C, log = ot.gromov.fgw_barycenters(
14941503
n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5,
14951504
fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss',
1496-
max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True, random_state=12345
1505+
max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True,
1506+
random_state=12345, log=True
14971507
)
14981508
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
14991509
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
15001510

1511+
# test correspondance with utils function
1512+
recovered_C = ot.gromov.update_kl_loss(p, lambdas, log['Ts_iter'][-1], [C1, C2])
1513+
np.testing.assert_allclose(C, recovered_C)
1514+
15011515

15021516
def test_gromov_wasserstein_linear_unmixing(nx):
15031517
n = 4

0 commit comments

Comments
 (0)