-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX backend fails with simple Scan examples #924
Comments
@junpenglao Any ideas? :) |
PS: The issue seems to be a simple case of a variable missing in an internal function, but I remember trying to add it and other things still failing. Hopefully it is still a simple fix! |
Related to #710, safe to say jax scan implementation is simply broken at the moment |
Still fails on Traceback (most recent call last):
File "<stdin>", line 14, in <module>
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/compile/function/__init__.py", line 317, in function
fn = pfunc(
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/compile/function/pfunc.py", line 374, in pfunc
return orig_function(
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/compile/function/types.py", line 1763, in orig_function
fn = m.create(defaults)
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/compile/function/types.py", line 1656, in create
_fn, _i, _o = self.linker.make_thunk(
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/basic.py", line 254, in make_thunk
return self.make_all(
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/basic.py", line 697, in make_all
thunks, nodes, jit_fn = self.create_jitable_thunk(
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/basic.py", line 646, in create_jitable_thunk
converted_fgraph = self.fgraph_convert(
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/jax/linker.py", line 13, in fgraph_convert
return jax_funcify(fgraph, **kwargs)
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/jax/dispatch.py", line 670, in jax_funcify_FunctionGraph
return fgraph_to_python(
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/utils.py", line 741, in fgraph_to_python
compiled_func = op_conversion_fn(
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/jax/dispatch.py", line 407, in jax_funcify_Scan
inner_fg = FunctionGraph(op.inputs, op.outputs)
AttributeError: 'Scan' object has no attribute 'inputs' |
This looks like the dispatcher wasn't updated after a scan refactor. I'll give it a shot using the current Numba dispatcher as a reference. |
The documentation example fails, but that might be due to the variable number of
n_steps
and an intermediateOp
that implies dynamic shapeTracerArrayConversionError traceback
But even if we remove that variable
n_steps
, it still failsThis is brought up in #710, but it was not clear that it also affected simple scan cases.
The text was updated successfully, but these errors were encountered: