Inconsistency between jax.scipy.minimize
and jaxopt.LBFGS
#322
-
Hi, As I understand it, I tried a minimal example, where the results were a bit different from my use case but still point an inconsistency: import jax.numpy as jnp
from jax.scipy.optimize import minimize
from jaxopt import LBFGS
from jaxopt._src import test_util
def fun(x, *args, **kwargs):
return 15.0*(x[1] - x[0]**2.0)**2.0 + (1 - x[0])**2.0
x0 = jnp.zeros(2)
lbfgs = LBFGS(fun=fun, tol=1e-3, maxiter=500, maxls=20)
x_jaxopt, lbfgs_state_jaxopt = lbfgs.run(x0)
soln_jaxminimize = minimize(
fun,
x0,
method="l-bfgs-experimental-do-not-rely-on-this",
options={"gtol": 1e-3, "maxiter": 500, "maxls": 20},
)
x_jax_minimize = soln_jaxminimize.x
print(soln_jaxminimize.success.item(), soln_jaxminimize.status.item(), soln_jaxminimize.nit.item())
test_util.JaxoptTestCase().assertArraysAllClose(x_jaxopt, x_jax_minimize, atol=1e-3) which should output
What this shows is that in the case of a simplified Rosenbrock function (i.e. 15 instead of 100), the implementation of I am about to start a debugging quest to find out what the differences in the line search implementations are, but potentially this is already a known topic. It's a tricky one because again, in my actual use case the |
Beta Was this translation helpful? Give feedback.
Replies: 10 comments 5 replies
-
I also see that there was a similar issue pointed out here, but again in this example |
Beta Was this translation helpful? Give feedback.
-
We use the zoom line search by default, which comes from |
Beta Was this translation helpful? Give feedback.
-
I found out! Took quite some time of debugging but here you go: the discrepancy comes from the fact that in This is not the case in However, I am not sure it's necessarily something we want given that the algorithm works even though the line search fails. Also, this explains the discrepancy in this minimal example but not in my case, I will need to do some more logging to see exactly what's going on. |
Beta Was this translation helpful? Give feedback.
-
And indeed if I comment out the failure line in the |
Beta Was this translation helpful? Give feedback.
-
And actually it also explains exactly the discrepancy I noticed in my real use case!! Commented out the same line, and there I got a lot more iterations with I will actually try to implement the fix I was mentioning, maybe it will be worth considering for |
Beta Was this translation helpful? Give feedback.
-
Thanks a lot for the investigation! So do you think it's better to stop the algorithm altogether when the line search failed? We need either a way to override the stopping criterion or a mechanism to tell the optimization loop that it has to stop. |
Beta Was this translation helpful? Give feedback.
-
I honestly have no idea atm... That's why I was suggesting just allowing the user to decide for themself with an attribute The reason I am confused is because in the Rosenbrock case it appears to be better to ignore the failure, while in my use case it appears to be better not to ignore it... it definitely depends on the situation, the question is "what is the discriminating factor?". Anyway in the meantime I am submitting a PR with the fix implemented, although it's branched off my previous PR. |
Beta Was this translation helpful? Give feedback.
-
Let's merge your previous PR first, for simplicity. There is just a small fix needed in the way you initialize |
Beta Was this translation helpful? Give feedback.
-
I have obtained similar mistakes between |
Beta Was this translation helpful? Give feedback.
-
@zaccharieramzi I'm having the same issue. Did you ever submit a PR where you implemented |
Beta Was this translation helpful? Give feedback.
I found out! Took quite some time of debugging but here you go: the discrepancy comes from the fact that in
jax.scipy.minimize
when the line search fails, i.e.ls_results.failed == True
, thestatus
is then5
.In turn, this makes the LBFGS state
failed
attributeTrue
.When this attribute is
True
, then of course the iterations stop.This is not the case in
jaxopt
. Indeed, no matter the failure of the line search, the step size is always accepted. This is because the LBFGS state injaxopt
is less rich than injax.scipy.minimize
, and it does not allow some more custom logic in theIterativeSolver
's_cond_fun
. Maybe one way to get back that consistency would be to enrich theLBFGSState
tuple w…