Description
Hi,
based on this tutorial I tried to get started with Jax and neural ODEs: https://colab.research.google.com/drive/1ZlK36VgWy1vBjBNXjSUg6Cb-7zeoa3jh
However, I get the a JaxStackTraceBeforeTransformation error (detailed error message below). I boiled down the code to a small working example (also provided below) and noted the error only occors when the equation in test_func contains an argument.
Since this issue seemed similar to one raised in an earlier post (jax-ml/jax#13629) I tried downgrading jax to version 0.4.23. I also tried setting up a fresh python environment with only the necessary packages installed. Nothing helped, so far. I'd appreciate your help :)
(Even though it's labeled JaxStack... error, @dfm pointed out it might actually be a problem with diffrax: "The error reported here is actually a TypeError being raised because of an issue with the return types in a jax.custom_jvp. It's hard to see from this error report exactly which custom_jvp is the culprit, but it seems like it must be something within diffrax or equinox, so I'd recommend opening the issue on the https://github.com/patrick-kidger/diffrax issue tracker." jax-ml/jax#24253)
Working example:
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5
import jax
import jax.numpy as jnp
def f(t, y, _):
dp_dt = 0.9 * y
return dp_dt
b0 = 2 # init condition
data_ts = jnp.linspace(0, 20, 100)
data_sol = diffeqsolve(ODETerm(f), Tsit5(), t0=0, t1=20, dt0=0.01,
y0=(b0), saveat=SaveAt(ts=data_ts))
def fwd_test(coeff):
num_ts = 100
def test_func(t, y, _coeff):
dp_dt = y * _coeff #doesn't work
# dp_dt = y #works
return dp_dt
b0 = 2
model_ts = jnp.linspace(0, 20, num_ts)
# Note: larger dt0 so that it runs faster; this is about as large as it can go
model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
model_b = model_sol.ys
data_b = data_sol.ys
return jnp.sum((model_b - data_b)**2)
coeff = 1.
grads = jax.grad(fwd_test)(coeff)
Error message:
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()
File <frozen runpy>:88, in _run_code()
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel_launcher.py:18](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel_launcher.py#line=17)
16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\traitlets\config\application.py:1075](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/traitlets/config/application.py#line=1074), in launch_instance()
1074 app.initialize(argv)
-> 1075 app.start()
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelapp.py:739](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelapp.py#line=738), in start()
738 try:
--> 739 self.io_loop.start()
740 except KeyboardInterrupt:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\tornado\platform\asyncio.py:205](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/tornado/platform/asyncio.py#line=204), in start()
204 def start(self) -> None:
--> 205 self.asyncio_loop.run_forever()
File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\base_events.py:607](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/base_events.py#line=606), in run_forever()
606 while True:
--> 607 self._run_once()
608 if self._stopping:
File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\base_events.py:1919](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/base_events.py#line=1918), in _run_once()
1918 else:
-> 1919 handle._run()
1920 handle = None
File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\events.py:80](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/events.py#line=79), in _run()
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:545](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=544), in dispatch_queue()
544 try:
--> 545 await self.process_one()
546 except Exception:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:534](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=533), in process_one()
533 return
--> 534 await dispatch(*args)
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:437](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=436), in dispatch_shell()
436 if inspect.isawaitable(result):
--> 437 await result
438 except Exception:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\ipkernel.py:362](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/ipkernel.py#line=361), in execute_request()
361 self._associate_new_top_level_threads_with(parent_header)
--> 362 await super().execute_request(stream, ident, parent)
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:778](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=777), in execute_request()
777 if inspect.isawaitable(reply_content):
--> 778 reply_content = await reply_content
780 # Flush output before sending the reply.
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\ipkernel.py:449](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/ipkernel.py#line=448), in do_execute()
448 if accepts_params["cell_id"]:
--> 449 res = shell.run_cell(
450 code,
451 store_history=store_history,
452 silent=silent,
453 cell_id=cell_id,
454 )
455 else:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\zmqshell.py:549](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/zmqshell.py#line=548), in run_cell()
548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3075](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3074), in run_cell()
3074 try:
-> 3075 result = self._run_cell(
3076 raw_cell, store_history, silent, shell_futures, cell_id
3077 )
3078 finally:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3130](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3129), in _run_cell()
3129 try:
-> 3130 result = runner(coro)
3131 except BaseException as e:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\async_helpers.py:129](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/async_helpers.py#line=128), in _pseudo_sync_runner()
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3334](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3333), in run_cell_async()
3331 interactivity = "none" if silent else self.ast_node_interactivity
-> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3335 interactivity=interactivity, compiler=compiler, result=result)
3337 self.last_execution_succeeded = not has_raised
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3517](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3516), in run_ast_nodes()
3516 asy = compare(code)
-> 3517 if await self.run_code(code, result, async_=asy):
3518 return True
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3577](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3576), in run_code()
3576 else:
-> 3577 exec(code_obj, self.user_global_ns, self.user_ns)
3578 finally:
3579 # Reset our crash handler in place
Cell In[1], line 32
31 coeff = 1.
---> 32 grads = jax.grad(fwd_test)(coeff)
33 # print(grads)
Cell In[1], line 24, in fwd_test()
23 # Note: larger dt0 so that it runs faster; this is about as large as it can go
---> 24 model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
25 y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
26 model_b = model_sol.ys
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:823](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=822), in diffeqsolve()
819 #
820 # Main loop
821 #
--> 823 final_state, aux_stats = adjoint.loop(
824 args=args,
825 terms=terms,
826 solver=solver,
827 stepsize_controller=stepsize_controller,
828 discrete_terminating_event=discrete_terminating_event,
829 saveat=saveat,
830 t0=t0,
831 t1=t1,
832 dt0=dt0,
833 max_steps=max_steps,
834 init_state=init_state,
835 throw=throw,
836 passed_solver_state=passed_solver_state,
837 passed_controller_state=passed_controller_state,
838 )
840 #
841 # Finish up
842 #
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\adjoint.py:286](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/adjoint.py#line=285), in loop()
285 msg = None
--> 286 final_state = self._loop(
287 terms=terms,
288 saveat=saveat,
289 init_state=init_state,
290 max_steps=max_steps,
291 inner_while_loop=inner_while_loop,
292 outer_while_loop=outer_while_loop,
293 **kwargs,
294 )
295 if msg is not None:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:429](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=428), in loop()
427 del filter_state
--> 429 final_state = outer_while_loop(
430 cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
431 )
433 def _save_t1(subsaveat, save_state):
File [~\AppData\Local\Programs\Python\Python311\Lib\contextlib.py:81](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/contextlib.py#line=80), in inner()
80 with self._recreate_cm():
---> 81 return func(*args, **kwds)
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\loop.py:107](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/loop.py#line=106), in while_loop()
106 del kind, base
--> 107 return checkpointed_while_loop(
108 cond_fun,
109 body_fun,
110 init_val,
111 max_steps=max_steps,
112 buffers=buffers,
113 checkpoints=checkpoints,
114 )
115 elif kind == "bounded":
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:247](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=246), in checkpointed_while_loop()
246 cond_fun_ = jtu.tree_map(_stop_gradient, cond_fun_)
--> 247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
248 vjp_arg = (init_val_, body_fun_)
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\common.py:463](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/common.py#line=462), in new_body_fun()
462 buffer_val = _wrap_buffers(val, pred, tag)
--> 463 buffer_val2 = body_fun(buffer_val)
464 # Needed to work with `disable_jit`, as then we lose the automatic
465 # ArrayLike->Array cast provided by JAX's while loops.
466 # The input `val` is already cast to Array below, so this matches that.
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:219](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=218), in body_fun()
214 #
215 # Actually do some differential equation solving! Make numerical steps, adapt
216 # step sizes, all that jazz.
217 #
--> 219 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
220 terms,
221 state.tprev,
222 state.tnext,
223 state.y,
224 args,
225 state.solver_state,
226 state.made_jump,
227 )
229 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
230 # we get a negative value for y, and then get a NaN vector field. (And then
231 # everything breaks.) See #143.
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\solver\runge_kutta.py:1041](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/solver/runge_kutta.py#line=1040), in step()
1035 # Needs to be an `eqxi.while_loop` as:
1036 # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one
1037 # more stage on the first step.
1038 # (b) to work around a limitation of JAX's autodiff being unable to express
1039 # "triangular computations" (every stage depends on all previous stages)
1040 # without spurious copies.
-> 1041 final_val = eqxi.while_loop(
1042 cond_stage,
1043 rk_stage,
1044 init_val,
1045 max_steps=num_stages,
1046 buffers=buffers,
1047 kind="checkpointed" if self.scan_kind is None else self.scan_kind,
1048 checkpoints=num_stages,
1049 base=num_stages,
1050 )
1051 _, y1, f1_for_fsal, _, _, fs, ks, result = final_val
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\loop.py:107](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/loop.py#line=106), in while_loop()
106 del kind, base
--> 107 return checkpointed_while_loop(
108 cond_fun,
109 body_fun,
110 init_val,
111 max_steps=max_steps,
112 buffers=buffers,
113 checkpoints=checkpoints,
114 )
115 elif kind == "bounded":
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:252](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=251), in checkpointed_while_loop()
249 final_val_ = _checkpointed_while_loop(
250 vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
251 )
--> 252 _, _, _, final_val = _stop_gradient_on_unperturbed(init_val_, final_val_, body_fun_)
253 return final_val
JaxStackTraceBeforeTransformation: TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Cell In[1], line 32
28 return jnp.sum((model_b - data_b)**2)
31 coeff = 1.
---> 32 grads = jax.grad(fwd_test)(coeff)
33 # print(grads)
[... skipping hidden 10 frame]
Cell In[1], line 24, in fwd_test(coeff)
22 model_ts = jnp.linspace(0, 20, num_ts)
23 # Note: larger dt0 so that it runs faster; this is about as large as it can go
---> 24 model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
25 y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
26 model_b = model_sol.ys
27 data_b = data_sol.ys
[... skipping hidden 27 frame]
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1272](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1271), in _stop_gradient_on_unperturbed_jvp(***failed resolving arguments***)
1268 del primals, tangents
1269 perturb_val, perturb_body_fun = jtu.tree_map(
1270 lambda _, t: t is not None, (init_val, body_fun), (t_init_val, t_body_fun)
1271 )
-> 1272 perturb_val = _resolve_perturb_val(
1273 init_val, body_fun, perturb_val, perturb_body_fun
1274 )
1275 t_final_val = jtu.tree_map(
1276 _perturb_to_tang, t_final_val, perturb_val, is_leaf=_is_none
1277 )
1278 return final_val, t_final_val
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1241](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1240), in _resolve_perturb_val(final_val, body_fun, perturb_final_val, perturb_body_fun)
1238 else:
1239 perturb_val = jtu.tree_map(operator.or_, perturb_val, new_perturb_val)
-> 1241 perturb_val = jax.eval_shape(_resolve_perturb_val_impl).value
1242 return perturb_val
[... skipping hidden 12 frame]
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1214](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1213), in _resolve_perturb_val.<locals>._resolve_perturb_val_impl()
1211 return _out
1213 # Not `jax.jvp`, so as not to error if `body_fun` has any `custom_vjp`s.
-> 1214 jax.linearize(_to_linearize, dynamic)
1215 if new_perturb_val is sentinel:
1216 # `_dynamic_out` in `_to_linearize` had no JVP tracers at all, despite
1217 # `_dynamic` having them. Presumably the user's `_body_fun` has no
1218 # differentiable dependency whatsoever.
1219 # This can happen if all the autograd is happening through
1220 # `perturb_body_fun`.
1221 return Static(perturb_val)
[... skipping hidden 5 frame]
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1207](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1206), in _resolve_perturb_val.<locals>._resolve_perturb_val_impl.<locals>._to_linearize(_dynamic)
1205 def _to_linearize(_dynamic):
1206 _body_fun, _val = combine(_dynamic, static)
-> 1207 _out = _body_fun(_val)
1208 _dynamic_out, _static_out = partition(_out, is_inexact_array)
1209 _dynamic_out = _record_symbolic_zeros(_dynamic_out)
[... skipping hidden 10 frame]
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\_src\custom_derivatives.py:351](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/jax/_src/custom_derivatives.py#line=350), in _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args)
344 msg = ("Custom JVP rule must produce primal and tangent outputs with "
345 "corresponding shapes and dtypes, but got:\n{}")
346 disagreements = (
347 f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}"
348 for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out)
349 if av_et != av_t)
--> 351 raise TypeError(msg.format('\n'.join(disagreements)))
352 yield primals_out + tangents_out, (out_tree, primal_avals)
TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34
jaxlib: 0.4.34
numpy: 1.26.4
python: 3.11.1 (tags/v3.11.1:a7a450f, Dec 6 2022, 19:58:39) [MSC v.1934 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', release='10', version='10.0.19044', machine='AMD64')
jupyterlab: 4.2.2
diffrax: 0.4.1