You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
However, I get the a JaxStackTraceBeforeTransformation error (detailed error message below). I boiled down the code to a small working example (also provided below) and noted the error only occors when the equation in test_func contains an argument.
Since this issue seemed similar to one raised in an earlier post (#13629) I tried downgrading jax to version 0.4.23. I also tried setting up a fresh python environment with only the necessary packages installed. Nothing helped, so far. I'd appreciate your help :)
Working example:
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5
import jax
import jax.numpy as jnp
def f(t, y, _):
dp_dt = 0.9 * y
return dp_dt
b0 = 2 # init condition
data_ts = jnp.linspace(0, 20, 100)
data_sol = diffeqsolve(ODETerm(f), Tsit5(), t0=0, t1=20, dt0=0.01,
y0=(b0), saveat=SaveAt(ts=data_ts))
def fwd_test(coeff):
num_ts = 100
def test_func(t, y, _coeff):
dp_dt = y * _coeff #doesn't work
# dp_dt = y #works
return dp_dt
b0 = 2
model_ts = jnp.linspace(0, 20, num_ts)
# Note: larger dt0 so that it runs faster; this is about as large as it can go
model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
model_b = model_sol.ys
data_b = data_sol.ys
return jnp.sum((model_b - data_b)**2)
coeff = 1.
grads = jax.grad(fwd_test)(coeff)
Error message:
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()
File <frozen runpy>:88, in _run_code()
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel_launcher.py:18](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel_launcher.py#line=17)
16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\traitlets\config\application.py:1075](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/traitlets/config/application.py#line=1074), in launch_instance()
1074 app.initialize(argv)
-> 1075 app.start()
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelapp.py:739](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelapp.py#line=738), in start()
738 try:
--> 739 self.io_loop.start()
740 except KeyboardInterrupt:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\tornado\platform\asyncio.py:205](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/tornado/platform/asyncio.py#line=204), in start()
204 def start(self) -> None:
--> 205 self.asyncio_loop.run_forever()
File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\base_events.py:607](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/base_events.py#line=606), in run_forever()
606 while True:
--> 607 self._run_once()
608 if self._stopping:
File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\base_events.py:1919](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/base_events.py#line=1918), in _run_once()
1918 else:
-> 1919 handle._run()
1920 handle = None
File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\events.py:80](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/events.py#line=79), in _run()
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:545](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=544), in dispatch_queue()
544 try:
--> 545 await self.process_one()
546 except Exception:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:534](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=533), in process_one()
533 return
--> 534 await dispatch(*args)
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:437](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=436), in dispatch_shell()
436 if inspect.isawaitable(result):
--> 437 await result
438 except Exception:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\ipkernel.py:362](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/ipkernel.py#line=361), in execute_request()
361 self._associate_new_top_level_threads_with(parent_header)
--> 362 await super().execute_request(stream, ident, parent)
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:778](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=777), in execute_request()
777 if inspect.isawaitable(reply_content):
--> 778 reply_content = await reply_content
780 # Flush output before sending the reply.
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\ipkernel.py:449](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/ipkernel.py#line=448), in do_execute()
448 if accepts_params["cell_id"]:
--> 449 res = shell.run_cell(
450 code,
451 store_history=store_history,
452 silent=silent,
453 cell_id=cell_id,
454 )
455 else:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\zmqshell.py:549](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/zmqshell.py#line=548), in run_cell()
548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3075](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3074), in run_cell()
3074 try:
-> 3075 result = self._run_cell(
3076 raw_cell, store_history, silent, shell_futures, cell_id
3077 )
3078 finally:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3130](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3129), in _run_cell()
3129 try:
-> 3130 result = runner(coro)
3131 except BaseException as e:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\async_helpers.py:129](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/async_helpers.py#line=128), in _pseudo_sync_runner()
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3334](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3333), in run_cell_async()
3331 interactivity = "none" if silent else self.ast_node_interactivity
-> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3335 interactivity=interactivity, compiler=compiler, result=result)
3337 self.last_execution_succeeded = not has_raised
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3517](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3516), in run_ast_nodes()
3516 asy = compare(code)
-> 3517 if await self.run_code(code, result, async_=asy):
3518 return True
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3577](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3576), in run_code()
3576 else:
-> 3577 exec(code_obj, self.user_global_ns, self.user_ns)
3578 finally:
3579 # Reset our crash handler in place
Cell In[1], line 32
31 coeff = 1.
---> 32 grads = jax.grad(fwd_test)(coeff)
33 # print(grads)
Cell In[1], line 24, in fwd_test()
23 # Note: larger dt0 so that it runs faster; this is about as large as it can go
---> 24 model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
25 y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
26 model_b = model_sol.ys
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:823](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=822), in diffeqsolve()
819 #
820 # Main loop
821 #
--> 823 final_state, aux_stats = adjoint.loop(
824 args=args,
825 terms=terms,
826 solver=solver,
827 stepsize_controller=stepsize_controller,
828 discrete_terminating_event=discrete_terminating_event,
829 saveat=saveat,
830 t0=t0,
831 t1=t1,
832 dt0=dt0,
833 max_steps=max_steps,
834 init_state=init_state,
835 throw=throw,
836 passed_solver_state=passed_solver_state,
837 passed_controller_state=passed_controller_state,
838 )
840 #
841 # Finish up
842 #
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\adjoint.py:286](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/adjoint.py#line=285), in loop()
285 msg = None
--> 286 final_state = self._loop(
287 terms=terms,
288 saveat=saveat,
289 init_state=init_state,
290 max_steps=max_steps,
291 inner_while_loop=inner_while_loop,
292 outer_while_loop=outer_while_loop,
293 **kwargs,
294 )
295 if msg is not None:
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:429](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=428), in loop()
427 del filter_state
--> 429 final_state = outer_while_loop(
430 cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
431 )
433 def _save_t1(subsaveat, save_state):
File [~\AppData\Local\Programs\Python\Python311\Lib\contextlib.py:81](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/contextlib.py#line=80), in inner()
80 with self._recreate_cm():
---> 81 return func(*args, **kwds)
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\loop.py:107](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/loop.py#line=106), in while_loop()
106 del kind, base
--> 107 return checkpointed_while_loop(
108 cond_fun,
109 body_fun,
110 init_val,
111 max_steps=max_steps,
112 buffers=buffers,
113 checkpoints=checkpoints,
114 )
115 elif kind == "bounded":
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:247](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=246), in checkpointed_while_loop()
246 cond_fun_ = jtu.tree_map(_stop_gradient, cond_fun_)
--> 247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
248 vjp_arg = (init_val_, body_fun_)
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\common.py:463](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/common.py#line=462), in new_body_fun()
462 buffer_val = _wrap_buffers(val, pred, tag)
--> 463 buffer_val2 = body_fun(buffer_val)
464 # Needed to work with `disable_jit`, as then we lose the automatic
465 # ArrayLike->Array cast provided by JAX's while loops.
466 # The input `val` is already cast to Array below, so this matches that.
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:219](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=218), in body_fun()
214 #
215 # Actually do some differential equation solving! Make numerical steps, adapt
216 # step sizes, all that jazz.
217 #
--> 219 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
220 terms,
221 state.tprev,
222 state.tnext,
223 state.y,
224 args,
225 state.solver_state,
226 state.made_jump,
227 )
229 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
230 # we get a negative value for y, and then get a NaN vector field. (And then
231 # everything breaks.) See #143.
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\solver\runge_kutta.py:1041](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/solver/runge_kutta.py#line=1040), in step()
1035 # Needs to be an `eqxi.while_loop` as:
1036 # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one
1037 # more stage on the first step.
1038 # (b) to work around a limitation of JAX's autodiff being unable to express
1039 # "triangular computations" (every stage depends on all previous stages)
1040 # without spurious copies.
-> 1041 final_val = eqxi.while_loop(
1042 cond_stage,
1043 rk_stage,
1044 init_val,
1045 max_steps=num_stages,
1046 buffers=buffers,
1047 kind="checkpointed" if self.scan_kind is None else self.scan_kind,
1048 checkpoints=num_stages,
1049 base=num_stages,
1050 )
1051 _, y1, f1_for_fsal, _, _, fs, ks, result = final_val
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\loop.py:107](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/loop.py#line=106), in while_loop()
106 del kind, base
--> 107 return checkpointed_while_loop(
108 cond_fun,
109 body_fun,
110 init_val,
111 max_steps=max_steps,
112 buffers=buffers,
113 checkpoints=checkpoints,
114 )
115 elif kind == "bounded":
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:252](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=251), in checkpointed_while_loop()
249 final_val_ = _checkpointed_while_loop(
250 vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
251 )
--> 252 _, _, _, final_val = _stop_gradient_on_unperturbed(init_val_, final_val_, body_fun_)
253 return final_val
JaxStackTraceBeforeTransformation: TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Cell In[1], line 32
28 return jnp.sum((model_b - data_b)**2)
31 coeff = 1.
---> 32 grads = jax.grad(fwd_test)(coeff)
33 # print(grads)
[... skipping hidden 10 frame]
Cell In[1], line 24, in fwd_test(coeff)
22 model_ts = jnp.linspace(0, 20, num_ts)
23 # Note: larger dt0 so that it runs faster; this is about as large as it can go
---> 24 model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
25 y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
26 model_b = model_sol.ys
27 data_b = data_sol.ys
[... skipping hidden 27 frame]
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1272](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1271), in _stop_gradient_on_unperturbed_jvp(***failed resolving arguments***)
1268 del primals, tangents
1269 perturb_val, perturb_body_fun = jtu.tree_map(
1270 lambda _, t: t is not None, (init_val, body_fun), (t_init_val, t_body_fun)
1271 )
-> 1272 perturb_val = _resolve_perturb_val(
1273 init_val, body_fun, perturb_val, perturb_body_fun
1274 )
1275 t_final_val = jtu.tree_map(
1276 _perturb_to_tang, t_final_val, perturb_val, is_leaf=_is_none
1277 )
1278 return final_val, t_final_val
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1241](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1240), in _resolve_perturb_val(final_val, body_fun, perturb_final_val, perturb_body_fun)
1238 else:
1239 perturb_val = jtu.tree_map(operator.or_, perturb_val, new_perturb_val)
-> 1241 perturb_val = jax.eval_shape(_resolve_perturb_val_impl).value
1242 return perturb_val
[... skipping hidden 12 frame]
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1214](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1213), in _resolve_perturb_val.<locals>._resolve_perturb_val_impl()
1211 return _out
1213 # Not `jax.jvp`, so as not to error if `body_fun` has any `custom_vjp`s.
-> 1214 jax.linearize(_to_linearize, dynamic)
1215 if new_perturb_val is sentinel:
1216 # `_dynamic_out` in `_to_linearize` had no JVP tracers at all, despite
1217 # `_dynamic` having them. Presumably the user's `_body_fun` has no
1218 # differentiable dependency whatsoever.
1219 # This can happen if all the autograd is happening through
1220 # `perturb_body_fun`.
1221 return Static(perturb_val)
[... skipping hidden 5 frame]
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1207](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1206), in _resolve_perturb_val.<locals>._resolve_perturb_val_impl.<locals>._to_linearize(_dynamic)
1205 def _to_linearize(_dynamic):
1206 _body_fun, _val = combine(_dynamic, static)
-> 1207 _out = _body_fun(_val)
1208 _dynamic_out, _static_out = partition(_out, is_inexact_array)
1209 _dynamic_out = _record_symbolic_zeros(_dynamic_out)
[... skipping hidden 10 frame]
File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\_src\custom_derivatives.py:351](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/jax/_src/custom_derivatives.py#line=350), in _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args)
344 msg = ("Custom JVP rule must produce primal and tangent outputs with "
345 "corresponding shapes and dtypes, but got:\n{}")
346 disagreements = (
347 f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}"
348 for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out)
349 if av_et != av_t)
--> 351 raise TypeError(msg.format('\n'.join(disagreements)))
352 yield primals_out + tangents_out, (out_tree, primal_avals)
TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
System info (python version, jaxlib version, accelerator, etc.)
The error reported here is actually a TypeError being raised because of an issue with the return types in a jax.custom_jvp. It's hard to see from this error report exactly which custom_jvp is the culprit, but it seems like it must be something within diffrax or equinox, so I'd recommend opening the issue on the https://github.com/patrick-kidger/diffrax issue tracker.
I'm going to close this since it looks like the conversations in patrick-kidger/diffrax#513 are getting to the bottom of things. Please let me know if there's something I'm missing!
Description
Hi everyone,
based on this tutorial I tried to get started with Jax and neural ODEs: https://colab.research.google.com/drive/1ZlK36VgWy1vBjBNXjSUg6Cb-7zeoa3jh
However, I get the a JaxStackTraceBeforeTransformation error (detailed error message below). I boiled down the code to a small working example (also provided below) and noted the error only occors when the equation in test_func contains an argument.
Since this issue seemed similar to one raised in an earlier post (#13629) I tried downgrading jax to version 0.4.23. I also tried setting up a fresh python environment with only the necessary packages installed. Nothing helped, so far. I'd appreciate your help :)
Working example:
Error message:
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34
jaxlib: 0.4.34
numpy: 1.26.4
python: 3.11.1 (tags/v3.11.1:a7a450f, Dec 6 2022, 19:58:39) [MSC v.1934 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', release='10', version='10.0.19044', machine='AMD64')
jupyterlab: 4.2.2
diffrax: 0.4.1
The text was updated successfully, but these errors were encountered: