diff --git a/ot/optim.py b/ot/optim.py index c4198a00a..8700f75d1 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -297,7 +297,7 @@ def cost(G): loop = 0 abs_delta_cost_G = abs(cost_G - old_cost_G) - relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) + relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) if cost_G != 0. else np.nan if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2: loop = 0 diff --git a/test/test_gromov.py b/test/test_gromov.py index 559d07855..846e69f2b 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -8,10 +8,12 @@ # License: MIT License import numpy as np +import pytest +import warnings + import ot from ot.backend import NumpyBackend from ot.backend import torch, tf -import pytest def test_gromov(nx): @@ -146,8 +148,10 @@ def test_gromov_dtype_device(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) + with warnings.catch_warnings(): + warnings.filterwarnings('error') + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb)