-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding the possibility to stop LBFGS after a line search failure #323
Adding the possibility to stop LBFGS after a line search failure #323
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for doing this. I left some comments and suggestions.
jaxopt/_src/lbfgs.py
Outdated
aux=aux) | ||
gamma=jnp.asarray(1.0, dtype=dtype), | ||
aux=aux, | ||
failed=jnp.asarray(False)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we need asarray
for booleans. CC @froystig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can try without and tell you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it works without the asarray
, so removed it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want them as explicit scalar arrays, rather than Python scalars, then np.array(False)
where np
is original numy. If we moreover want them on device for some reason, then jnp.asarray(False)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the main thing we want to be careful of is that init_state
and update
output states with the same types. Otherwise, it triggers recompilations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we're fine with just a plain literal False
:
>>> import jax
>>> import jax.numpy as jnp
>>> @jax.jit
... def f(x):
... print('compiling')
... return x
...
>>> f(False)
compiling
DeviceArray(False, dtype=bool)
>>> f(jnp.asarray(False))
DeviceArray(False, dtype=bool)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would mean rewriting the unit tests that check the return types, so maybe it's for a different PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree. Let's squash the commits and merge this PR.
@mblondel went over your comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM apart from 2 minor comments.
@mblondel all corrections incorporated and tests are passing at least locally |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM apart from squashing commits. Thanks again for your contribution!
jaxopt/_src/lbfgs.py
Outdated
@@ -211,6 +213,7 @@ class LBFGS(base.IterativeSolver): | |||
|
|||
stepsize: Union[float, Callable] = 0.0 | |||
linesearch: str = "zoom" | |||
allow_failed_linesearch: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the default behavior now is the same as that of LBFGS in JAX core, right? (just to confirm)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no no I left the default behaviour unchanged : should I change it the Rosenbrock test would not pass and indeed this is what happens in JAX core (#322 (comment)).
A simple solution to make this test pass would be to have use_gamma=False
i.e. not update the initial H_0
at each iteration and keep it identity (as is done in scipy
).
If I were to make a proposal I would say:
- let's have the default behavior same as JAX core (and therefore a change of behavior for jaxopt), i.e.
allow_failed_linesearch: bool = True
- correct the Rosenbrock test by using
use_gamma=False
- have a warning somewhere in the doc that
use_gamma
can result in failures for very ill-conditioned functions
Waiting for your input on this before taking any action/squashing commits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused. I thought allow_failed_linesearch == True
was JAXopt's behavior? (since the optimization loop keeps going even in the case of failed line search).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah it's because there is a confusion about the naming. Originally I had called it allow_linesearch_failure
which meant that we authorize the linesearch failure to be passed on. Now with the new name I understand why there is some confusion. Maybe we can just change the name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old name is confusing too IMO. How about stop_if_linesearch_fails
or continue_if_linesearch_fails
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, resorting to stop_if_linesearch_fails
and keeping it False
by default (which is the current Jaxopt behavior), but feel free to let me know what you think of my previous plan of action (changing it with the new name):
- let's have the default behavior same as JAX core (and therefore a change of behavior for jaxopt), i.e.
allow_failed_linesearch: bool = True
- correct the Rosenbrock test by using
use_gamma=False
- have a warning somewhere in the doc that
use_gamma
can result in failures for very ill-conditioned functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My opinion is that I would keep stop_if_linesearch_fails=False
by default for the time being because we already have users relying on LBFGS as (inner) solver and we can't take the risk to break their code/tests without any warning. If there is strong evidence that stopping is the right thing to do and this should be the default, we need to use a deprecation cycle to warn that the default value is going to change in release 0.x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for now it is stop_if_linesearch_fails=False
, keeping it as is.
16a1053
to
3ecda44
Compare
@mblondel squashed the commits, and tests are green locally, you can merge when green in CI. |
This PR addresses the issue mentioned here, by allowing the user to decide if they want to have LBFGS fail when the line search fails or not using a
allow_line_search_fail
attribute (set toFalse
by default to ensute backward compatibility).I am not sure exactly which unit tests I should add for this new addition : typically it will fail on the Rosenbrock function.