Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Regression] Slower integration of differential equations since jaxlib > 0.4.32.dev20240807 #518

Open
jdehning opened this issue Oct 25, 2024 · 4 comments

Comments

@jdehning
Copy link

Between jaxlib==0.4.32.dev20240807 and jaxlib==0.4.32.dev20240812 I observe a significant decrease of performance for integration of differential equations with many solver steps (up to 8x slower). Minimal example:

import timeit

import numpy as np
import diffrax
import jax
import jax.numpy as jnp
from diffrax import diffeqsolve, Tsit5, ODETerm, SaveAt, PIDController, StepTo

def ODE(t, y, args):
    return jnp.cos(y)  

t_out = np.linspace(0, 20, 10)  
t_steps = np.linspace(0, 20, 10000) 

stepsize_controller = StepTo(ts=t_steps)
saveat = SaveAt(ts=t_out)

term = ODETerm(ODE)
y0 =1.
f = lambda y0: diffeqsolve(term, Tsit5(), t0=t_out[0], t1=t_out[-1], dt0=None, y0=y0, saveat=saveat,
			   stepsize_controller=stepsize_controller, max_steps=len(t_steps))
f = jax.jit(f)

solution = jax.block_until_ready(f(y0))

def f_timer():
    jax.block_until_ready(f(1.))    
runtime = timeit.timeit(f_timer, number=100)/100*1000
print(f"Runtime: {runtime:.3f} ms")

Tested on Ubuntu 22.04 and CPU backend. Runtime is on my PC 1.8 ms for the nightly jaxlib version 20240807 and 14.8ms for version 20240812. The difference is the largest if t_steps is quite large.

To test it quickly, I used the following one-liner: uv venv --python 3.12 && uv pip install diffrax numpy --pre jax==0.4.32.dev20240807 jaxlib==0.4.32.dev20240807 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --reinstall --exclude-newer 2024-09-20 && uv run test_diffrax.py && uv pip install diffrax numpy --pre jax==0.4.32.dev20240807 jaxlib==0.4.32.dev20240812 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --reinstall --exclude-newer 2024-09-20 && uv run test_diffrax.py, where test_diffrax.py is the script above.

I don't know whether I should have better opened the issue on the JAX Github, let me know if it isn't correct here.

@patrick-kidger
Copy link
Owner

I'd definitely open this as an issue on the JAX GitHub! Probably this will be due to some change in the XLA compiler between those two builds, which unfortunately means there isn't much we can do about it from Diffrax.

If you're interested in digging into it then you might be able to locate the appropriate commit by bisecting through the XLA repo (https://github.com/openxla/xla/). I believe JAX itself hosts several benchmarks so probably something can be added to those to prevent a regression again afterwards.

@patrick-kidger
Copy link
Owner

I've just seen a suggested fix over here: jax-ml/jax#24501

Give it a try?

@jdehning
Copy link
Author

Yes, with the suggested fix, I observe similar runtimes again.

A bit unrelated, I can't test it for the newest nightly Jax version, as Jax.core.ConcreteArray has been removed from Jax (jax-ml/jax@48f24b6) but ConcreteArray is used in diffrax. I don't know whether you already know about it.

@patrick-kidger
Copy link
Owner

Great that the source of the slowdown has been fixed.

Thanks for the heads-up on ConcreteArray. Judging from that PR it looks like we probably want to use something like jax.core.is_concrete(jax.core.get_aval(x)) instead now. I'll delay our next release of Diffrax until after that JAX release, and then we can try to include a fix for that at the same time as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants