-
-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VBT vs brownian path slowdown #489
Comments
Maybe this isn't even solver specific, I just noticed using Heun instead, with dt = 0.1, VBT takes 27s but UBP takes3s , so maybe the real question is twofold: is there anyway to make VBT faster, and if not, is there any way to program around VBT so we can making adaptive stepping solvers, or differentiable solvers, assuming we can give maintain one of the three pillars of the UBP requirements? |
So VBT is sensitive to the tolerance used: decreasing the tolerances increases the computational cost. This is the reason you're seeing scaling with Other than that I think there is performance being left on the table with our VBT implementation:
These are things that I think will require some careful thought to fix, so they've never made it far enough up my to-do list. (What I really want is to an LRU cache on the evalutions of the loop body.) FWIW I did benchmark all of this when I first wrote all of this, and found that whilst there was a slowdown, it wasn't as dramatic as you're seeing here. It might be that you're in a case this is particularly pronounced, or something might have sneakily regressed without me noticing... (Semi-relatedly we also have this benchmark: https://github.com/patrick-kidger/diffrax/blob/main/benchmarks/brownian_tree_times.py) Another thing I'd love to see an implementation of some time is the Brownian Interval from this paper, but again that's fairly fiddly in JAX's model of computation. (Not impossible though I think.) |
When I compare the old to the new VBT, I didn't see a pronounced slowdown, so it doesn't seem like regression (although I am generally supportive of speed regression tests). But one thing I did notice, if if I make the (new) VBT over an array, rather than a pytree, e.g. # shape=jax.ShapeDtypeStruct((4,), dtype=jnp.float64),
shape={
"dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
"dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
}, and just parse the array and convert to dict in solver if not isinstance(cont, dict):
cont = {"dW": cont[:2], "dZ": cont[2:]} this is >2x faster than if I just have the VBT over a dictionary (from 11s to 4s). Maybe this is expected since the tree is iterated over (https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_brownian/tree.py#L323) and tree maps happen sequentially (I assume?). If this is the case, is there a reason not to just flatten? I could just be misreading it. I'm starting to have a lurking suspicion we will end up implementing a brownian path object, then diff check it against VBT to understand it better lol. I think we can definitely squeeze more performance from the VBT (and we can add those to the list of things to implement), it just feels like the slowdown shouldn't be this much even with a slightly suboptimal VBT.
Was there ever an implementation in something like torchsde? |
I think the iteration over tree leaves is indeed a mistake performance-wise. FWIW in principle JAX should be able to parallelize each call, but in practice it seems that it is not doing that... I think you're right, the better approach would be to have a single loop that acts over a PyTree. I'll tag @andyElking on this issue too.
Yup. (In practice the implementation is really rather complicated, and has a couple of footguns of its own!) |
I completely agree. In fact I've been eyeing that split by PyTree for a while now and am intending to refactor it soonish. In addition I am intending to add a LRU cache to _evaluate, but that can be a separate edit since it also requires changes to diffeqsolve. |
Yea, it's weird since the path is (probably) parallelizing it over the same pytree (https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_brownian/path.py#L126). Not sure why the inner map of VBT isn't.
Is the reason for implementing the VBT over the BI just less foot guns? Or easier in jax? |
Primarily just that the VBT is easier in JAX. The BI algorithm involves dynamically creating a tree, so to do that in JAX you'd have to preallocate a buffer and then use that to store pointers (indices) into itself. |
While implementing some weak solver schemes, I noticed that when I used the VBT as opposed to the UBP (unsafe brownian path) it was substantially (like ~5 mins vs ~10 seconds) slower. Using an UBP is fine in this case for us (since its a fixed step solver, and we aren't differentiating through the equation), but in the future is not ideal. Below is a MVC, but in summary:
u += g1 @ (_dW + chi1)
. Specifically, if I comment that out, I see a decrease in speed from ~11s to 3s. But in UBP it only goes from like 1.1s to 0.8s (so it isn't just the elimination of a matmul making that whole speed gap).All of this is a bit surprising since I just call the diffusion control once. Is there a way of using VBT's or integrating them into new solvers that avoids this slowdown, or am I just making some mistake in my usage of the VBT?
Here is the full code:
yields
9.37 s ± 281 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
for VBT and1.03 s ± 8.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
for UBP.The text was updated successfully, but these errors were encountered: