Skip to content

Commit ffdd1cf

Browse files
kachayevrflamarycedricvincentcuaz
authored
Correctly handle cost = 0. (#505)
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
1 parent 7856700 commit ffdd1cf

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

ot/optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def cost(G):
297297
loop = 0
298298

299299
abs_delta_cost_G = abs(cost_G - old_cost_G)
300-
relative_delta_cost_G = abs_delta_cost_G / abs(cost_G)
300+
relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) if cost_G != 0. else np.nan
301301
if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2:
302302
loop = 0
303303

test/test_gromov.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
# License: MIT License
99

1010
import numpy as np
11+
import pytest
12+
import warnings
13+
1114
import ot
1215
from ot.backend import NumpyBackend
1316
from ot.backend import torch, tf
14-
import pytest
1517

1618

1719
def test_gromov(nx):
@@ -146,8 +148,10 @@ def test_gromov_dtype_device(nx):
146148

147149
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp)
148150

149-
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
150-
gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
151+
with warnings.catch_warnings():
152+
warnings.filterwarnings('error')
153+
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
154+
gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
151155

152156
nx.assert_same_dtype_device(C1b, Gb)
153157
nx.assert_same_dtype_device(C1b, gw_valb)

0 commit comments

Comments
 (0)