Skip to content

Commit

Permalink
Enhanced zoom linesearch with earlier stopping, actionable messages, …
Browse files Browse the repository at this point in the history
…avoid getting stuck and rather warn. Let max_stepsize be only handled in initial guess for linesearches, unless necessary as in lbfgsb.
  • Loading branch information
vroulet committed Sep 23, 2023
1 parent d82e448 commit 86b0cf6
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 107 deletions.
6 changes: 3 additions & 3 deletions jaxopt/_src/backtracking_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class BacktrackingLineSearch(base.IterativeLineSearch):
c2: constant strictly less than 1 used by the (strong) Wolfe condition.
decrease_factor: factor by which to decrease the stepsize during line search
(default: 0.8).
max_stepsize: upper bound on stepsize.
max_stepsize: upper bound on stepsize (unused)
maxiter: maximum number of line search iterations.
tol: tolerance of the stopping criterion.
Expand All @@ -87,6 +87,8 @@ class BacktrackingLineSearch(base.IterativeLineSearch):
c1: float = 1e-4
c2: float = 0.9
decrease_factor: float = 0.8
# TODO(vroulet): remove max_stepsize argument as it is not used here.
# It's handled by the initial guess taken by the linesearch
max_stepsize: float = 1.0

verbose: int = 0
Expand Down Expand Up @@ -167,8 +169,6 @@ def update(
Returns:
(params, state)
"""
# Ensure that stepsize does not exceed upper bound.
stepsize = jnp.minimum(self.max_stepsize, stepsize)
num_fun_eval = state.num_fun_eval
num_grad_eval = state.num_grad_eval

Expand Down
12 changes: 9 additions & 3 deletions jaxopt/_src/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,14 @@ class BFGS(base.IterativeSolver):
backtracking line search (default: 0.8).
increase_factor: factor by which to increase the stepsize during line search
(default: 1.5).
max_stepsize: upper bound on stepsize.
min_stepsize: lower bound on stepsize.
max_stepsize: upper bound on stepsize guess at start of each linesearch run
for linesearch_init='increase'.
Note that the linesearch is allowed to take a larger stepsize to satisfy
curvature conditions.
min_stepsize: lower bound on stepsize guess at start of each linesearch run
for linesearch_init='increase'.
Note that the linesearch is allowed to take a smaller stepsize to satisfy
decrease conditions.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
Expand Down Expand Up @@ -286,7 +292,7 @@ def __post_init__(self):
value_and_grad=True,
has_aux=True,
maxlsiter=self.maxls,
max_stepsize=self.max_stepsize,
max_stepsize=None,
jit=self.jit,
unroll=unroll,
verbose=self.verbose,
Expand Down
5 changes: 3 additions & 2 deletions jaxopt/_src/hager_zhang_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class HagerZhangLineSearch(base.IterativeLineSearch):
c1: constant used by the Wolfe and Approximate Wolfe condition.
c2: constant strictly less than 1 used by the Wolfe and Approximate Wolfe
condition.
max_stepsize: upper bound on stepsize.
max_stepsize: upper bound on stepsize (unused).
maxiter: maximum number of line search iterations.
tol: tolerance of the stopping criterion.
Expand All @@ -103,7 +103,8 @@ class HagerZhangLineSearch(base.IterativeLineSearch):
expansion_factor: float = 5.0
shrinkage_factor: float = 0.66
approximate_wolfe_threshold = 1e-6
max_stepsize: float = 1.0
# TODO(vroulet): remove max_stepsize argument as it is not used
max_stepsize: float = 1.0

verbose: int = 0
jit: base.AutoOrBoolean = "auto"
Expand Down
12 changes: 9 additions & 3 deletions jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,14 @@ class LBFGS(base.IterativeSolver):
line search when using backtracking linesearch (default: 0.8).
increase_factor: factor by which to increase the stepsize during line search
(default: 1.5).
max_stepsize: upper bound on stepsize.
min_stepsize: lower bound on stepsize.
max_stepsize: upper bound on stepsize guess at start of each linesearch run
for linesearch_init='increase'.
Note that the linesearch is allowed to take a larger stepsize to satisfy
curvature conditions.
min_stepsize: lower bound on stepsize guess at start of each linesearch run
for linesearch_init='increase'.
Note that the linesearch is allowed to take a smaller stepsize to satisfy
decrease conditions.
history_size: size of the memory to use.
use_gamma: whether to initialize the inverse Hessian approximation with
gamma * I, where gamma is chosen following equation (7.20) of 'Numerical
Expand Down Expand Up @@ -438,7 +444,7 @@ def __post_init__(self):
value_and_grad=True,
has_aux=True,
maxlsiter=self.maxls,
max_stepsize=self.max_stepsize,
max_stepsize=None,
jit=self.jit,
unroll=unroll,
verbose=self.verbose,
Expand Down
12 changes: 8 additions & 4 deletions jaxopt/_src/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,13 @@ class LBFGSB(base.IterativeSolver):
backtracking line search (default: 0.8).
increase_factor: factor by which to increase the stepsize during line search
(default: 1.5).
max_stepsize: upper bound on stepsize.
min_stepsize: lower bound on stepsize.
history_size: size of the memory to use.
max_stepsize: upper bound on acceptable stepsize. By default update directions
are defined such that the max_stepsize should be 1. to avoid violating
constraints.
min_stepsize: lower bound on stepsize guess at start of each linesearch run
for linesearch_init='increase'.
Note that the linesearch is allowed to take a smaller stepsize to satisfy
decrease conditions. history_size: size of the memory to use.
use_gamma: whether to initialize the Hessian approximation with gamma *
theta, where gamma is chosen following equation (7.20) of 'Numerical
Optimization' [2]. If use_gamma is set to False, theta is used as
Expand All @@ -289,7 +293,7 @@ class LBFGSB(base.IterativeSolver):
linesearch_init: str = "increase"
stop_if_linesearch_fails: bool = False
condition: Any = None # deprecated in v0.8
maxls: int = 20
maxls: int = 30
decrease_factor: Any = None # deprecated in v0.8
increase_factor: float = 1.5
max_stepsize: float = 1.0
Expand Down
8 changes: 6 additions & 2 deletions jaxopt/_src/linesearch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def _setup_linesearch(
value_and_grad=value_and_grad,
has_aux=has_aux,
maxiter=maxlsiter,
max_stepsize=max_stepsize,
# NOTE(vroulet): max_stepsize has no effect in the solver
# max_stepsize=max_stepsize,
jit=jit,
unroll=unroll,
verbose=verbose,
Expand All @@ -63,7 +64,8 @@ def _setup_linesearch(
value_and_grad=value_and_grad,
has_aux=has_aux,
maxiter=maxlsiter,
max_stepsize=max_stepsize,
# NOTE(vroulet): max_stepsize has no effect in the solver
# max_stepsize=max_stepsize,
jit=jit,
unroll=unroll,
verbose=verbose,
Expand Down Expand Up @@ -95,6 +97,8 @@ def _init_stepsize(
# Else, we increase a bit the previous one.
stepsize * increase_factor,
)
# Never guess higher than max_stepsize
init_stepsize = jnp.minimum(init_stepsize, max_stepsize)
else:
raise ValueError(
f"Strategy {strategy} not available/tested. "
Expand Down
16 changes: 10 additions & 6 deletions jaxopt/_src/nonlinear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,14 @@ class NonlinearCG(base.IterativeSolver):
line search when using backtracking linesearch (default: 0.8).
increase_factor: factor by which to increase the stepsize during line search
(default: 1.2).
max_stepsize: upper bound on stepsize.
min_stepsize: lower bound on stepsize.
max_stepsize: upper bound on stepsize guess at start of each linesearch run
for linesearch_init='increase'.
Note that the linesearch is allowed to take a larger stepsize to satisfy
curvature conditions.
min_stepsize: lower bound on stepsize guess at start of each linesearch run
for linesearch_init='increase'.
Note that the linesearch is allowed to take a smaller stepsize to satisfy
decrease conditions.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
Expand Down Expand Up @@ -123,7 +128,7 @@ class NonlinearCG(base.IterativeSolver):
linesearch: str = "zoom"
linesearch_init: str = "increase"
condition: Any = None # deprecated in v0.8
maxls: int = 15
maxls: int = 30
decrease_factor: Any = None # deprecated in v0.8
increase_factor: float = 1.2
max_stepsize: float = 1.0
Expand Down Expand Up @@ -293,14 +298,13 @@ def __post_init__(self):
self.reference_signature = self.fun

unroll = self._get_unroll_option()

linesearch_solver = _setup_linesearch(
linesearch=self.linesearch,
fun=self._value_and_grad_with_aux,
value_and_grad=True,
has_aux=True,
maxlsiter=self.maxls,
max_stepsize=self.max_stepsize,
max_stepsize=None,
jit=self.jit,
unroll=unroll,
verbose=self.verbose
Expand Down
Loading

0 comments on commit 86b0cf6

Please sign in to comment.