From f3324a64e0714e8afd8549d22e90acacf57d2e1f Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Thu, 17 Aug 2023 21:50:38 +0200 Subject: [PATCH 1/4] Explicitly check that SinkhornL1l2Transport.fit works with no warnings --- test/test_da.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_da.py b/test/test_da.py index c95d48850..95d3ee4fd 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -7,6 +7,7 @@ import numpy as np from numpy.testing import assert_allclose, assert_equal import pytest +import warnings import ot from ot.datasets import make_data_classif @@ -166,7 +167,9 @@ def test_sinkhorn_l1l2_transport_class(nx): otda = ot.da.SinkhornL1l2Transport() # test its computed - otda.fit(Xs=Xs, ys=ys, Xt=Xt) + with warnings.catch_warnings(): + warnings.simplefilter("error") + otda.fit(Xs=Xs, ys=ys, Xt=Xt) assert hasattr(otda, "cost_") assert hasattr(otda, "coupling_") assert hasattr(otda, "log_") From 3478d0ab9d2aa6e7c9a9555176649e9619a69273 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Thu, 17 Aug 2023 22:05:05 +0200 Subject: [PATCH 2/4] Default value for alpha_min is set to 0 --- ot/optim.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/ot/optim.py b/ot/optim.py index 9e65e8141..c4198a00a 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -27,7 +27,7 @@ def line_search_armijo( f, xk, pk, gfk, old_fval, args=(), c1=1e-4, - alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs + alpha0=0.99, alpha_min=0., alpha_max=None, nx=None, **kwargs ): r""" Armijo linesearch function that works with matrices @@ -56,7 +56,7 @@ def line_search_armijo( :math:`c_1` const in armijo rule (>0) alpha0 : float, optional initial step (>0) - alpha_min : float, optional + alpha_min : float, default=0. minimum value for alpha alpha_max : float, optional maximum value for alpha @@ -89,6 +89,14 @@ def line_search_armijo( fc = [0] def phi(alpha1): + # it's necessary to check boundary condition here for the coefficient + # as the callback could be evaluated for negative value of alpha by + # `scalar_search_armijo` function here: + # + # https://github.com/scipy/scipy/blob/11509c4a98edded6c59423ac44ca1b7f28fba1fd/scipy/optimize/linesearch.py#L686 + # + # see more details https://github.com/PythonOT/POT/issues/502 + alpha1 = np.clip(alpha1, alpha_min, alpha_max) # The callable function operates on nx backend fc[0] += 1 alpha10 = nx.from_numpy(alpha1) @@ -109,13 +117,12 @@ def phi(alpha1): derphi0 = np.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( - phi, phi0, derphi0, c1=c1, alpha0=alpha0) + phi, phi0, derphi0, c1=c1, alpha0=alpha0, amin=alpha_min) if alpha is None: return 0., fc[0], nx.from_numpy(phi0, type_as=xk0) else: - if alpha_min is not None or alpha_max is not None: - alpha = np.clip(alpha, alpha_min, alpha_max) + alpha = np.clip(alpha, alpha_min, alpha_max) return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0) From aad4169f1c359ed9f782660fa1d9f2e9bb19548e Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Thu, 17 Aug 2023 22:40:04 +0200 Subject: [PATCH 3/4] Fix random_state for SinkhornL1l2Transport test --- test/test_da.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_da.py b/test/test_da.py index 95d3ee4fd..46d1106d0 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -159,12 +159,12 @@ def test_sinkhorn_l1l2_transport_class(nx): ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif('3gauss', ns, random_state=42) + Xt, yt = make_data_classif('3gauss2', nt, random_state=43) Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) - otda = ot.da.SinkhornL1l2Transport() + otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500) # test its computed with warnings.catch_warnings(): From 905bd4db51742a3bd0001ff3f58ad17ed2216472 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Wed, 20 Sep 2023 16:34:25 +0200 Subject: [PATCH 4/4] Mention changes in releases --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index 1b98d51bf..4eeea9c66 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,7 @@ + Tweaked `get_backend` to ignore `None` inputs (PR # 525) #### Closed issues +- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) ## 0.9.1