Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance issue with SDE solver #517

Open
pierreguilmin opened this issue Oct 21, 2024 · 16 comments
Open

Performance issue with SDE solver #517

pierreguilmin opened this issue Oct 21, 2024 · 16 comments

Comments

@pierreguilmin
Copy link

Hello,

When solving the (trivial) SDE $d y_t = -y_t\ dt + 0.2\ dW_t$, the Diffrax Euler solver is ~200x slower than a naive for loop. Am I doing something wrong? The speed difference is consistent across various SDEs, solvers, time steps dt, and number of trajectories, and it appears to be specific to SDE solvers.

import diffrax as dx
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt

# === simulation parameters
key = jax.random.PRNGKey(42)
t0 = 0
t1 = 1
y0 = 1.0
ndt = 101
dt = (t1 - t0) / (ndt - 1)
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.2

# === diffrax euler
brownian_motion = dx.VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=key)
solver = dx.Euler()
terms = dx.MultiTerm(dx.ODETerm(drift), dx.ControlTerm(diffusion, brownian_motion))
saveat = dx.SaveAt(ts=jnp.linspace(t0, t1, ndt))

@jax.jit
def diffrax_simu():
    return dx.diffeqsolve(terms, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat).ys

# === homemade euler
@jax.jit
def homemade_simu():
    dWs = jnp.sqrt(dt) * jax.random.normal(key, (ndt,))

    def step(y, dW):
        dy = drift(None, y, None) * dt + diffusion(None, y, None) * dW
        return y + dy, y

    return jax.lax.scan(step, 1.0, dWs)[-1]

# === plot a single trajectory
y = diffrax_simu()
plt.plot(y)
y = homemade_simu()
plt.plot(y)

# === benchmark
%timeit diffrax_simu().block_until_ready()
%timeit homemade_simu().block_until_ready()
5.39 ms ± 261 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
19.7 μs ± 899 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
@lockwo
Copy link
Contributor

lockwo commented Oct 21, 2024

I get them to be a lot closer by using UnsafeBrownianPath, which has less overhead than VBT. Diffrax is still a bit slower with this change on my machine, but the difference is smaller (and probably due to other overheads that diffrax does to enable more features).

There's also some risky (but often useful) changes to UBP we've made internally that I've been meaning to put in the fork, so you can definitely do a fair amount with modifications to UBP (being able to get through all 3 stated requirements).

@patrick-kidger
Copy link
Owner

Yup, VBT is often the cause of poor SDE performance. Really we need some kind of LRU caching to make it behave properly, but that doesn't seem to be easy in JAX -- I'm pretty sure it'd require both a new primitive ('cached_call_p') and a new transform. That's a fairly advanced project for someone to take on!

In the meantime I recommend UBP as the go-to for these kinds of normal 'just solve an SDE' applications.

@lockwo
Copy link
Contributor

lockwo commented Oct 21, 2024

I think a lot of people get turned off by the Unsafe in the name, maybe worth adding a sentence like this to the docs ("In the meantime I recommend UBP as the go-to for these kinds of normal 'just solve an SDE' applications.").

@gautierronan
Copy link

Thanks. Indeed using UBP does help but I understand it's quite restricted in terms of usage.

Diffrax is still a bit slower with this change on my machine, but the difference is smaller (and probably due to other overheads that diffrax does to enable more features).

It seems there is still a factor ~10-20 difference (irrespective of number of time steps) between the homemade solver and diffrax with UBP. I would have naively thought that any irrelevant computation would be jitted away. Could you elaborate on what diffrax with UBP does compared to the naive solver?

Diffrax (VBT): 7.51 ms ± 18.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Diffrax (UBP): 637 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Naive:         28.5 µs ± 147 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

@lockwo
Copy link
Contributor

lockwo commented Oct 22, 2024

Diffrax has a lot more checking/shaping/logging than the default implementation. You can see it reflected in the jaxprs:

diffrax
let _where = { lambda ; a:bool[] b:i32[] c:i32[]. let
    d:i32[] = select_n a c b
  in (d,) } in
let _where1 = { lambda ; e:bool[] f:f32[] g:f32[]. let
    h:f32[] = select_n e g f
  in (h,) } in
let _where2 = { lambda ; i:bool[] j:f32[] k:f32[]. let
    l:f32[] = select_n i k j
  in (l,) } in
let _where3 = { lambda ; m:bool[] n:i32[] o:f32[]. let
    p:f32[] = convert_element_type[new_dtype=float32 weak_type=False] n
    q:f32[] = select_n m o p
  in (q,) } in
{ lambda ; . let
    r:f32[4096] = pjit[
      name=diffrax_simu
      jaxpr={ lambda s:u32[2]; . let
          _:i32[] = add 1 1
          _:i32[] _:f32[] _:f32[] _:f32[4096] t:f32[4096] _:i32[] _:i32[] _:i32[]
            _:i32[] = pjit[
            name=diffeqsolve
            jaxpr={ lambda u:bool[] v:bool[] w:bool[] x:bool[]; y:u32[2]. let
                _:i32[] = add 1 1
                _:i32[] = pjit[
                  name=branched_error_if_impl
                  jaxpr={ lambda ; z:f32[]. let  in (0,) }
                ] 0.009999999776482582
                ba:bool[] = lt 0.0 1.0
                bb:i32[] = pjit[name=_where jaxpr=_where] ba 1 -1
                bc:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bb
                bd:f32[] = mul 0.0 bc
                be:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bb
                bf:f32[] = mul 1.0 be
                bg:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bb
                bh:f32[] = mul 0.009999999776482582 bg
                bi:f32[] = add bd bh
                bj:f32[] = min bi bf
                bk:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=True
                ] bb
                bl:f32[] = mul bk inf
                bm:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bl
                bn:f32[4096] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(4096,)
                ] bm
                bo:f32[4096] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(4096,)
                ] inf
                bp:f32[] = copy 1.0
                bq:f32[] = copy bd
                br:f32[] = copy bj
                bs:f32[] = copy bh
                bt:f32[4096] = copy bn
                bu:f32[4096] = copy bo
                bv:bool[] = lt bq bf
                bw:bool[] = and bv u
                bx:bool[] = copy bw
                _:i32[] _:bool[] _:bool[] by:f32[] bz:f32[] ca:f32[] _:bool[] cb:f32[]
                  cc:i32[] cd:i32[] ce:i32[] cf:i32[] cg:i32[] ch:f32[4096] ci:f32[4096]
                  cj:i32[] = while[
                  body_jaxpr={ lambda ; ck:i32[] cl:u32[2] cm:f32[] cn:f32[] co:bool[]
                      cp:i32[] cq:bool[] cr:bool[] cs:f32[] ct:f32[] cu:f32[] cv:bool[]
                      cw:f32[] cx:i32[] cy:i32[] cz:i32[] da:i32[] db:i32[] dc:f32[4096]
                      dd:f32[4096] de:i32[]. let
                      df:bool[] = eq ck 1
                      dg:f32[] = neg cu
                      dh:f32[] = pjit[name=_where jaxpr=_where1] df ct dg
                      di:bool[] = eq ck 1
                      dj:f32[] = neg ct
                      dk:f32[] = pjit[name=_where jaxpr=_where1] di cu dj
                      dl:f32[] = sub dk dh
                      dm:f32[] = convert_element_type[
                        new_dtype=float32
                        weak_type=False
                      ] ck
                      dn:f32[] = mul dm dl
                      do:bool[] = eq ck 1
                      dp:f32[] = neg cu
                      dq:f32[] = pjit[name=_where jaxpr=_where1] do ct dp
                      dr:bool[] = eq ck 1
                      ds:f32[] = neg ct
                      dt:f32[] = pjit[name=_where jaxpr=_where1] dr cu ds
                      _:i32[] = add 1 1
                      _:i32[] du:f32[] = pjit[
                        name=evaluate
                        jaxpr={ lambda ; dv:u32[2] dw:f32[] dx:f32[]. let
                            dy:f32[] = custom_jvp_call[
                              call_jaxpr={ lambda ; dz:f32[]. let  in (dz,) }
                              jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7ef9d5f12440>
                              num_consts=0
                              symbolic_zeros=False
                            ] dw
                            ea:f32[] = custom_jvp_call[
                              call_jaxpr={ lambda ; eb:f32[]. let  in (eb,) }
                              jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7ef9d5f125f0>
                              num_consts=0
                              symbolic_zeros=False
                            ] dx
                            ec:i32[] = bitcast_convert_type[new_dtype=int32] dy
                            ed:i32[] = bitcast_convert_type[new_dtype=int32] ea
                            ee:key<fry>[] = random_wrap[impl=fry] dv
                            ef:u32[] = convert_element_type[
                              new_dtype=uint32
                              weak_type=False
                            ] ec
                            eg:key<fry>[] = random_fold_in ee ef
                            eh:u32[2] = random_unwrap eg
                            ei:key<fry>[] = random_wrap[impl=fry] eh
                            ej:u32[] = convert_element_type[
                              new_dtype=uint32
                              weak_type=False
                            ] ed
                            ek:key<fry>[] = random_fold_in ei ej
                            el:u32[2] = random_unwrap ek
                            em:key<fry>[] = random_wrap[impl=fry] el
                            en:key<fry>[1] = random_split[shape=(1,)] em
                            eo:u32[1,2] = random_unwrap en
                            ep:u32[1,2] = slice[
                              limit_indices=(1, 2)
                              start_indices=(0, 0)
                              strides=(1, 1)
                            ] eo
                            eq:u32[2] = squeeze[dimensions=(0,)] ep
                            er:f32[] = sub ea dy
                            es:f32[] = sqrt er
                            _:f32[] = sub ea dy
                            et:key<fry>[] = random_wrap[impl=fry] eq
                            eu:f32[] = pjit[
                              name=_normal
                              jaxpr={ lambda ; ev:key<fry>[]. let
                                  ew:f32[] = pjit[
                                    name=_normal_real
                                    jaxpr={ lambda ; ex:key<fry>[]. let
                                        ey:f32[] = pjit[
                                          name=_uniform
                                          jaxpr={ lambda ; ez:key<fry>[] fa:f32[]
                                              fb:f32[]. let
                                              fc:u32[] = random_bits[
                                                bit_width=32
                                                shape=()
                                              ] ez
                                              fd:u32[] = shift_right_logical fc 9
                                              fe:u32[] = or fd 1065353216
                                              ff:f32[] = bitcast_convert_type[
                                                new_dtype=float32
                                              ] fe
                                              fg:f32[] = sub ff 1.0
                                              fh:f32[] = sub fb fa
                                              fi:f32[] = mul fg fh
                                              fj:f32[] = add fi fa
                                              fk:f32[] = reshape[
                                                dimensions=None
                                                new_sizes=()
                                              ] fj
                                              fl:f32[] = max fa fk
                                            in (fl,) }
                                        ] ex -0.9999999403953552 1.0
                                        fm:f32[] = erf_inv ey
                                        fn:f32[] = mul 1.4142135381698608 fm
                                      in (fn,) }
                                  ] ev
                                in (ew,) }
                            ] et
                            fo:f32[] = mul eu es
                          in (0, fo) }
                      ] cl dq dt
                      fp:f32[] = convert_element_type[
                        new_dtype=float32
                        weak_type=False
                      ] ck
                      fq:f32[] = mul fp du
                      fr:f32[] = convert_element_type[
                        new_dtype=float32
                        weak_type=False
                      ] ck
                      _:f32[] = mul ct fr
                      fs:f32[] = neg cs
                      ft:f32[] = mul dn fs
                      fu:f32[] = convert_element_type[
                        new_dtype=float32
                        weak_type=False
                      ] ck
                      _:f32[] = mul ct fu
                      fv:f32[] = dot_general[
                        dimension_numbers=(([], []), ([], []))
                        preferred_element_type=float32
                      ] 0.20000000298023224 fq
                      fw:f32[] = add ft fv
                      fx:f32[] = add cs fw
                      fy:f32[] = add cu cw
                      fz:f32[] = min cu cm
                      ga:f32[] = sub cm 9.999999974752427e-07
                      gb:bool[] = gt fy ga
                      gc:f32[] = sub cm fz
                      gd:f32[] = mul 0.5 gc
                      ge:f32[] = add fz gd
                      gf:f32[] = pjit[name=_where jaxpr=_where2] True cm ge
                      gg:f32[] = pjit[name=_where jaxpr=_where2] gb gf fy
                      gh:bool[] = eq cn cm
                      gi:f32[] = sub fz cn
                      gj:f32[] = pjit[name=_where jaxpr=_where3] gh 0 gi
                      gk:f32[] = sub cm cn
                      gl:f32[] = pjit[name=_where jaxpr=_where3] gh 1 gk
                      _:f32[] = div gj gl
                      gm:f32[] = pjit[name=_where jaxpr=_where2] True fx cs
                      gn:i32[] = add cy 1
                      go:i32[] = pjit[name=_where jaxpr=_where] True 1 0
                      gp:i32[] = add cz go
                      gq:i32[] = pjit[name=_where jaxpr=_where] True 0 1
                      gr:i32[] = add da gq
                      gs:bool[] = and True cq
                      gt:f32[] = copy fz
                      gu:f32[4096] = maybe_set[
                        i_static=None
                        i_treedef=PyTreeDef(*)
                        kwargs={}
                        makes_false_steps=False
                      ] gs dc gt de
                      gv:bool[] = and True cq
                      gw:f32[] = copy gm
                      gx:f32[4096] = maybe_set[
                        i_static=None
                        i_treedef=PyTreeDef(*)
                        kwargs={}
                        makes_false_steps=False
                      ] gv dd gw de
                      gy:i32[] = pjit[name=_where jaxpr=_where] True 1 0
                      gz:i32[] = add de gy
                      ha:f32[] = copy gm
                      hb:f32[] = copy fz
                      hc:f32[] = copy gg
                      hd:f32[] = copy cw
                      he:i32[] = copy gn
                      hf:i32[] = copy gp
                      hg:i32[] = copy gr
                      hh:i32[] = copy db
                      hi:f32[4096] = copy gu
                      _:bool[] = copy cq
                      hj:f32[4096] = copy gx
                      _:bool[] = copy cq
                      hk:i32[] = copy gz
                      hl:bool[] = copy cq
                      hm:f32[] = copy ha
                      hn:f32[] = copy cs
                      ho:f32[] = select_if_vmap hl hm hn
                      hp:bool[] = copy cq
                      hq:f32[] = copy hb
                      hr:f32[] = copy ct
                      hs:f32[] = select_if_vmap hp hq hr
                      ht:bool[] = copy cq
                      hu:f32[] = copy hc
                      hv:f32[] = copy cu
                      hw:f32[] = select_if_vmap ht hu hv
                      hx:bool[] = copy cq
                      hy:bool[] = copy False
                      hz:bool[] = copy cv
                      ia:bool[] = select_if_vmap hx hy hz
                      ib:bool[] = copy cq
                      ic:f32[] = copy hd
                      id:f32[] = copy cw
                      ie:f32[] = select_if_vmap ib ic id
                      if:bool[] = copy cq
                      ig:i32[] = copy 0
                      ih:i32[] = copy cx
                      ii:i32[] = select_if_vmap if ig ih
                      ij:bool[] = copy cq
                      ik:i32[] = copy he
                      il:i32[] = copy cy
                      im:i32[] = select_if_vmap ij ik il
                      in:bool[] = copy cq
                      io:i32[] = copy hf
                      ip:i32[] = copy cz
                      iq:i32[] = select_if_vmap in io ip
                      ir:bool[] = copy cq
                      is:i32[] = copy hg
                      it:i32[] = copy da
                      iu:i32[] = select_if_vmap ir is it
                      iv:bool[] = copy cq
                      iw:i32[] = copy hh
                      ix:i32[] = copy db
                      iy:i32[] = select_if_vmap iv iw ix
                      iz:bool[] = copy cq
                      ja:i32[] = copy hk
                      jb:i32[] = copy de
                      jc:i32[] = select_if_vmap iz ja jb
                      jd:i32[] = add cp 1
                      je:bool[] = lt hb cm
                      jf:bool[] = and je co
                      jg:bool[] = and cq jf
                    in (jd, jg, cq, ho, hs, hw, ia, ie, ii, im, iq, iu, iy, hi, hj,
                      jc) }
                  body_nconsts=5
                  cond_jaxpr={ lambda ; jh:f32[] ji:bool[] jj:i32[] jk:bool[] jl:bool[]
                      jm:f32[] jn:f32[] jo:f32[] jp:bool[] jq:f32[] jr:i32[] js:i32[]
                      jt:i32[] ju:i32[] jv:i32[] jw:f32[4096] jx:f32[4096] jy:i32[]. let
                      jz:bool[] = lt jn jh
                      ka:bool[] = and jz ji
                      kb:bool[] = unvmap_any ka
                      kc:bool[] = lt jj 4096
                      kd:bool[] = convert_element_type[
                        new_dtype=bool
                        weak_type=False
                      ] kc
                      ke:bool[] = and kb kd
                      kf:bool[] = nonbatchable[
                        allow_constant_across_batch=True
                        msg=Nonconstant batch. `equinox.internal.while_loop` has received a batch of values that were expected to be constant. This is probably an internal error in the library you are using.
                      ] ke
                    in (kf,) }
                  cond_nconsts=2
                ] bf v bb y bf bd w 0 bx True bp bq br False bs 0 0 0 0 0 bt bu 0
                kg:bool[] = lt bz bf
                kh:bool[] = and kg x
                ki:i32[] = pjit[
                  name=_where
                  jaxpr={ lambda ; kj:bool[] kk:i32[] kl:i32[]. let
                      km:i32[] = select_n kj kl kk
                    in (km,) }
                ] kh 1 cc
                _:f32[] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] by
                _:f32[] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] bz
                _:f32[] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] ca
                _:f32[] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] cb
                kn:i32[] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] ki
                ko:i32[] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] cd
                kp:i32[] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] ce
                kq:i32[] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] cf
                _:i32[] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] cg
                kr:f32[4096] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] ch
                ks:f32[4096] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] ci
                _:i32[] = nondifferentiable_backward[
                  msg=Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`.
                  symbolic=True
                ] cj
                kt:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bb
                ku:f32[4096] = mul kr kt
                kv:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bb
                kw:f32[] = mul bd kv
                kx:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bb
                ky:f32[] = mul bf kx
                kz:bool[] = eq kn 0
                la:bool[] = eq kn 8
                lb:bool[] = or kz la
                lc:bool[] = not lb
                _:i32[] = add 1 1
                _:i32[] ld:f32[] le:f32[] lf:f32[4096] lg:f32[4096] lh:i32[] li:i32[]
                  lj:i32[] lk:i32[] = pjit[
                  name=branched_error_if_impl
                  jaxpr={ lambda ; ll:f32[] lm:f32[] ln:f32[4096] lo:f32[4096] lp:i32[]
                      lq:i32[] lr:i32[] ls:i32[] lt:bool[] lu:i32[]. let
                      lv:bool[] = unvmap_any lt
                      lw:i32[] = unvmap_max lu
                      lx:f32[] ly:f32[] lz:f32[4096] ma:f32[4096] mb:i32[] mc:i32[]
                        md:i32[] me:i32[] = custom_jvp_call[
                        call_jaxpr={ lambda ; mf:f32[] mg:f32[] mh:f32[4096] mi:f32[4096]
                            mj:i32[] mk:i32[] ml:i32[] mm:i32[] mn:bool[] mo:i32[]. let
                            mp:i32[] = convert_element_type[
                              new_dtype=int32
                              weak_type=False
                            ] mn
                            mq:f32[] mr:f32[] ms:f32[4096] mt:f32[4096] mu:i32[]
                              mv:i32[] mw:i32[] mx:i32[] = cond[
                              branches=(
                                { lambda ; my_:i32[] mz:f32[] na:f32[] nb:f32[4096]
                                    nc:f32[4096] nd:i32[] ne:i32[] nf:i32[] ng:i32[]. let
                                    
                                  in (mz, na, nb, nc, nd, ne, nf, ng) }
                                { lambda ; nh:i32[] ni_:f32[] nj_:f32[] nk_:f32[4096]
                                    nl_:f32[4096] nm_:i32[] nn_:i32[] no_:i32[] np_:i32[]. let
                                    nq:f32[] nr:f32[] ns:f32[4096] nt:f32[4096] nu:i32[]
                                      nv:i32[] nw:i32[] nx:i32[] = pure_callback[
                                      callback=_FlatCallback(callback_func=<function _error.<locals>.raises at 0x7ef9d5b48160>, in_tree=PyTreeDef(((*,), {})))
                                      result_avals=(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[4096]), ShapedArray(float32[4096]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]))
                                      sharding=None
                                      vectorized=False
                                    ] nh
                                    ny:f32[] nz:f32[] oa:f32[4096] ob:f32[4096] oc:i32[]
                                      od:i32[] oe:i32[] of:i32[] = pure_callback[
                                      callback=_FlatCallback(callback_func=<function _error.<locals>.tpu_msg at 0x7ef9d5b481f0>, in_tree=PyTreeDef(((CustomNode(Solution[('t0', 't1', 'ts', 'ys', 'interpolation', 'stats', 'result', 'solver_state', 'controller_state', 'made_jump', 'event_mask'), (), ()], [*, *, *, *, None, {'max_steps': None, 'num_accepted_steps': *, 'num_rejected_steps': *, 'num_steps': *}, CustomNode(EnumerationItem[('_value',), ('_enumeration',), (<class 'diffrax._solution.RESULTS'>,)], [*]), None, None, None, None]), *), {})))
                                      result_avals=(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[4096]), ShapedArray(float32[4096]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]))
                                      sharding=None
                                      vectorized=False
                                    ] nq nr ns nt nu nv nw nx nh
                                  in (ny, nz, oa, ob, oc, od, oe, of) }
                              )
                            ] mp mo mf mg mh mi mj mk ml mm
                          in (mq, mr, ms, mt, mu, mv, mw, mx) }
                        jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7ef9d5b48790>
                        num_consts=0
                        symbolic_zeros=True
                      ] ll lm ln lo lp lq lr ls lv lw
                    in (0, lx, ly, lz, ma, mb, mc, md, me) }
                ] kw ky ku ks kp kq ko kn lc kn
              in (0, ld, le, lf, lg, lh, li, lj, lk) }
          ] s
        in (t,) }
    ] 
  in (r,) }
pure jax
{ lambda ; . let
    a:f32[101] = pjit[
      name=homemade_simu
      jaxpr={ lambda b:u32[2]; . let
          c:f32[] = sqrt 0.01
          d:key<fry>[] = random_wrap[impl=fry] b
          e:f32[101] = pjit[
            name=_normal
            jaxpr={ lambda ; f:key<fry>[]. let
                g:f32[101] = pjit[
                  name=_normal_real
                  jaxpr={ lambda ; h:key<fry>[]. let
                      i:f32[101] = pjit[
                        name=_uniform
                        jaxpr={ lambda ; j:key<fry>[] k:f32[] l:f32[]. let
                            m:f32[1] = broadcast_in_dim[
                              broadcast_dimensions=()
                              shape=(1,)
                            ] k
                            n:f32[1] = broadcast_in_dim[
                              broadcast_dimensions=()
                              shape=(1,)
                            ] l
                            o:u32[101] = random_bits[bit_width=32 shape=(101,)] j
                            p:u32[101] = shift_right_logical o 9
                            q:u32[101] = or p 1065353216
                            r:f32[101] = bitcast_convert_type[new_dtype=float32] q
                            s:f32[101] = sub r 1.0
                            t:f32[1] = sub n m
                            u:f32[101] = mul s t
                            v:f32[101] = add u m
                            w:f32[101] = max m v
                          in (w,) }
                      ] h -0.9999999403953552 1.0
                      x:f32[101] = erf_inv i
                      y:f32[101] = mul 1.4142135381698608 x
                    in (y,) }
                ] f
              in (g,) }
          ] d
          z:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
          ba:f32[101] = mul z e
          _:f32[] bb:f32[101] = scan[
            _split_transpose=False
            jaxpr={ lambda ; bc:f32[] bd:f32[]. let
                be:f32[] = neg bc
                bf:f32[] = mul be 0.01
                bg:f32[] = mul 0.20000000298023224 bd
                bh:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bf
                bi:f32[] = add bh bg
                bj:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] bc
                bk:f32[] = add bj bi
              in (bk, bc) }
            length=101
            linear=(False, False)
            num_carry=1
            num_consts=0
            reverse=False
            unroll=1
          ] 1.0 ba
        in (bb,) }
    ] 
  in (a,) }

I believe most of this comes from the UBP, since if I do

@jax.jit
def homemade_simu():
    ts = jnp.linspace(t0, t1, ndt)

    def step(y, t):
        dw = brownian_motion.evaluate(t, t + dt)
        dy = drift(None, y, None) * dt + diffusion(None, y, None) * dw
        return y + dy, y

    return jax.lax.scan(step, 1.0, ts)[-1]

I see the times are pretty much the same. Perhaps this does indicate that there is room for cutting down the speed costs of the UBP related overhead.

@patrick-kidger
Copy link
Owner

FWIW I think the speed difference here does seem unacceptably large. This seems like it should be improved.

Starting with the low-hanging fruit to be sure we're doing more of an equal comparison: can you try setting EQX_ON_ERROR=nan and diffeqsolve(throw=False), to disable all error checks. Those are fairly slow.

Also, can you try using stepsize_controller=StepTo(...). By default Diffrax does not recompile if the number of steps changes (e.g. because t1 changes), but a lax.scan implementation does. Diffrax pays a small amount of runtime cost for this generality. Using StepTo instead bakes in the discretisation in the same way as a lax.scan.

@lockwo
Copy link
Contributor

lockwo commented Oct 22, 2024

With throw=False, EQX_ERROR=NAN and step to, this is what I see

code
import os

os.environ["EQX_ON_ERROR"] = "nan"
import diffrax as dx
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt

# === simulation parameters
key = jax.random.PRNGKey(42)
t0 = 0
t1 = 1
y0 = 1.0
ndt = 101
dt = (t1 - t0) / (ndt - 1)
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.2
steps = jnp.linspace(t0, t1, ndt)

brownian_motion = dx.UnsafeBrownianPath(shape=(), key=key)
solver = dx.Euler()
terms = dx.MultiTerm(dx.ODETerm(drift), dx.ControlTerm(diffusion, brownian_motion))
saveat = dx.SaveAt(steps=True)

@jax.jit
def diffrax_simu():
    return dx.diffeqsolve(terms, solver, t0, t1, dt0=None, y0=y0, saveat=saveat, adjoint=dx.DirectAdjoint(), throw=False, stepsize_controller=dx.StepTo(ts=steps)).ys

@jax.jit
def homemade_simu():
    dWs = jnp.sqrt(dt) * jax.random.normal(key, (ndt,))

    def step(y, dW):
        dy = drift(None, y, None) * dt + diffusion(None, y, None) * dW
        return y + dy, y

    return jax.lax.scan(step, 1.0, dWs)[-1]


y = diffrax_simu().block_until_ready()
plt.plot(y)
y = homemade_simu().block_until_ready()
plt.plot(y)
plt.show()


%timeit _ = diffrax_simu().block_until_ready()
%timeit _ = homemade_simu().block_until_ready()

(diffrax top, custom bottom)

2.18 ms ± 351 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
109 µs ± 25.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

(without any of those things I had):

2.43 ms ± 666 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
110 µs ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

(all on CPU, just a slower CPU, but the 20-30x slowdown seems of the same scale)

@patrick-kidger
Copy link
Owner

patrick-kidger commented Oct 22, 2024

So you definitely don't want DirectAdjoint: this is actually really slow and should be avoided if possible. (It exists to handle some autodiff edge cases, I'd love to remove it sometime...) Use the default instead.

Make sure you include an argument (say y0) to both jitted functions -- XLA may have different behavior around constant folding.

I'd also try with and without SaveAt(steps=True). (And adjusting the scan appropriately.) I think this should be equivalent either way but I'm not 100% certain.

With all of the above in, then at that point there shouldn't actually be that much difference between the two implementations. (And if there is then we should figure out what.)

@lockwo
Copy link
Contributor

lockwo commented Oct 22, 2024

The default actually errors with UBP which is why I changed to direct adjoint

ValueError: `adjoint=RecursiveCheckpointAdjoint()` does not support `UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` instead.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Oct 22, 2024

Ah, right. I've just checked and in the case of an unsafe SDE we do actually arrange for DirectAdjoint to do a scan so that should be fine:

if is_unsafe_sde(terms):

(In retrospect I think we could have arranged for the default adjoint to also do the same thing, that might be a small usability improvement.)

Anyway, that's everything off the top of my head -- I might be forgetting something but with these settings then I think Diffrax should be doing something similar to the simple lax.scan implementation. But clearly we're missing something!

(EDIT: we still have one discrepancy I have just noticed: generating the Brownian samples in advance vs on-the-fly.)

If you'd like to dig into this then it might be time to stare at some jaxprs or HLO for the two programs. If you want to do this at the jaxpr level then you might find eqxi.finalise_jaxpr(and friends) to be a useful set of tools here:

https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_finalise_jaxpr.py

Many primitives exist just to add e.g. an autodiff rule, so we can simplify our jaxprs down to what actually gets lowered by ignoring that and tracing through their impl rules instead.

@lockwo
Copy link
Contributor

lockwo commented Oct 22, 2024

DirectAjoint does slow things down, but not all the way. If I switch to a branch that allows for UBP + recursive adjoint, it's faster but still around ~4x gap. If I account for the fact that UBP has to split keys but the other doesn't, I get the gap to be around ~1.1-1.2 (which maybe isn't ideal, but seems much more reasonable to me given there's probably some other if statements/logging that might exist).

x = Timer(lambda : diffrax_simu(y0).block_until_ready())
print(x.timeit(number=100))
x = Timer(lambda : homemade_simu(y0).block_until_ready())
print(x.timeit(number=100))

with (above things, NAN, steps, function input, stepto, max steps, etc. all that) and direct adjoint:
0.002462916076183319
0.0005935421213507652

w/ checkpoint adjoint (on an internal branch that had some UBP changes to work with checkpoint):
0.002062791958451271
0.0005716248415410519

w/ both splitting keys:
0.0019747079350054264
0.001669874880462885

(code changed to:

@jax.jit
def homemade_simu(yy):

    def step(y1, dW):
        y, k = y1
        k, subkey = jax.random.split(k)
        dw = jnp.sqrt(dt) * jax.random.normal(subkey)
        dy = drift(None, y, None) * dt + diffusion(None, y, None) * dw
        return (y + dy, k), y

    return jax.lax.scan(step, (yy, key), steps)[-1]

)

@patrick-kidger
Copy link
Owner

Aha, interesting! Good to have more-or-less gotten to the bottom of the cause of this.

So:

  1. I'd be curious to see what your version of RecursiveCheckpointAdjoint does, and how that compares to the unsafe-SDE-branch of DirectAdjoint.
  2. I suppose generating the Brownian samples in advance, rather than on-the-fly, is very plausibly much faster. (Although I note that it will be more memory-intensive.) Off the top of my head I'm not immediately sure how to arrange it so that the case of using a constant step size controller and an UnsafeBrownianPath could make it possible to precompute things.

On point 2, I suspect the solution may require allowing the control to have additional state. (Which is also what we'd need to make VBT faster.) Perhaps it's time to bite that bullet and allow for that to happen. Happy to hear suggestions on this one!

@lockwo
Copy link
Contributor

lockwo commented Oct 24, 2024

  1. That is something I want to investigate as well (and also organize more of it pushed to a fork for others to check), admittedly will take a little bit for me to get to
  2. Would it be possible to add a "precompute" flag (or something to that effect) to UBP? Which would generate the noise ahead of time (and the size is just determined by the max steps or user input), without requiring a stateful approach. This might(?, if the dt multiplication is still done in the loop) also be compatible with adaptive solver that don't reject steps ("previsible" I think James calls them).
  3. I am in general an advocate of stateful controls (also discussed in Why can't UBP be backproped? #490), although I haven't thought much more on it since the discussion in that issue (which is very similar to how my stateful UBP is implemented).

@patrick-kidger
Copy link
Owner

  1. Okay, lmk what you find.
  2. I'm not sure. The way the controls are called at the moment is with the t, not the step index. We'd also have to have a way to pass the number of steps etc to the control. FWIW I'd probably lean towards not having a flag and just always doing this when possible.
  3. I think to do this 'properly' we might need to have AbstractSolver.step also accept the control state, and then pipe it through appropriately. Then also return the updated state. Unfortunately I think we're looking at a hard break to both the control and the solver APIs here, but c'est la vie.

@lockwo
Copy link
Contributor

lockwo commented Dec 6, 2024

I'm not sure. The way the controls are called at the moment is with the t, not the step index. We'd also have to have a way to pass the number of steps etc to the control. FWIW I'd probably lean towards not having a flag and just always doing this when possible.

Yes, looking at it more, this would probably have to be change/addon to support passing the "step" counter around. If this is an acceptable change, I don't think it would be too much for me to get a PR up.

I think to do this 'properly' we might need to have AbstractSolver.step also accept the control state, and then pipe it through appropriately. Then also return the updated state. Unfortunately I think we're looking at a hard break to both the control and the solver APIs here, but c'est la vie.

This was my conclusion as well, and I started drafting a branch for this, but figured it would require a pretty noticeable breaking change (at least internally), and I figured diffrax was more fait accompli than c'est la vie when it came to this level of breaking changes.

@patrick-kidger
Copy link
Owner

It's true, I try to avoid breaking changes where possible! They're no fun for anyone. But the performance issues discussed here genuinely are quite severe, so I think they're actually strong enough to motivate making a breaking change of this nature.

If this is an acceptable change, I don't think it would be too much for me to get a PR up.

Awesome, I'm looking forward to it! Let's see if we can get the stateful controls done at the same time? I'd like to contain the breaking changes to a single release, ideally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants