@@ -1441,11 +1441,20 @@ def test_fgw_barycenter(nx):
1441
1441
p = ot .unif (n_samples )
1442
1442
1443
1443
ysb , ytb , C1b , C2b , p1b , p2b , pb = nx .from_numpy (ys , yt , C1 , C2 , p1 , p2 , p )
1444
-
1445
- Xb , Cb = ot .gromov .fgw_barycenters (
1446
- n_samples , [ysb , ytb ], [C1b , C2b ], None , [.5 , .5 ], 0.5 , fixed_structure = False ,
1447
- fixed_features = False , p = pb , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3 , random_state = 12345
1444
+ lambdas = [.5 , .5 ]
1445
+ Csb = [C1b , C2b ]
1446
+ Ysb = [ysb , ytb ]
1447
+ Xb , Cb , logb = ot .gromov .fgw_barycenters (
1448
+ n_samples , Ysb , Csb , None , lambdas , 0.5 , fixed_structure = False ,
1449
+ fixed_features = False , p = pb , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3 ,
1450
+ random_state = 12345 , log = True
1448
1451
)
1452
+ # test correspondance with utils function
1453
+ recovered_Cb = ot .gromov .update_square_loss (pb , lambdas , logb ['Ts_iter' ][- 1 ], Csb )
1454
+ recovered_Xb = ot .gromov .update_feature_matrix (lambdas , [y .T for y in Ysb ], logb ['Ts_iter' ][- 1 ], pb ).T
1455
+
1456
+ np .testing .assert_allclose (Cb , recovered_Cb )
1457
+ np .testing .assert_allclose (Xb , recovered_Xb )
1449
1458
1450
1459
xalea = rng .randn (n_samples , 2 )
1451
1460
init_C = ot .dist (xalea , xalea )
@@ -1454,7 +1463,7 @@ def test_fgw_barycenter(nx):
1454
1463
1455
1464
with pytest .raises (ot .utils .UndefinedParameter ): # to raise warning when `fixed_structure=True`and `init_C=None`
1456
1465
Xb , Cb = ot .gromov .fgw_barycenters (
1457
- n_samples , [ ysb , ytb ], [ C1b , C2b ] , ps = [p1b , p2b ], lambdas = None ,
1466
+ n_samples , Ysb , Csb , ps = [p1b , p2b ], lambdas = None ,
1458
1467
alpha = 0.5 , fixed_structure = True , init_C = None , fixed_features = False ,
1459
1468
p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3
1460
1469
)
@@ -1490,14 +1499,19 @@ def test_fgw_barycenter(nx):
1490
1499
np .testing .assert_allclose (X .shape , (n_samples , ys .shape [1 ]))
1491
1500
1492
1501
# add test with 'kl_loss'
1493
- X , C = ot .gromov .fgw_barycenters (
1502
+ X , C , log = ot .gromov .fgw_barycenters (
1494
1503
n_samples , [ys , yt ], [C1 , C2 ], [p1 , p2 ], [.5 , .5 ], 0.5 ,
1495
1504
fixed_structure = False , fixed_features = False , p = p , loss_fun = 'kl_loss' ,
1496
- max_iter = 100 , tol = 1e-3 , init_C = C , init_X = X , warmstartT = True , random_state = 12345
1505
+ max_iter = 100 , tol = 1e-3 , init_C = C , init_X = X , warmstartT = True ,
1506
+ random_state = 12345 , log = True
1497
1507
)
1498
1508
np .testing .assert_allclose (C .shape , (n_samples , n_samples ))
1499
1509
np .testing .assert_allclose (X .shape , (n_samples , ys .shape [1 ]))
1500
1510
1511
+ # test correspondance with utils function
1512
+ recovered_C = ot .gromov .update_kl_loss (p , lambdas , log ['Ts_iter' ][- 1 ], [C1 , C2 ])
1513
+ np .testing .assert_allclose (C , recovered_C )
1514
+
1501
1515
1502
1516
def test_gromov_wasserstein_linear_unmixing (nx ):
1503
1517
n = 4
0 commit comments