Skip to content

Commit b7179ad

Browse files
corrections Remi
1 parent 6501a7b commit b7179ad

File tree

2 files changed

+2
-11
lines changed

2 files changed

+2
-11
lines changed

ot/gromov/_utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,6 @@ def update_square_loss(p, lambdas, T, Cs, nx=None):
294294
295295
"""
296296
if nx is None:
297-
T = list_to_array(*T)
298-
Cs = list_to_array(*Cs)
299-
p = list_to_array(p)
300297
nx = get_backend(p, *T, *Cs)
301298

302299
# Correct order mistake in Equation 14 in [12]
@@ -353,9 +350,6 @@ def update_kl_loss(p, lambdas, T, Cs, nx=None):
353350
354351
"""
355352
if nx is None:
356-
Cs = list_to_array(*Cs)
357-
T = list_to_array(*T)
358-
p = list_to_array(p)
359353
nx = get_backend(p, *T, *Cs)
360354

361355
# Correct order mistake in Equation 15 in [12]
@@ -403,9 +397,6 @@ def update_feature_matrix(lambdas, Ys, Ts, p, nx=None):
403397
International Conference on Machine Learning (ICML). 2019.
404398
"""
405399
if nx is None:
406-
p = list_to_array(p)
407-
Ts = list_to_array(*Ts)
408-
Ys = list_to_array(*Ys)
409400
nx = get_backend(*Ys, *Ts, p)
410401

411402
p = 1. / p

test/test_gromov.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,7 +1461,7 @@ def test_fgw_barycenter(nx):
14611461
init_C /= init_C.max()
14621462
init_Cb = nx.from_numpy(init_C)
14631463

1464-
with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_structure=True`and `init_C=None`
1464+
with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None`
14651465
Xb, Cb = ot.gromov.fgw_barycenters(
14661466
n_samples, Ysb, Csb, ps=[p1b, p2b], lambdas=None,
14671467
alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False,
@@ -1480,7 +1480,7 @@ def test_fgw_barycenter(nx):
14801480
init_X = rng.randn(n_samples, ys.shape[1])
14811481
init_Xb = nx.from_numpy(init_X)
14821482

1483-
with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_features=True`and `init_X=None`
1483+
with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None`
14841484
Xb, Cb, logb = ot.gromov.fgw_barycenters(
14851485
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
14861486
fixed_structure=False, fixed_features=True, init_X=None,

0 commit comments

Comments
 (0)