Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computation hangs forever with jax > 0.4.31 #24219

Closed
mfschubert opened this issue Oct 9, 2024 · 3 comments
Closed

Computation hangs forever with jax > 0.4.31 #24219

mfschubert opened this issue Oct 9, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@mfschubert
Copy link

mfschubert commented Oct 9, 2024

Description

Jax is hanging forever on certain calculations for versions > 0.4.31.

This problem is proving extremely difficult to diagnose and create a minimal repro for, but it is popping up in several repos that I maintain including fmmax and invrs_opt.

I will try to get a simplified reproduction, but here is some code that creates the problem (colab).

import jax
import jax.numpy as jnp
from totypes import types

from invrs_opt.optimizers import lbfgsb


def optimization_with_vmap(steps):
    print("running", flush=True)

    def initial_params_fn(key):
        del key
        return types.Density2DArray(array=jnp.ones((3, 3)))

    keys = jax.random.split(jax.random.PRNGKey(0), num=2)
    opt = lbfgsb.density_lbfgsb(beta=2, maxcor=20)

    params = jax.vmap(initial_params_fn)(keys)
    print("vmap params initialized", flush=True)
    state = jax.vmap(opt.init)(params)
    print("vmap state initialized", flush=True)

    @jax.jit
    @jax.vmap
    def step_fn(state):
        params = opt.params(state)
        dummy_value = jnp.array(1.0, dtype=float)
        dummy_grad = jax.tree_util.tree_map(jnp.ones_like, params)
        state = opt.update(grad=dummy_grad, value=dummy_value, params=params, state=state)
        return state, dummy_value

    for i in range(steps):
        print(f"vmap step {i}", flush=True)
        state, value = step_fn(state)

    for k in keys:
        print("If we don't print the key and use jax > 0.4.30, the code hangs.", flush=True)
        # print(f"{k=}", flush=True)
        params = initial_params_fn(k)
        print("params initialized", flush=True)
        state = opt.init(params)
        print("state initialized", flush=True)
        for i in range(steps):
            print(f"step {i}", flush=True)
            params = opt.params(state)
            dummy_value = jnp.array(1.0, dtype=float)
            dummy_grad = jax.tree_util.tree_map(jnp.ones_like, params)
            state = opt.update(grad=dummy_grad, value=dummy_value, params=params, state=state)


optimization_with_vmap(steps=10)

The code with the print statement commented out will hang for jax 0.4.34, and will execute as expected for jax 0.4.30. Bizarrely, it will run in both cases if the print statement is uncommented. As I mentioned, it is popping up in several different places, and so I don't believe that something very specific to this code is to blame. I experienced this problem locally, on github workflows, and on colab.

I'll try to create a smaller repro but just want to get this issue filed first.

System info (python version, jaxlib version, accelerator, etc.)

Colab cpu, jax 0.4.34

@mfschubert mfschubert added the bug Something isn't working label Oct 9, 2024
@mfschubert
Copy link
Author

I am finding that adding a jax.block_until_ready to the arguments of a jax.pure_callback helps in at least one case. I will investigate this in a few other cases.

@mfschubert
Copy link
Author

mfschubert commented Oct 10, 2024

Ok, here is a simpler reproduction.

import jax
import jax.numpy as jnp
from invrs_opt.optimizers import lbfgsb
from totypes import types

def optimization_with_vmap(steps):
    params = types.Density2DArray(array=jnp.ones((2, 3, 3)))
    opt = lbfgsb.density_lbfgsb(beta=2, maxcor=20)
    state = jax.vmap(opt.init)(params)

    @jax.jit
    @jax.vmap
    def step_fn(state):
        params = opt.params(state)
        dummy_value = jnp.array(1.0, dtype=float)
        dummy_grad = jax.tree_util.tree_map(jnp.ones_like, params)
        state = opt.update(grad=dummy_grad, value=dummy_value, params=params, state=state)
        return state, dummy_value

    for i in range(steps):
        print(f"vmap step {i}", flush=True)
        state, value = step_fn(state)
        # Without this `block_until_ready`, this hangs after ~35 steps.
        jax.block_until_ready((state, value))


optimization_with_vmap(steps=100)

There is a block_until_ready call which seems to be critical in avoiding the hanging behavior. When removed, on colab CPU runtime it tends to hang after 35 steps.

@mfschubert mfschubert changed the title Computation hangs forever with jax > 0.4.30 Computation hangs forever with jax > 0.4.31 Oct 10, 2024
@mfschubert
Copy link
Author

It seems this issue is related to use of jax types inside functions called via jax.pure_callback. I will close this and file a new issue that is more specific.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant