You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ZoomLineSearch under vmap ends up calling failure_diagnostic() even when safe_stepsize > 0. as shown here. This can result in a lot of print outs, and I didn't see a way to disable the failure diagnostic print outs given the current implementation. I think the relevant commit is 614dc7b. Below, you will find minimum reproducible examples.
WARNING: jaxopt.ZoomLineSearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
INFO: jaxopt.ZoomLineSearch: Iter: 1, Stepsize: 1.0, Decrease error: -0.0, Curvature error: 0.0
WARNING: jaxopt.ZoomLineSearch: The linesearch failed because the provided direction is not a descent direction. The slope (=-0.0) at stepsize=0 should be negative
WARNING: jaxopt.ZoomLineSearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: jaxopt.ZoomLineSearch: Computed stepsize (=1.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: jaxopt.ZoomLineSearch: Very large absolute slope at stepsize=0. (|slope|=0.0). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: jaxopt.ZoomLineSearch: Cannot even make a step without getting Inf or Nan. The linesearch won't make a step and the optimizer is stuck.
WARNING: jaxopt.ZoomLineSearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
Whereas, the following code does not produce any warnings
Hello @tare,
Thanks for pointing this out. vmap evaluates both branches of a cond (not the case without vmap), see https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html.
I'm not sure how we could then have failure diagnostics under vmap.
At least I have patched #544 for zoom not to display failure diagnostics unless verbose is set to True. That will avoid unncessary prints.
Environment
% pip list|grep jax jax 0.4.20 jaxlib 0.4.20 jaxopt 0.8.2 % python --version Python 3.10.11
Description
ZoomLineSearch
undervmap
ends up callingfailure_diagnostic()
even whensafe_stepsize > 0.
as shown here. This can result in a lot of print outs, and I didn't see a way to disable the failure diagnostic print outs given the current implementation. I think the relevant commit is 614dc7b. Below, you will find minimum reproducible examples.The following code
gives the following warnings
Whereas, the following code does not produce any warnings
Here is a minimal reproducible example illustrating the issue with
jax.debug.print
,cond
, andvmap
; the following codegives the following output
The text was updated successfully, but these errors were encountered: