-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Computation hangs forever with jax > 0.4.31 #24219
Comments
I am finding that adding a |
Ok, here is a simpler reproduction. import jax
import jax.numpy as jnp
from invrs_opt.optimizers import lbfgsb
from totypes import types
def optimization_with_vmap(steps):
params = types.Density2DArray(array=jnp.ones((2, 3, 3)))
opt = lbfgsb.density_lbfgsb(beta=2, maxcor=20)
state = jax.vmap(opt.init)(params)
@jax.jit
@jax.vmap
def step_fn(state):
params = opt.params(state)
dummy_value = jnp.array(1.0, dtype=float)
dummy_grad = jax.tree_util.tree_map(jnp.ones_like, params)
state = opt.update(grad=dummy_grad, value=dummy_value, params=params, state=state)
return state, dummy_value
for i in range(steps):
print(f"vmap step {i}", flush=True)
state, value = step_fn(state)
# Without this `block_until_ready`, this hangs after ~35 steps.
jax.block_until_ready((state, value))
optimization_with_vmap(steps=100) There is a |
It seems this issue is related to use of jax types inside functions called via |
Description
Jax is hanging forever on certain calculations for versions > 0.4.31.
This problem is proving extremely difficult to diagnose and create a minimal repro for, but it is popping up in several repos that I maintain including fmmax and invrs_opt.
I will try to get a simplified reproduction, but here is some code that creates the problem (colab).
The code with the print statement commented out will hang for jax 0.4.34, and will execute as expected for jax 0.4.30. Bizarrely, it will run in both cases if the print statement is uncommented. As I mentioned, it is popping up in several different places, and so I don't believe that something very specific to this code is to blame. I experienced this problem locally, on github workflows, and on colab.
I'll try to create a smaller repro but just want to get this issue filed first.
System info (python version, jaxlib version, accelerator, etc.)
Colab cpu, jax 0.4.34
The text was updated successfully, but these errors were encountered: