Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong failure diagnostic print outs from ZoomLineSearch under vmap #555

Closed
tare opened this issue Nov 18, 2023 · 3 comments
Closed

Wrong failure diagnostic print outs from ZoomLineSearch under vmap #555

tare opened this issue Nov 18, 2023 · 3 comments

Comments

@tare
Copy link

tare commented Nov 18, 2023

Environment

% pip list|grep jax   
jax                       0.4.20
jaxlib                    0.4.20
jaxopt                    0.8.2

% python --version
Python 3.10.11

Description

ZoomLineSearch under vmap ends up calling failure_diagnostic() even when safe_stepsize > 0. as shown here. This can result in a lot of print outs, and I didn't see a way to disable the failure diagnostic print outs given the current implementation. I think the relevant commit is 614dc7b. Below, you will find minimum reproducible examples.

The following code

import jax.numpy as jnp
from jax import jit, vmap
from jaxopt import LBFGS
from jax.lax import map

def solve(x, y):
    solver = LBFGS(lambda x, y: jnp.square(y-x), linesearch="zoom")
    x, _ = solver.run(x, y=y)
    return x

x_init = jnp.zeros(())
ys = jnp.arange(1)

vmap(solve, in_axes=(None, 0))(jnp.zeros(()), ys)

gives the following warnings

WARNING: jaxopt.ZoomLineSearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
INFO: jaxopt.ZoomLineSearch: Iter: 1, Stepsize: 1.0, Decrease error: -0.0, Curvature error: 0.0
WARNING: jaxopt.ZoomLineSearch: The linesearch failed because the provided direction is not a descent direction. The slope (=-0.0) at stepsize=0 should be negative
WARNING: jaxopt.ZoomLineSearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: jaxopt.ZoomLineSearch: Computed stepsize (=1.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: jaxopt.ZoomLineSearch: Very large absolute slope at stepsize=0. (|slope|=0.0). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: jaxopt.ZoomLineSearch: Cannot even make a step without getting Inf or Nan. The linesearch won't make a step and the optimizer is stuck.
WARNING: jaxopt.ZoomLineSearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.

Whereas, the following code does not produce any warnings

import jax.numpy as jnp
from jax import jit, vmap
from jaxopt import LBFGS
from jax.lax import map

def solve(x, y):
    solver = LBFGS(lambda x, y: jnp.square(y-x), linesearch="zoom")
    x, _ = solver.run(x, y=y)
    return x

x_init = jnp.zeros(())
ys = jnp.arange(1)

res = map(jit(lambda y: solve(jnp.zeros(()), y)), ys)

Here is a minimal reproducible example illustrating the issue with jax.debug.print, cond, and vmap; the following code

import jax.numpy as jnp
from jax import vmap, jit
from jax.lax import cond
import jax.debug

def test(x):
    def true_fun(x):
        pass
    def false_fun(x):
        jax.debug.print("{}", x)
    cond(x < 3, true_fun, false_fun, x)

print("map and jit")
map(jit(test), jnp.arange(5))
print("vmap")
vmap(test)(jnp.arange(5))

gives the following output

map and jit
3
4
vmap
0
1
2
3
4
@vroulet
Copy link
Collaborator

vroulet commented Nov 22, 2023

Hello @tare,
Thanks for pointing this out. vmap evaluates both branches of a cond (not the case without vmap), see https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html.
I'm not sure how we could then have failure diagnostics under vmap.
At least I have patched #544 for zoom not to display failure diagnostics unless verbose is set to True. That will avoid unncessary prints.

@tare
Copy link
Author

tare commented Nov 22, 2023

Thanks for the quick reply and pointing out #544! I hope that PR gets merged soon.

@vroulet
Copy link
Collaborator

vroulet commented Jan 10, 2024

Closing as #544 has been merged.

@vroulet vroulet closed this as completed Jan 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants