Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ot/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def phi(alpha1):
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)

return min(1, alpha), fc[0], phi1
# scalar_search_armijo can return alpha > 1
if alpha is not None:
alpha = min(1, alpha)
return alpha, fc[0], phi1


def solve_linesearch(cost, G, deltaG, Mi, f_val,
Expand Down
10 changes: 10 additions & 0 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,13 @@ def test_solve_1d_linesearch_quad_funct():
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)


def test_line_search_armijo():
xk = np.array([[0.25, 0.25], [0.25, 0.25]])
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
old_fval = -123
# Should not throw an exception and return None for alpha
alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
assert alpha is None