From 86b0cf6cc6507e0047fd142ab455a6df1df7e18a Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Fri, 22 Sep 2023 17:49:35 -0700 Subject: [PATCH] Enhanced zoom linesearch with earlier stopping, actionable messages, avoid getting stuck and rather warn. Let max_stepsize be only handled in initial guess for linesearches, unless necessary as in lbfgsb. --- jaxopt/_src/backtracking_linesearch.py | 6 +- jaxopt/_src/bfgs.py | 12 +- jaxopt/_src/hager_zhang_linesearch.py | 5 +- jaxopt/_src/lbfgs.py | 12 +- jaxopt/_src/lbfgsb.py | 12 +- jaxopt/_src/linesearch_util.py | 8 +- jaxopt/_src/nonlinear_cg.py | 16 +- jaxopt/_src/zoom_linesearch.py | 209 ++++++++++++++++++------- tests/lbfgs_test.py | 6 +- tests/lbfgsb_test.py | 2 + tests/nonlinear_cg_test.py | 2 + tests/zoom_linesearch_test.py | 57 ++++--- 12 files changed, 240 insertions(+), 107 deletions(-) diff --git a/jaxopt/_src/backtracking_linesearch.py b/jaxopt/_src/backtracking_linesearch.py index 12235ace..f6c66e20 100644 --- a/jaxopt/_src/backtracking_linesearch.py +++ b/jaxopt/_src/backtracking_linesearch.py @@ -66,7 +66,7 @@ class BacktrackingLineSearch(base.IterativeLineSearch): c2: constant strictly less than 1 used by the (strong) Wolfe condition. decrease_factor: factor by which to decrease the stepsize during line search (default: 0.8). - max_stepsize: upper bound on stepsize. + max_stepsize: upper bound on stepsize (unused) maxiter: maximum number of line search iterations. tol: tolerance of the stopping criterion. @@ -87,6 +87,8 @@ class BacktrackingLineSearch(base.IterativeLineSearch): c1: float = 1e-4 c2: float = 0.9 decrease_factor: float = 0.8 + # TODO(vroulet): remove max_stepsize argument as it is not used here. + # It's handled by the initial guess taken by the linesearch max_stepsize: float = 1.0 verbose: int = 0 @@ -167,8 +169,6 @@ def update( Returns: (params, state) """ - # Ensure that stepsize does not exceed upper bound. - stepsize = jnp.minimum(self.max_stepsize, stepsize) num_fun_eval = state.num_fun_eval num_grad_eval = state.num_grad_eval diff --git a/jaxopt/_src/bfgs.py b/jaxopt/_src/bfgs.py index 33ca697a..29c8c217 100644 --- a/jaxopt/_src/bfgs.py +++ b/jaxopt/_src/bfgs.py @@ -100,8 +100,14 @@ class BFGS(base.IterativeSolver): backtracking line search (default: 0.8). increase_factor: factor by which to increase the stepsize during line search (default: 1.5). - max_stepsize: upper bound on stepsize. - min_stepsize: lower bound on stepsize. + max_stepsize: upper bound on stepsize guess at start of each linesearch run + for linesearch_init='increase'. + Note that the linesearch is allowed to take a larger stepsize to satisfy + curvature conditions. + min_stepsize: lower bound on stepsize guess at start of each linesearch run + for linesearch_init='increase'. + Note that the linesearch is allowed to take a smaller stepsize to satisfy + decrease conditions. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. @@ -286,7 +292,7 @@ def __post_init__(self): value_and_grad=True, has_aux=True, maxlsiter=self.maxls, - max_stepsize=self.max_stepsize, + max_stepsize=None, jit=self.jit, unroll=unroll, verbose=self.verbose, diff --git a/jaxopt/_src/hager_zhang_linesearch.py b/jaxopt/_src/hager_zhang_linesearch.py index 88439c23..f432b112 100644 --- a/jaxopt/_src/hager_zhang_linesearch.py +++ b/jaxopt/_src/hager_zhang_linesearch.py @@ -81,7 +81,7 @@ class HagerZhangLineSearch(base.IterativeLineSearch): c1: constant used by the Wolfe and Approximate Wolfe condition. c2: constant strictly less than 1 used by the Wolfe and Approximate Wolfe condition. - max_stepsize: upper bound on stepsize. + max_stepsize: upper bound on stepsize (unused). maxiter: maximum number of line search iterations. tol: tolerance of the stopping criterion. @@ -103,7 +103,8 @@ class HagerZhangLineSearch(base.IterativeLineSearch): expansion_factor: float = 5.0 shrinkage_factor: float = 0.66 approximate_wolfe_threshold = 1e-6 - max_stepsize: float = 1.0 + # TODO(vroulet): remove max_stepsize argument as it is not used + max_stepsize: float = 1.0 verbose: int = 0 jit: base.AutoOrBoolean = "auto" diff --git a/jaxopt/_src/lbfgs.py b/jaxopt/_src/lbfgs.py index 5253c0e7..58b33d33 100644 --- a/jaxopt/_src/lbfgs.py +++ b/jaxopt/_src/lbfgs.py @@ -183,8 +183,14 @@ class LBFGS(base.IterativeSolver): line search when using backtracking linesearch (default: 0.8). increase_factor: factor by which to increase the stepsize during line search (default: 1.5). - max_stepsize: upper bound on stepsize. - min_stepsize: lower bound on stepsize. + max_stepsize: upper bound on stepsize guess at start of each linesearch run + for linesearch_init='increase'. + Note that the linesearch is allowed to take a larger stepsize to satisfy + curvature conditions. + min_stepsize: lower bound on stepsize guess at start of each linesearch run + for linesearch_init='increase'. + Note that the linesearch is allowed to take a smaller stepsize to satisfy + decrease conditions. history_size: size of the memory to use. use_gamma: whether to initialize the inverse Hessian approximation with gamma * I, where gamma is chosen following equation (7.20) of 'Numerical @@ -438,7 +444,7 @@ def __post_init__(self): value_and_grad=True, has_aux=True, maxlsiter=self.maxls, - max_stepsize=self.max_stepsize, + max_stepsize=None, jit=self.jit, unroll=unroll, verbose=self.verbose, diff --git a/jaxopt/_src/lbfgsb.py b/jaxopt/_src/lbfgsb.py index 06855b4f..8b3d92d9 100644 --- a/jaxopt/_src/lbfgsb.py +++ b/jaxopt/_src/lbfgsb.py @@ -261,9 +261,13 @@ class LBFGSB(base.IterativeSolver): backtracking line search (default: 0.8). increase_factor: factor by which to increase the stepsize during line search (default: 1.5). - max_stepsize: upper bound on stepsize. - min_stepsize: lower bound on stepsize. - history_size: size of the memory to use. + max_stepsize: upper bound on acceptable stepsize. By default update directions + are defined such that the max_stepsize should be 1. to avoid violating + constraints. + min_stepsize: lower bound on stepsize guess at start of each linesearch run + for linesearch_init='increase'. + Note that the linesearch is allowed to take a smaller stepsize to satisfy + decrease conditions. history_size: size of the memory to use. use_gamma: whether to initialize the Hessian approximation with gamma * theta, where gamma is chosen following equation (7.20) of 'Numerical Optimization' [2]. If use_gamma is set to False, theta is used as @@ -289,7 +293,7 @@ class LBFGSB(base.IterativeSolver): linesearch_init: str = "increase" stop_if_linesearch_fails: bool = False condition: Any = None # deprecated in v0.8 - maxls: int = 20 + maxls: int = 30 decrease_factor: Any = None # deprecated in v0.8 increase_factor: float = 1.5 max_stepsize: float = 1.0 diff --git a/jaxopt/_src/linesearch_util.py b/jaxopt/_src/linesearch_util.py index 63bd4145..6822d733 100644 --- a/jaxopt/_src/linesearch_util.py +++ b/jaxopt/_src/linesearch_util.py @@ -41,7 +41,8 @@ def _setup_linesearch( value_and_grad=value_and_grad, has_aux=has_aux, maxiter=maxlsiter, - max_stepsize=max_stepsize, + # NOTE(vroulet): max_stepsize has no effect in the solver + # max_stepsize=max_stepsize, jit=jit, unroll=unroll, verbose=verbose, @@ -63,7 +64,8 @@ def _setup_linesearch( value_and_grad=value_and_grad, has_aux=has_aux, maxiter=maxlsiter, - max_stepsize=max_stepsize, + # NOTE(vroulet): max_stepsize has no effect in the solver + # max_stepsize=max_stepsize, jit=jit, unroll=unroll, verbose=verbose, @@ -95,6 +97,8 @@ def _init_stepsize( # Else, we increase a bit the previous one. stepsize * increase_factor, ) + # Never guess higher than max_stepsize + init_stepsize = jnp.minimum(init_stepsize, max_stepsize) else: raise ValueError( f"Strategy {strategy} not available/tested. " diff --git a/jaxopt/_src/nonlinear_cg.py b/jaxopt/_src/nonlinear_cg.py index 53f97d7b..76fb55cc 100644 --- a/jaxopt/_src/nonlinear_cg.py +++ b/jaxopt/_src/nonlinear_cg.py @@ -90,9 +90,14 @@ class NonlinearCG(base.IterativeSolver): line search when using backtracking linesearch (default: 0.8). increase_factor: factor by which to increase the stepsize during line search (default: 1.2). - max_stepsize: upper bound on stepsize. - min_stepsize: lower bound on stepsize. - + max_stepsize: upper bound on stepsize guess at start of each linesearch run + for linesearch_init='increase'. + Note that the linesearch is allowed to take a larger stepsize to satisfy + curvature conditions. + min_stepsize: lower bound on stepsize guess at start of each linesearch run + for linesearch_init='increase'. + Note that the linesearch is allowed to take a smaller stepsize to satisfy + decrease conditions. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. @@ -123,7 +128,7 @@ class NonlinearCG(base.IterativeSolver): linesearch: str = "zoom" linesearch_init: str = "increase" condition: Any = None # deprecated in v0.8 - maxls: int = 15 + maxls: int = 30 decrease_factor: Any = None # deprecated in v0.8 increase_factor: float = 1.2 max_stepsize: float = 1.0 @@ -293,14 +298,13 @@ def __post_init__(self): self.reference_signature = self.fun unroll = self._get_unroll_option() - linesearch_solver = _setup_linesearch( linesearch=self.linesearch, fun=self._value_and_grad_with_aux, value_and_grad=True, has_aux=True, maxlsiter=self.maxls, - max_stepsize=self.max_stepsize, + max_stepsize=None, jit=self.jit, unroll=unroll, verbose=self.verbose diff --git a/jaxopt/_src/zoom_linesearch.py b/jaxopt/_src/zoom_linesearch.py index 4972822f..23c00848 100644 --- a/jaxopt/_src/zoom_linesearch.py +++ b/jaxopt/_src/zoom_linesearch.py @@ -33,29 +33,29 @@ from jaxopt.tree_util import tree_scalar_mul from jaxopt.tree_util import tree_vdot_real from jaxopt.tree_util import tree_conj +from jaxopt.tree_util import tree_l2_norm # pylint: disable=g-bare-generic # pylint: disable=invalid-name # Flags to print errors, used in tests WARNING_PREAMBLE = \ - "ZoomLineSearchWarning: " -FLAG_NOT_A_DESCENT_DIRECTION = WARNING_PREAMBLE + \ - "Provided linesearch direction is not a descent direction. " + \ - "The linesearch will probably fail." + "WARNING: jaxopt.ZoomLineSearch: " FLAG_NAN_INF_VALUES = WARNING_PREAMBLE + \ "NaN or Inf values encountered in function values." FLAG_INTERVAL_NOT_FOUND = WARNING_PREAMBLE + \ - "No interval satisfying curvature condition. " + \ - "Try increasing maximal stepsize." + "No interval satisfying curvature condition." + \ + "Consider increasing maximal possible stepsize of the linesearch." FLAG_INTERVAL_TOO_SMALL = WARNING_PREAMBLE + \ - "Length of searched interval has been reduced below machine precision." + "Length of searched interval has been reduced below threshold." FLAG_CURVATURE_COND_NOT_SATSIFIED = WARNING_PREAMBLE + \ "Returning stepsize with sufficient decrease " + \ "but curvature condition not satisfied." FLAG_NO_STEPSIZE_FOUND = WARNING_PREAMBLE + \ - "Linesearch failed, no stepsize satisfying sufficient decrease found. " + \ - "Try increasing maximal number of linesearch iterations." + "Linesearch failed, no stepsize satisfying sufficient decrease found." +FLAG_NOT_A_DESCENT_DIRECTION = WARNING_PREAMBLE + \ + "The linesearch failed because the provided direction " + \ + "is not a descent direction. " _dot = functools.partial(jnp.dot, precision=lax.Precision.HIGHEST) @@ -127,8 +127,8 @@ def _set_val(x, y): return jax.tree_util.tree_map(_set_val, candidate, default) -def _cond_print(condition, message): - jax.lax.cond(condition, lambda _: jax.debug.print(message), lambda _: None, None) +def _cond_print(condition, message, **kwargs): + jax.lax.cond(condition, lambda _: jax.debug.print(message, **kwargs), lambda _: None, None) class ZoomLineSearchState(NamedTuple): @@ -147,6 +147,8 @@ class ZoomLineSearchState(NamedTuple): num_fun_eval: int num_grad_eval: int + decrease_error: float + curvature_error: float error: float done: bool failed: bool # comply to semantic used by other line searches @@ -212,9 +214,13 @@ class ZoomLineSearch(base.IterativeLineSearch): rel_tol_cubic*interval_size (default: 0.2) rel_tol_quad: point computed by quadratic interpolation accepted if inside rel_tol_quad*interval_size (default: 0.1) + interval_threshold: if the size of the interval searched is below this threshold + and a sufficient decrease for some stepsize s has been found, + then the linesearch takes s and moves on. (default=5*1e-5, which corresponds + to 15 linesearch iterations for init_stepsize = 1) increase_factor: factor to mutliply stepsize at initialization until finding interval satisfying curvature condition (default: 2.) - max_stepsize: maximal possible stepsize. (default: 2**30) + max_stepsize: maximal possible stepsize, (default: 2**maxiter) tol: tolerance of the stopping criterion. (default: 0.0) maxiter: maximum number of line search iterations. (default: 30) verbose: whether to print error on every iteration or not. verbose=True will @@ -232,13 +238,12 @@ class ZoomLineSearch(base.IterativeLineSearch): c3: float = 1e-6 rel_tol_cubic: float = 0.2 rel_tol_quad: float = 0.1 + interval_threshold: float = 5*1e-5 increase_factor: float = 2.0 tol: float = 0.0 maxiter: int = 30 - # max_stepsize needs to be large enough for the linesearch to be able - # to find a good stepsize - max_stepsize: float = 2**30 + max_stepsize: Optional[float] = None verbose: bool = False jit: base.AutoOrBoolean = "auto" @@ -263,7 +268,7 @@ def _decrease_error( value_step - value_init - self.c1 * stepsize * slope_init ) # or an approximate decrease condition, see equation (23) of [2] - approx_decrease_error_ = slope_step - (2 * self.c1 - 1.0) * slope_init + approx_decrease_error = slope_step - (2 * self.c1 - 1.0) * slope_init # The classical Armijo condition may fail to be satisfied if we are too # close to a minimum, causing the optimizer to fail as explained in [2] @@ -271,7 +276,7 @@ def _decrease_error( # We switch to approximate Wolfe conditions only if we are close enough to # the minimizer which is captured by the following criterion. delta_values = value_step - value_init - self.c3 * jnp.abs(value_init) - approx_decrease_error = jnp.maximum(approx_decrease_error_, delta_values) + approx_decrease_error = jnp.maximum(approx_decrease_error, delta_values) # We take then the *minimum* of both errors. return jnp.minimum(approx_decrease_error, exact_decrease_error) @@ -279,13 +284,18 @@ def _curvature_error(self, slope_step, slope_init): # See equation (3.7b) of [1]. return jnp.abs(slope_step) - self.c2 * jnp.abs(slope_init) - def _make_safe_step(self, _, state, args, kwargs): + def _make_safe_step(self, stepsize, state, args, kwargs): safe_stepsize = state.safe_stepsize - _cond_print(safe_stepsize == 0.0, FLAG_NO_STEPSIZE_FOUND) if self.verbose: _cond_print((safe_stepsize > 0.), FLAG_CURVATURE_COND_NOT_SATSIFIED) + final_stepsize = jax.lax.cond( + safe_stepsize > 0., + lambda safe_stepsize, *_: safe_stepsize, + self.failure_diagnostic, + safe_stepsize, stepsize, state + ) step = tree_add_scalar_mul( - state.params, safe_stepsize, state.descent_direction + state.params, final_stepsize, state.descent_direction ) (value_step, aux_step), grad_step = self._value_and_grad_fun_with_aux( step, *args, **kwargs @@ -296,7 +306,7 @@ def _make_safe_step(self, _, state, args, kwargs): grad=grad_step, aux=aux_step, ) - return safe_stepsize, new_state + return final_stepsize, new_state def _keep_step(self, stepsize, state, _, __): return stepsize, state @@ -337,16 +347,16 @@ def _search_interval(self, init_stepsize, state, args, kwargs): if self.verbose: _cond_print(is_value_nan, FLAG_NAN_INF_VALUES) - decrease_error_ = self._decrease_error( + decrease_error = self._decrease_error( new_stepsize, new_value_step, new_slope_step, value_init, slope_init ) - decrease_error = jnp.maximum(decrease_error_, 0.0) + decrease_error = jnp.maximum(decrease_error, 0.0) decrease_error = jnp.where( jnp.isnan(decrease_error), jnp.inf, decrease_error ) - curvature_error_ = self._curvature_error(new_slope_step, slope_init) - curvature_error = jnp.maximum(curvature_error_, 0.0) + curvature_error = self._curvature_error(new_slope_step, slope_init) + curvature_error = jnp.maximum(curvature_error, 0.0) curvature_error = jnp.where( jnp.isnan(curvature_error), jnp.inf, curvature_error ) @@ -403,21 +413,21 @@ def _search_interval(self, init_stepsize, state, args, kwargs): # the linesearch is done for either of the two reasons above, we set # directly the new parameters, gradient, value and aux to the ones found. done = (new_error <= self.tol) | (max_stepsize_reached & ~interval_found) - default = [0.0, params_init, value_init, grad_init, aux_init] + default = [params_init, value_init, grad_init, aux_init] candidate = [ - new_stepsize, new_step, new_value_step, new_grad_step, new_aux_step, ] - best_stepsize, next_params, next_value, next_grad, next_aux = _set_values( + next_params, next_value, next_grad, next_aux = _set_values( done, candidate, default ) if self.verbose: _cond_print( (max_stepsize_reached & ~interval_found), - FLAG_INTERVAL_NOT_FOUND + '\n' + FLAG_CURVATURE_COND_NOT_SATSIFIED + FLAG_INTERVAL_NOT_FOUND + '\n' + + FLAG_CURVATURE_COND_NOT_SATSIFIED ) max_iter_reached = (iter_num + 1 >= self.maxiter) & (~done) @@ -428,6 +438,8 @@ def _search_interval(self, init_stepsize, state, args, kwargs): grad=next_grad, aux=next_aux, # + decrease_error=decrease_error, + curvature_error=curvature_error, error=new_error, done=done, failed=jnp.asarray(max_iter_reached), @@ -448,7 +460,7 @@ def _search_interval(self, init_stepsize, state, args, kwargs): # safe_stepsize=new_safe_stepsize, ) - return base.LineSearchStep(stepsize=best_stepsize, state=new_state) + return base.LineSearchStep(stepsize=new_stepsize, state=new_state) def _zoom_into_interval(self, stepsize, state, args, kwargs): """Zoom procedure described in Algorithm 3.6 of [1].""" @@ -485,7 +497,11 @@ def _zoom_into_interval(self, stepsize, state, args, kwargs): right = jnp.maximum(high, low) cubic_chk = self.rel_tol_cubic * delta quad_chk = self.rel_tol_quad * delta - threshold = jnp.where((jnp.finfo(delta).bits < 64), 1e-5, 1e-10) + + # Rather large values of threshold compared to machine precision + # such that we avoid wasting iterations to satisfy curvature condition + # (a stepsize reducing values is taken if it exists when threshold is met) + threshold = self.interval_threshold too_small_int = delta <= threshold if self.verbose: _cond_print(too_small_int, FLAG_INTERVAL_TOO_SMALL) @@ -520,16 +536,16 @@ def _zoom_into_interval(self, stepsize, state, args, kwargs): if self.verbose: _cond_print(is_value_nan, FLAG_NAN_INF_VALUES) - decrease_error_ = self._decrease_error( + decrease_error = self._decrease_error( middle, value_middle, slope_middle, value_init, slope_init ) - decrease_error = jnp.maximum(decrease_error_, 0.0) + decrease_error = jnp.maximum(decrease_error, 0.0) decrease_error = jnp.where( jnp.isnan(decrease_error), jnp.inf, decrease_error ) - curvature_error_ = self._curvature_error(slope_middle, slope_init) - curvature_error = jnp.maximum(curvature_error_, 0.0) + curvature_error = self._curvature_error(slope_middle, slope_init) + curvature_error = jnp.maximum(curvature_error, 0.0) curvature_error = jnp.where( jnp.isnan(curvature_error), jnp.inf, curvature_error ) @@ -539,14 +555,14 @@ def _zoom_into_interval(self, stepsize, state, args, kwargs): # If the new point satisfies at least the decrease error we keep it in case # the curvature error cannot be satisfied. We take the largest possible one safe_decrease = decrease_error <= self.tol - new_safe_stepsize_ = jnp.where(safe_decrease, middle, safe_stepsize) - new_safe_stepsize = jnp.maximum(new_safe_stepsize_, safe_stepsize) + new_safe_stepsize = jnp.where(safe_decrease, middle, safe_stepsize) + new_safe_stepsize = jnp.maximum(new_safe_stepsize, safe_stepsize) # If both armijo and curvature conditions are satisfied, we are done. done = new_error <= self.tol - default = [0.0, params_init, value_init, grad_init, aux_init] - candidate = [middle, step, value_middle, grad_step, aux_step] - best_stepsize, next_params, next_value, next_grad, next_aux = _set_values( + default = [params_init, value_init, grad_init, aux_init] + candidate = [step, value_middle, grad_step, aux_step] + next_params, next_value, next_grad, next_aux = _set_values( new_error <= self.tol, candidate, default ) @@ -584,9 +600,15 @@ def _zoom_into_interval(self, stepsize, state, args, kwargs): [high, value_high], [low, value_low], ) - - max_iter_reached = (iter_num + 1 >= self.maxiter) & (~done) - + # We stop if the searched interval is reduced below machine precision + # and we already have found a positive stepsize ensuring sufficient + # decrease. If no stepsize with sufficient decrease has been found, + # we keep going on (some extremely steep functions require very small + # stepsizes, see zakharov test in lbfgs_test.py) + max_iter_reached = (iter_num + 1 >= self.maxiter) + presumably_failed = jnp.asarray(max_iter_reached) | \ + (too_small_int & (new_safe_stepsize > 0.)) + failed = presumably_failed & ~done new_state = state._replace( iter_num=iter_num + 1, params=next_params, @@ -594,9 +616,11 @@ def _zoom_into_interval(self, stepsize, state, args, kwargs): grad=next_grad, aux=next_aux, # + decrease_error=decrease_error, + curvature_error=curvature_error, error=new_error, done=done, - failed=jnp.asarray(max_iter_reached), + failed=failed, # low=new_low, value_low=new_value_low, @@ -609,7 +633,7 @@ def _zoom_into_interval(self, stepsize, state, args, kwargs): # safe_stepsize=new_safe_stepsize, ) - return base.LineSearchStep(stepsize=best_stepsize, state=new_state) + return base.LineSearchStep(stepsize=middle, state=new_state) def init_state( self, @@ -665,8 +689,6 @@ def init_state( slope = tree_vdot_real(tree_conj(grad), descent_direction) - _cond_print(slope > 0, FLAG_NOT_A_DESCENT_DIRECTION) - return ZoomLineSearchState( iter_num=jnp.asarray(0), params=params, @@ -678,6 +700,8 @@ def init_state( slope_init=slope, descent_direction=descent_direction, # + decrease_error=jnp.asarray(jnp.inf), + curvature_error=jnp.asarray(jnp.inf), error=jnp.asarray(jnp.inf), done=jnp.asarray(False), failed=jnp.asarray(False), @@ -701,15 +725,6 @@ def init_state( num_grad_eval=num_grad_eval, ) - def _cond_fun(self, inputs): - # Stop the linesearch according to done rather than the error as one may - # reach the maximal stepsize and no decrease of the curvature error may be - # possible. - _, state = inputs[0] - if self.verbose: - jax.debug.print("Solver: ZoomLineSearch, Error: {error}", error=state.error) - return ~state.done - def update( self, stepsize: float, @@ -769,7 +784,7 @@ def update( (new_state_.failed) & (new_state_.iter_num == self.maxiter) ).astype(base.NUM_EVAL_DTYPE) best_stepsize, new_state = cond( - (new_state_.failed) & (new_state_.iter_num == self.maxiter), + new_state_.failed, self._make_safe_step, self._keep_step, best_stepsize_, @@ -785,9 +800,89 @@ def update( return base.LineSearchStep(stepsize=best_stepsize, state=new_state) + def _cond_fun(self, inputs): + # Stop the linesearch according to done or failed rather than the error as one may + # reach the maximal stepsize and no decrease of the curvature error may be + # possible or the searched interval has been reduced too much. + stepsize, state = inputs[0] + if self.verbose: + self._log_info(stepsize, state) + return ~(state.done | state.failed) + + def _log_info(self, stepsize, state): + jax.debug.print( + "INFO: jaxopt.ZoomLineSearch: " + \ + "Iter: {iter}, " + \ + "Stepsize: {stepsize}, " + \ + "Decrease error: {decrease_error}, " + \ + "Curvature error: {curvature_error}", + iter=state.iter_num, + stepsize=stepsize, + decrease_error=state.decrease_error, + curvature_error=state.curvature_error + ) + + def failure_diagnostic(self, safe_stepsize, stepsize, state): + jax.debug.print(FLAG_NO_STEPSIZE_FOUND) + self._log_info(stepsize, state) + + slope_init = state.slope_init + is_descent_dir = slope_init < 0. + _cond_print( + ~is_descent_dir, + FLAG_NOT_A_DESCENT_DIRECTION + \ + "The slope (={slope_init}) at stepsize=0 should be negative", + slope_init=slope_init + ) + _cond_print( + is_descent_dir, + WARNING_PREAMBLE + \ + "Consider augmenting the maximal number of linesearch iterations." + ) + eps = jnp.finfo(stepsize).eps + below_eps = stepsize < eps + _cond_print( + below_eps & is_descent_dir, + WARNING_PREAMBLE + \ + "Computed stepsize (={stepsize}) " + \ + "is below machine precision (={eps}), " +\ + "consider passing to higher precision like x64, using " +\ + "jax.config.update('jax_enable_x64).", + stepsize=stepsize, + eps=eps + ) + abs_slope_init = jnp.abs(slope_init) + high_slope = abs_slope_init > 1e16 + _cond_print( + high_slope & is_descent_dir, + WARNING_PREAMBLE + \ + "Very large absolute slope at stepsize=0. (|slope|={abs_slope_init}). " + \ + "The objective is badly conditioned. " + \ + "Consider reparameterizing objective (e.g., normalizing parameters) " + \ + "or finding a better guess for the initial parameters for the solver.", + abs_slope_init=abs_slope_init + ) + outside_domain = jnp.isinf(state.decrease_error) + _cond_print( + outside_domain, + WARNING_PREAMBLE + \ + "Cannot even make a step without getting Inf or Nan. " + \ + "The linesearch won't make a step and the optimizer is stuck." + ) + _cond_print( + ~outside_domain, + WARNING_PREAMBLE + \ + "Making an unsafe step, not decreasing enough the objective. " + \ + "Convergence of the solver is compromised as it does not reduce values." + ) + final_stepsize = jnp.where(outside_domain, safe_stepsize, stepsize) + return final_stepsize + def __post_init__(self): self._fun_with_aux, _, self._value_and_grad_fun_with_aux = ( _make_funs_with_aux( self.fun, value_and_grad=self.value_and_grad, has_aux=self.has_aux ) ) + if not self.max_stepsize: + self.max_stepsize = float(2**self.maxiter) diff --git a/tests/lbfgs_test.py b/tests/lbfgs_test.py index 54d94a59..fba7e5b5 100644 --- a/tests/lbfgs_test.py +++ b/tests/lbfgs_test.py @@ -517,11 +517,7 @@ def test_against_scipy(self, fun_init_and_opt): tol = 1e-15 if jax.config.jax_enable_x64 else 1e-6 fun_name, x0, opt = fun_init_and_opt jnp_fun, onp_fun = get_fun(fun_name, jnp), get_fun(fun_name, onp) - jaxopt_options = {} - if fun_name == 'zakharov': - # zakharov function requires more linesearch iterations - jaxopt_options.update(dict(maxls = 50)) - jaxopt_res = LBFGS(jnp_fun, tol=tol, **jaxopt_options).run(x0).params + jaxopt_res = LBFGS(jnp_fun, tol=tol).run(x0).params scipy_res = scipy_opt.minimize(onp_fun, x0, method='BFGS').x # scipy not good for matyas and zakharov functions, # compare to true minimum, zero diff --git a/tests/lbfgsb_test.py b/tests/lbfgsb_test.py index 65fa7f42..4fd37bb8 100644 --- a/tests/lbfgsb_test.py +++ b/tests/lbfgsb_test.py @@ -29,6 +29,8 @@ from sklearn import datasets +# Uncomment this line to test in x64 +# jax.config.update('jax_enable_x64', True) class LbfgsbTest(test_util.JaxoptTestCase): diff --git a/tests/nonlinear_cg_test.py b/tests/nonlinear_cg_test.py index 4d5b692c..d278fd96 100644 --- a/tests/nonlinear_cg_test.py +++ b/tests/nonlinear_cg_test.py @@ -26,6 +26,8 @@ from jaxopt._src import test_util from sklearn import datasets +# Uncomment this line to test in x64 +# jax.config.update('jax_enable_x64', True) def get_random_pytree(): key = jax.random.PRNGKey(1213) diff --git a/tests/zoom_linesearch_test.py b/tests/zoom_linesearch_test.py index 7b39d545..7702e7b2 100644 --- a/tests/zoom_linesearch_test.py +++ b/tests/zoom_linesearch_test.py @@ -36,7 +36,8 @@ # pylint: disable=invalid-name - +# Uncomment the line below in order to run in float64. +# jax.config.update("jax_enable_x64", True) class ZoomLinesearchTest(test_util.JaxoptTestCase): """Tests for ZoomLineSearch.""" @@ -216,7 +217,7 @@ def fun_(w): w_init, descent_dir, stepsize, fun_, jax.grad(fun_), state ) - def test_failure_cases(self): + def test_failure_descent_direction(self): # See gh-7475 # For this f and p, starting at a point on axis 0, the strong Wolfe @@ -225,44 +226,56 @@ def test_failure_cases(self): def fun(x): return jnp.dot(x, x) - def fun_der(x): - return 2.0 * x - - c2 = 0.5 p = jnp.array([1.0, 0.0]) - - # 1. Test that the line search fails for p not a descent direction x = 60 * p - ls = ZoomLineSearch(fun, c2=c2, maxiter=10) + + # Test that the line search fails for p not a descent direction + # For high maxiter, still finds a decrease error because of + # the approximate Wolfe condition so we reduced maxiter + ls = ZoomLineSearch(fun, c2=0.5, maxiter=18) stdout = io.StringIO() with redirect_stdout(stdout): s, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) - self._check_step_in_state(x, p, s, fun, fun_der, state) + self._check_step_in_state(x, p, s, fun, jax.grad(fun), state) # Check that we were not able to make a step or an infinitesimal one - self.assertTrue(s == 0.) + self.assertLess(s, 1e-5) self.assertTrue(FLAG_NOT_A_DESCENT_DIRECTION in stdout.getvalue()) self.assertTrue(FLAG_NO_STEPSIZE_FOUND in stdout.getvalue()) - # 2. Test that the line search fails if the maximum stepsize is too small - # Here, smallest s satisfying strong Wolfe conditions for c2=0.5 is 30 + def test_failure_too_small_max_stepsize(self): + def fun(x): + return jnp.dot(x, x) + + p = jnp.array([1.0, 0.0]) x = -60 * p - ls = ZoomLineSearch(fun, c2=c2, max_stepsize=10, verbose=True) + + # Test that the line search fails if the maximum stepsize is too small + # Here, smallest s satisfying strong Wolfe conditions for c2=0.5 is 30 + ls = ZoomLineSearch(fun, c2=0.5, max_stepsize=10, verbose=True) stdout = io.StringIO() with redirect_stdout(stdout): s, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) - self._check_step_in_state(x, p, s, fun, fun_der, state) + self._check_step_in_state(x, p, s, fun, jax.grad(fun), state) # Check that we still made a step self.assertTrue(s == 10.0) self.assertTrue(FLAG_INTERVAL_NOT_FOUND in stdout.getvalue()) self.assertTrue(FLAG_CURVATURE_COND_NOT_SATSIFIED in stdout.getvalue()) - # 3. s=30 will only be tried on the 6th iteration, so this fails because + def test_failure_not_enough_iter(self): + def fun(x): + return jnp.dot(x, x) + + p = jnp.array([1.0, 0.0]) + x = -60 * p + + c2 = 0.5 + # s=30 will only be tried on the 6th iteration, so this fails because # the maximum number of iterations is reached. ls = ZoomLineSearch(fun, c2=c2, maxiter=5, verbose=True) stdout = io.StringIO() with redirect_stdout(stdout): s, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) - self._check_step_in_state(x, p, s, fun, fun_der, state) + self._check_step_in_state(x, p, s, fun, jax.grad(fun), state) # Check that we still made a step self.assertTrue(s == 16.0) self.assertTrue(state.failed) @@ -272,10 +285,11 @@ def fun_der(x): # Check if it works normally ls = ZoomLineSearch(fun, c2=c2) s, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) - self._assert_line_conds(x, p, s, fun, fun_der, c1=ls.c1, c2=c2) - self._check_step_in_state(x, p, s, fun, fun_der, state) + self._assert_line_conds(x, p, s, fun, jax.grad(fun), c1=ls.c1, c2=c2) + self._check_step_in_state(x, p, s, fun, jax.grad(fun), state) self.assertTrue(s >= 30.0) + def test_failure_flat_fun(self): # Check failure for a very flat function def fun_flat(x): return jnp.exp(-1 / x**2) @@ -289,13 +303,14 @@ def fun_flat(x): ls.run(init_stepsize=1.0, params=x) self.assertTrue(FLAG_INTERVAL_TOO_SMALL in stdout.getvalue()) + def test_failure_inf_value(self): # Check behavior for inf/nan values def fun_inf(x): return jnp.log(x) x = 1.0 p = -2.0 - ls = ZoomLineSearch(fun_inf, verbose=True, jit=False) + ls = ZoomLineSearch(fun_inf, verbose=True) stdout = io.StringIO() with redirect_stdout(stdout): ls.run(init_stepsize=1.0, params=x, descent_direction=p) @@ -423,6 +438,4 @@ def run_ls(): if __name__ == "__main__": - # Uncomment the line below in order to run in float64. - # jax.config.update("jax_enable_x64", True) absltest.main()