Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VBT vs brownian path slowdown #489

Open
lockwo opened this issue Aug 16, 2024 · 7 comments
Open

VBT vs brownian path slowdown #489

lockwo opened this issue Aug 16, 2024 · 7 comments

Comments

@lockwo
Copy link
Contributor

lockwo commented Aug 16, 2024

While implementing some weak solver schemes, I noticed that when I used the VBT as opposed to the UBP (unsafe brownian path) it was substantially (like ~5 mins vs ~10 seconds) slower. Using an UBP is fine in this case for us (since its a fixed step solver, and we aren't differentiating through the equation), but in the future is not ideal. Below is a MVC, but in summary:

  • This is not the real solver, I ripped out most of everything just to make the code smaller
  • These numbers seem a bit small/microbenchmark-y but I have verified them on some larger problems as well, I am just using this small problem for speed and demonstration.
  • VBT is 10x slower here, and seems dominated by the line u += g1 @ (_dW + chi1). Specifically, if I comment that out, I see a decrease in speed from ~11s to 3s. But in UBP it only goes from like 1.1s to 0.8s (so it isn't just the elimination of a matmul making that whole speed gap).
  • The surprising thing is not that VBT is slower (I figured it would come with some overhead), but that it seems to scale as well. Specifically, if I decrease dt, it's not just some constant overhead but seems to increase. Maybe this is expected, but even for small problems with large dts this becomes prohibitive (see the original 5 min vs 10 seconds).

All of this is a bit surprising since I just call the diffusion control once. Is there a way of using VBT's or integrating them into new solvers that avoids this slowdown, or am I just making some mistake in my usage of the VBT?

Here is the full code:

import jax
from jax import numpy as jnp
import diffrax
from typing import ClassVar

_NORMAL_ONESIX_QUANTILE = -0.9674215661017014


def calc_threepoint_random(x):
    return jnp.where(
        jnp.abs(x) > -_NORMAL_ONESIX_QUANTILE,
        jnp.where(x < _NORMAL_ONESIX_QUANTILE, -1.0, 1.0),
        0.0,
    )


def calc_twopoint_random(x):
    return jnp.where(x > 0, 1.0, -1.0)


class Solver(diffrax.AbstractSolver):

    term_structure: ClassVar = diffrax.AbstractTerm
    interpolation_cls: ClassVar = diffrax.LocalLinearInterpolation

    def func(self, terms, t0, y0, args):
        return terms.vf(t0, y0, args)

    def init(self, terms, t0, t1, y0, args):
        return None

    def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
        drift = terms.terms[0]
        diffusion = terms.terms[1]
        cont = diffusion.contr(t0, t1)
        dt = t1 - t0
        dW_scaled = cont["dW"] / jnp.sqrt(dt)
        sq3dt = jnp.sqrt(3 * dt)
        _dW = sq3dt * calc_threepoint_random(dW_scaled)
        dZ_scaled = cont["dZ"]
        _dZ = calc_twopoint_random(dZ_scaled)
        xi = jnp.sqrt(dt) * _dZ[0]
        chi1 = (_dW**2 / xi - xi) / 2
        k1 = drift.vf(t0, y0, args)
        g1 = diffusion.vf(t0, y0, args)
        H02 = y0 + k1 * dt + g1 @ _dW
        k2 = drift.vf(t0, H02, args)
        H03 = y0 + k2 * dt + k1 * dt + g1 @ _dW
        k3 = drift.vf(t0, H03, args)
        u = y0 + k1 * dt + k2 * dt + k3 * dt
        u += g1 @ (_dW + chi1)
        dense_info = dict(y0=y0, y1=u)

        return u, None, dense_info, None, diffrax.RESULTS.successful

def drift(t, X, args):
    y1, y2 = X
    dy1 = -273 / 512 * y1
    dy2 = -1 // 160 * y1 - (-785 // 512 + jnp.sqrt(2) / 8) * y2
    return jnp.array([dy1, dy2])


def diffusion(t, X, args):
    y1, y2 = X
    g11 = 1 / 4 * y1
    g12 = 1 / 16 * y1
    g21 = (1 - 2 * jnp.sqrt(2)) / 4 * y1
    g22 = 1 // 10 * y1 + 1 // 16 * y2

    return jnp.array([[g11, g12], [g21, g22]])


t0, t1 = 0.0, 3.0
y0 = jnp.array([1.0, 1.0])


def solve_wrapper(dt, num_samples, use_tree):
    keys = jax.random.split(jax.random.key(42), num_samples)
    solver = Solver()
    saveat = diffrax.SaveAt(t1=True)

    def solve(key):
        if not use_tree:
            tree = diffrax.UnsafeBrownianPath(
                shape={
                    "dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                    "dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                },
                key=key,
            )
            terms = diffrax.MultiTerm(
                diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, tree)
            )
            return diffrax.diffeqsolve(
                terms,
                solver,
                t0,
                t1,
                dt0=dt,
                y0=y0,
                saveat=saveat,
                adjoint=diffrax.DirectAdjoint(),
            )
        else:
            tree = diffrax.VirtualBrownianTree(
                t0,
                t1,
                tol=dt / 2,
                shape={
                    "dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                    "dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                },
                key=key,
            )
            terms = diffrax.MultiTerm(
                diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, tree)
            )
            return diffrax.diffeqsolve(
                terms, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat
            )
            # , adjoint=diffrax.DirectAdjoint()) this is 2x slower

    return jax.jit(jax.vmap(solve))(keys).ys.squeeze(axis=1)
%%timeit
_ = solve_wrapper(1.0, 20 * 100_000, True).block_until_ready()

yields 9.37 s ± 281 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) for VBT and 1.03 s ± 8.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) for UBP.

@lockwo
Copy link
Contributor Author

lockwo commented Aug 16, 2024

Maybe this isn't even solver specific, I just noticed using Heun instead, with dt = 0.1, VBT takes 27s but UBP takes3s , so maybe the real question is twofold: is there anyway to make VBT faster, and if not, is there any way to program around VBT so we can making adaptive stepping solvers, or differentiable solvers, assuming we can give maintain one of the three pillars of the UBP requirements?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Aug 17, 2024

So VBT is sensitive to the tolerance used: decreasing the tolerances increases the computational cost. This is the reason you're seeing scaling with dt.

Other than that I think there is performance being left on the table with our VBT implementation:

  • on every sample is goes through its entire while loop again, despite the fact that most of those iterations are probably identical to the last sample.
  • across two adjacent steps [t0, t1], [t1, t2], then we end up evaluating vbt(t1) - vbt(t0) followed by vbt(t2) - vbt(t1) -- so that we actually evaluate vbt(t1) twice! This doubles the computational work.

These are things that I think will require some careful thought to fix, so they've never made it far enough up my to-do list. (What I really want is to an LRU cache on the evalutions of the loop body.) FWIW I did benchmark all of this when I first wrote all of this, and found that whilst there was a slowdown, it wasn't as dramatic as you're seeing here. It might be that you're in a case this is particularly pronounced, or something might have sneakily regressed without me noticing...

(Semi-relatedly we also have this benchmark: https://github.com/patrick-kidger/diffrax/blob/main/benchmarks/brownian_tree_times.py)

Another thing I'd love to see an implementation of some time is the Brownian Interval from this paper, but again that's fairly fiddly in JAX's model of computation. (Not impossible though I think.)

@lockwo
Copy link
Contributor Author

lockwo commented Aug 17, 2024

When I compare the old to the new VBT, I didn't see a pronounced slowdown, so it doesn't seem like regression (although I am generally supportive of speed regression tests). But one thing I did notice, if if I make the (new) VBT over an array, rather than a pytree, e.g.

              # shape=jax.ShapeDtypeStruct((4,), dtype=jnp.float64),
              shape={
                  "dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                  "dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
              },

and just parse the array and convert to dict in solver

if not isinstance(cont, dict):
     cont = {"dW": cont[:2], "dZ": cont[2:]}

this is >2x faster than if I just have the VBT over a dictionary (from 11s to 4s). Maybe this is expected since the tree is iterated over (https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_brownian/tree.py#L323) and tree maps happen sequentially (I assume?). If this is the case, is there a reason not to just flatten? I could just be misreading it.

I'm starting to have a lurking suspicion we will end up implementing a brownian path object, then diff check it against VBT to understand it better lol.

I think we can definitely squeeze more performance from the VBT (and we can add those to the list of things to implement), it just feels like the slowdown shouldn't be this much even with a slightly suboptimal VBT.

Another thing I'd love to see an implementation of some time is the Brownian Interval from this paper, but again that's fairly fiddly in JAX's model of computation. (Not impossible though I think.)

Was there ever an implementation in something like torchsde?

@patrick-kidger
Copy link
Owner

I think the iteration over tree leaves is indeed a mistake performance-wise. FWIW in principle JAX should be able to parallelize each call, but in practice it seems that it is not doing that...

I think you're right, the better approach would be to have a single loop that acts over a PyTree.

I'll tag @andyElking on this issue too.

Was there ever an implementation in something like torchsde?

Yup. torchsde is where I wrote the original canonical Brownian Interval implementation.

(In practice the implementation is really rather complicated, and has a couple of footguns of its own!)

@andyElking
Copy link
Contributor

I completely agree. In fact I've been eyeing that split by PyTree for a while now and am intending to refactor it soonish. In addition I am intending to add a LRU cache to _evaluate, but that can be a separate edit since it also requires changes to diffeqsolve.

@lockwo
Copy link
Contributor Author

lockwo commented Aug 21, 2024

FWIW in principle JAX should be able to parallelize each call, but in practice it seems that it is not doing that...

Yea, it's weird since the path is (probably) parallelizing it over the same pytree (https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_brownian/path.py#L126). Not sure why the inner map of VBT isn't.

(In practice the implementation is really rather complicated, and has a couple of footguns of its own!)

Is the reason for implementing the VBT over the BI just less foot guns? Or easier in jax?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Aug 21, 2024

Primarily just that the VBT is easier in JAX. The BI algorithm involves dynamically creating a tree, so to do that in JAX you'd have to preallocate a buffer and then use that to store pointers (indices) into itself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants