-
-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance issue with SDE solver #517
Comments
I get them to be a lot closer by using There's also some risky (but often useful) changes to UBP we've made internally that I've been meaning to put in the fork, so you can definitely do a fair amount with modifications to UBP (being able to get through all 3 stated requirements). |
Yup, VBT is often the cause of poor SDE performance. Really we need some kind of LRU caching to make it behave properly, but that doesn't seem to be easy in JAX -- I'm pretty sure it'd require both a new primitive (' In the meantime I recommend UBP as the go-to for these kinds of normal 'just solve an SDE' applications. |
I think a lot of people get turned off by the |
Thanks. Indeed using UBP does help but I understand it's quite restricted in terms of usage.
It seems there is still a factor ~10-20 difference (irrespective of number of time steps) between the homemade solver and diffrax with UBP. I would have naively thought that any irrelevant computation would be jitted away. Could you elaborate on what diffrax with UBP does compared to the naive solver?
|
Diffrax has a lot more checking/shaping/logging than the default implementation. You can see it reflected in the jaxprs: diffrax
pure jax
I believe most of this comes from the UBP, since if I do @jax.jit
def homemade_simu():
ts = jnp.linspace(t0, t1, ndt)
def step(y, t):
dw = brownian_motion.evaluate(t, t + dt)
dy = drift(None, y, None) * dt + diffusion(None, y, None) * dw
return y + dy, y
return jax.lax.scan(step, 1.0, ts)[-1] I see the times are pretty much the same. Perhaps this does indicate that there is room for cutting down the speed costs of the UBP related overhead. |
FWIW I think the speed difference here does seem unacceptably large. This seems like it should be improved. Starting with the low-hanging fruit to be sure we're doing more of an equal comparison: can you try setting Also, can you try using |
With throw=False, EQX_ERROR=NAN and step to, this is what I see codeimport os
os.environ["EQX_ON_ERROR"] = "nan"
import diffrax as dx
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
# === simulation parameters
key = jax.random.PRNGKey(42)
t0 = 0
t1 = 1
y0 = 1.0
ndt = 101
dt = (t1 - t0) / (ndt - 1)
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.2
steps = jnp.linspace(t0, t1, ndt)
brownian_motion = dx.UnsafeBrownianPath(shape=(), key=key)
solver = dx.Euler()
terms = dx.MultiTerm(dx.ODETerm(drift), dx.ControlTerm(diffusion, brownian_motion))
saveat = dx.SaveAt(steps=True)
@jax.jit
def diffrax_simu():
return dx.diffeqsolve(terms, solver, t0, t1, dt0=None, y0=y0, saveat=saveat, adjoint=dx.DirectAdjoint(), throw=False, stepsize_controller=dx.StepTo(ts=steps)).ys
@jax.jit
def homemade_simu():
dWs = jnp.sqrt(dt) * jax.random.normal(key, (ndt,))
def step(y, dW):
dy = drift(None, y, None) * dt + diffusion(None, y, None) * dW
return y + dy, y
return jax.lax.scan(step, 1.0, dWs)[-1]
y = diffrax_simu().block_until_ready()
plt.plot(y)
y = homemade_simu().block_until_ready()
plt.plot(y)
plt.show()
%timeit _ = diffrax_simu().block_until_ready()
%timeit _ = homemade_simu().block_until_ready() (diffrax top, custom bottom) 2.18 ms ± 351 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) (without any of those things I had): 2.43 ms ± 666 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) (all on CPU, just a slower CPU, but the 20-30x slowdown seems of the same scale) |
So you definitely don't want Make sure you include an argument (say I'd also try with and without With all of the above in, then at that point there shouldn't actually be that much difference between the two implementations. (And if there is then we should figure out what.) |
The default actually errors with UBP which is why I changed to direct adjoint
|
Ah, right. I've just checked and in the case of an unsafe SDE we do actually arrange for Line 352 in ada5229
(In retrospect I think we could have arranged for the default adjoint to also do the same thing, that might be a small usability improvement.) Anyway, that's everything off the top of my head -- I might be forgetting something but with these settings then I think Diffrax should be doing something similar to the simple (EDIT: we still have one discrepancy I have just noticed: generating the Brownian samples in advance vs on-the-fly.) If you'd like to dig into this then it might be time to stare at some jaxprs or HLO for the two programs. If you want to do this at the jaxpr level then you might find https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_finalise_jaxpr.py Many primitives exist just to add e.g. an autodiff rule, so we can simplify our jaxprs down to what actually gets lowered by ignoring that and tracing through their impl rules instead. |
DirectAjoint does slow things down, but not all the way. If I switch to a branch that allows for UBP + recursive adjoint, it's faster but still around ~4x gap. If I account for the fact that UBP has to split keys but the other doesn't, I get the gap to be around ~1.1-1.2 (which maybe isn't ideal, but seems much more reasonable to me given there's probably some other if statements/logging that might exist). x = Timer(lambda : diffrax_simu(y0).block_until_ready())
print(x.timeit(number=100))
x = Timer(lambda : homemade_simu(y0).block_until_ready())
print(x.timeit(number=100)) with (above things, NAN, steps, function input, stepto, max steps, etc. all that) and direct adjoint: w/ checkpoint adjoint (on an internal branch that had some UBP changes to work with checkpoint): w/ both splitting keys: (code changed to: @jax.jit
def homemade_simu(yy):
def step(y1, dW):
y, k = y1
k, subkey = jax.random.split(k)
dw = jnp.sqrt(dt) * jax.random.normal(subkey)
dy = drift(None, y, None) * dt + diffusion(None, y, None) * dw
return (y + dy, k), y
return jax.lax.scan(step, (yy, key), steps)[-1] ) |
Aha, interesting! Good to have more-or-less gotten to the bottom of the cause of this. So:
On point 2, I suspect the solution may require allowing the control to have additional state. (Which is also what we'd need to make VBT faster.) Perhaps it's time to bite that bullet and allow for that to happen. Happy to hear suggestions on this one! |
|
|
Yes, looking at it more, this would probably have to be change/addon to support passing the "step" counter around. If this is an acceptable change, I don't think it would be too much for me to get a PR up.
This was my conclusion as well, and I started drafting a branch for this, but figured it would require a pretty noticeable breaking change (at least internally), and I figured diffrax was more fait accompli than c'est la vie when it came to this level of breaking changes. |
It's true, I try to avoid breaking changes where possible! They're no fun for anyone. But the performance issues discussed here genuinely are quite severe, so I think they're actually strong enough to motivate making a breaking change of this nature.
Awesome, I'm looking forward to it! Let's see if we can get the stateful controls done at the same time? I'd like to contain the breaking changes to a single release, ideally. |
Hello,
When solving the (trivial) SDE$d y_t = -y_t\ dt + 0.2\ dW_t$ , the Diffrax Euler solver is ~200x slower than a naive for loop. Am I doing something wrong? The speed difference is consistent across various SDEs, solvers, time steps
dt
, and number of trajectories, and it appears to be specific to SDE solvers.The text was updated successfully, but these errors were encountered: