Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)


## 0.9.1
Expand Down
17 changes: 12 additions & 5 deletions ot/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def line_search_armijo(
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs
alpha0=0.99, alpha_min=0., alpha_max=None, nx=None, **kwargs
):
r"""
Armijo linesearch function that works with matrices
Expand Down Expand Up @@ -56,7 +56,7 @@ def line_search_armijo(
:math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
alpha_min : float, optional
alpha_min : float, default=0.
minimum value for alpha
alpha_max : float, optional
maximum value for alpha
Expand Down Expand Up @@ -89,6 +89,14 @@ def line_search_armijo(
fc = [0]

def phi(alpha1):
# it's necessary to check boundary condition here for the coefficient
# as the callback could be evaluated for negative value of alpha by
# `scalar_search_armijo` function here:
#
# https://github.com/scipy/scipy/blob/11509c4a98edded6c59423ac44ca1b7f28fba1fd/scipy/optimize/linesearch.py#L686
#
# see more details https://github.com/PythonOT/POT/issues/502
alpha1 = np.clip(alpha1, alpha_min, alpha_max)
# The callable function operates on nx backend
fc[0] += 1
alpha10 = nx.from_numpy(alpha1)
Expand All @@ -109,13 +117,12 @@ def phi(alpha1):

derphi0 = np.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
phi, phi0, derphi0, c1=c1, alpha0=alpha0, amin=alpha_min)

if alpha is None:
return 0., fc[0], nx.from_numpy(phi0, type_as=xk0)
else:
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
alpha = np.clip(alpha, alpha_min, alpha_max)
return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)


Expand Down
11 changes: 7 additions & 4 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from numpy.testing import assert_allclose, assert_equal
import pytest
import warnings

import ot
from ot.datasets import make_data_classif
Expand Down Expand Up @@ -158,15 +159,17 @@ def test_sinkhorn_l1l2_transport_class(nx):
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Xs, ys = make_data_classif('3gauss', ns, random_state=42)
Xt, yt = make_data_classif('3gauss2', nt, random_state=43)

Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)

otda = ot.da.SinkhornL1l2Transport()
otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500)

# test its computed
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
with warnings.catch_warnings():
warnings.simplefilter("error")
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
assert hasattr(otda, "cost_")
assert hasattr(otda, "coupling_")
assert hasattr(otda, "log_")
Expand Down