Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the possibility to stop LBFGS after a line search failure #323

Merged

Conversation

zaccharieramzi
Copy link
Contributor

This PR addresses the issue mentioned here, by allowing the user to decide if they want to have LBFGS fail when the line search fails or not using a allow_line_search_fail attribute (set to False by default to ensute backward compatibility).

I am not sure exactly which unit tests I should add for this new addition : typically it will fail on the Rosenbrock function.

Copy link
Collaborator

@mblondel mblondel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for doing this. I left some comments and suggestions.

jaxopt/_src/lbfgs.py Outdated Show resolved Hide resolved
jaxopt/_src/lbfgs.py Outdated Show resolved Hide resolved
jaxopt/_src/lbfgs.py Outdated Show resolved Hide resolved
aux=aux)
gamma=jnp.asarray(1.0, dtype=dtype),
aux=aux,
failed=jnp.asarray(False))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need asarray for booleans. CC @froystig

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try without and tell you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it works without the asarray, so removed it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want them as explicit scalar arrays, rather than Python scalars, then np.array(False) where np is original numy. If we moreover want them on device for some reason, then jnp.asarray(False).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the main thing we want to be careful of is that init_state and update output states with the same types. Otherwise, it triggers recompilations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we're fine with just a plain literal False:

>>> import jax
>>> import jax.numpy as jnp
>>> @jax.jit
... def f(x):
...   print('compiling')
...   return x
... 
>>> f(False)
compiling
DeviceArray(False, dtype=bool)
>>> f(jnp.asarray(False))
DeviceArray(False, dtype=bool)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would mean rewriting the unit tests that check the return types, so maybe it's for a different PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. Let's squash the commits and merge this PR.

jaxopt/_src/lbfgs.py Outdated Show resolved Hide resolved
jaxopt/_src/lbfgs.py Outdated Show resolved Hide resolved
@zaccharieramzi
Copy link
Contributor Author

@mblondel went over your comments.
Only thing is that like I mentioned, I don't have a unit test for this.

Copy link
Collaborator

@mblondel mblondel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. LGTM apart from 2 minor comments.

jaxopt/_src/lbfgs.py Outdated Show resolved Hide resolved
jaxopt/_src/lbfgs.py Show resolved Hide resolved
@zaccharieramzi
Copy link
Contributor Author

@mblondel all corrections incorporated and tests are passing at least locally

Copy link
Collaborator

@mblondel mblondel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM apart from squashing commits. Thanks again for your contribution!

@@ -211,6 +213,7 @@ class LBFGS(base.IterativeSolver):

stepsize: Union[float, Callable] = 0.0
linesearch: str = "zoom"
allow_failed_linesearch: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the default behavior now is the same as that of LBFGS in JAX core, right? (just to confirm)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no no I left the default behaviour unchanged : should I change it the Rosenbrock test would not pass and indeed this is what happens in JAX core (#322 (comment)).

A simple solution to make this test pass would be to have use_gamma=False i.e. not update the initial H_0 at each iteration and keep it identity (as is done in scipy).

If I were to make a proposal I would say:

  • let's have the default behavior same as JAX core (and therefore a change of behavior for jaxopt), i.e. allow_failed_linesearch: bool = True
  • correct the Rosenbrock test by using use_gamma=False
  • have a warning somewhere in the doc that use_gamma can result in failures for very ill-conditioned functions

Waiting for your input on this before taking any action/squashing commits

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. I thought allow_failed_linesearch == True was JAXopt's behavior? (since the optimization loop keeps going even in the case of failed line search).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah it's because there is a confusion about the naming. Originally I had called it allow_linesearch_failure which meant that we authorize the linesearch failure to be passed on. Now with the new name I understand why there is some confusion. Maybe we can just change the name?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old name is confusing too IMO. How about stop_if_linesearch_fails or continue_if_linesearch_fails?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, resorting to stop_if_linesearch_fails and keeping it False by default (which is the current Jaxopt behavior), but feel free to let me know what you think of my previous plan of action (changing it with the new name):

  • let's have the default behavior same as JAX core (and therefore a change of behavior for jaxopt), i.e. allow_failed_linesearch: bool = True
  • correct the Rosenbrock test by using use_gamma=False
  • have a warning somewhere in the doc that use_gamma can result in failures for very ill-conditioned functions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion is that I would keep stop_if_linesearch_fails=False by default for the time being because we already have users relying on LBFGS as (inner) solver and we can't take the risk to break their code/tests without any warning. If there is strong evidence that stopping is the right thing to do and this should be the default, we need to use a deprecation cycle to warn that the default value is going to change in release 0.x

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for now it is stop_if_linesearch_fails=False, keeping it as is.

jaxopt/_src/lbfgs.py Show resolved Hide resolved
@zaccharieramzi zaccharieramzi force-pushed the line-search-failed-breaks-lbfgs branch from 16a1053 to 3ecda44 Compare October 18, 2022 12:00
@zaccharieramzi
Copy link
Contributor Author

@mblondel squashed the commits, and tests are green locally, you can merge when green in CI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants