Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX backend fails with simple Scan examples #924

Open
Tracked by #1425
ricardoV94 opened this issue Apr 22, 2022 · 5 comments
Open
Tracked by #1425

JAX backend fails with simple Scan examples #924

ricardoV94 opened this issue Apr 22, 2022 · 5 comments
Assignees
Labels
bug Something isn't working help wanted Extra attention is needed JAX Involves JAX transpilation Scan Involves the `Scan` `Op`

Comments

@ricardoV94
Copy link
Contributor

ricardoV94 commented Apr 22, 2022

The documentation example fails, but that might be due to the variable number of n_steps and an intermediate Op that implies dynamic shape

import aesara
import aesara.tensor as at

k = at.iscalar("k")
A = at.vector("A")

result, _ = aesara.scan(fn=lambda prior_result, A: prior_result * A,
                              outputs_info=at.ones_like(A),
                              non_sequences=A,
                              n_steps=k)

final_result = result[-1]

power = aesara.function(inputs=[A, k], outputs=final_result, mode="JAX")

print(power(range(10), 2))  # TracerArrayConversionError
print(power(range(10), 4))
TracerArrayConversionError traceback
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpy/core/fromnumeric.py in ndim(a)
   3159     try:
-> 3160         return a.ndim
   3161     except AttributeError:

AttributeError: 'tuple' object has no attribute 'ndim'

During handling of the above exception, another exception occurred:

TracerArrayConversionError                Traceback (most recent call last)
~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/utils.py in streamline_default_f()
    201                 ):
--> 202                     thunk()
    203                     for old_s in old_storage:

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/basic.py in thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    663         ):
--> 664             outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    665 

    [... skipping hidden 14 frame]

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/utils.py in jax_funcified_fgraph(A, k)
     33     # AllocEmpty{dtype='float64'}(Elemwise{Composite{(Switch(GT(i0, i1), (i1 + i0), (i1 - i0)) + i2)}}[(0, 1)].0, Shape_i{0}.0)
---> 34     auto_1742 = allocempty(auto_2745, auto_125)
     35     # IncSubtensor{InplaceSet;:int64:}(AllocEmpty{dtype='float64'}.0, Rebroadcast{(0, False)}.0, ScalarConstant{1})

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/jax/dispatch.py in allocempty(*shape)
    264     def allocempty(*shape):
--> 265         return jnp.empty(shape, dtype=op.dtype)
    266 

~/miniconda3/envs/pymc/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in zeros(shape, dtype)
   2622   lax_internal._check_user_dtype_supported(dtype, "zeros")
-> 2623   shape = canonicalize_shape((shape,) if ndim(shape) == 0 else shape)
   2624   return lax.full(shape, 0, _jnp_dtype(dtype))

<__array_function__ internals> in ndim(*args, **kwargs)

~/miniconda3/envs/pymc/lib/python3.9/site-packages/numpy/core/fromnumeric.py in ndim(a)
   3161     except AttributeError:
-> 3162         return asarray(a).ndim
   3163 

~/miniconda3/envs/pymc/lib/python3.9/site-packages/jax/core.py in __array__(self, *args, **kw)
    469   def __array__(self, *args, **kw):
--> 470     raise TracerArrayConversionError(self)
    471 

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>
While tracing the function jax_funcified_fgraph at /tmp/tmpzgxn26e5:2 for jit, this concrete value was not available in Python because it depends on the value of the argument 'k'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_123337/442771908.py in <module>
     15 power = aesara.function(inputs=[A, k], outputs=final_result, mode="JAX")
     16 
---> 17 print(power(range(10), 2))
     18 print(power(range(10), 4))

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/compile/function/types.py in __call__(self, *args, **kwargs)
    962         try:
    963             outputs = (
--> 964                 self.fn()
    965                 if output_subset is None
    966                 else self.fn(output_subset=output_subset)

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/utils.py in streamline_default_f()
    204                         old_s[0] = None
    205             except Exception:
--> 206                 raise_with_op(fgraph, node, thunk)
    207 
    208         f = streamline_default_f

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/utils.py in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    527 
    528     try:
--> 529         exc_value = exc_type(
    530             str(exc_value) + detailed_err_msg + "\n" + "\n".join(hints)
    531         )

~/miniconda3/envs/pymc/lib/python3.9/site-packages/jax/_src/errors.py in __init__(self, tracer)
    320     super().__init__(
    321         "The numpy.ndarray conversion method __array__() was called on "
--> 322         f"the JAX Tracer object {tracer}{tracer._origin_msg()}")
    323 
    324 

AttributeError: 'str' object has no attribute '_origin_msg'

But even if we remove that variable n_steps, it still fails

import aesara
import aesara.tensor as at

k = at.iscalar("k")
A = at.vector("A")

result, _ = aesara.scan(fn=lambda prior_result, A: prior_result * A,
                              outputs_info=at.ones_like(A),
                              non_sequences=A,
                              n_steps=2)

final_result = result[-1]

power = aesara.function(inputs=[A], outputs=final_result, mode="JAX")

print(power(range(10)))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/utils.py in streamline_default_f()
    201                 ):
--> 202                     thunk()
    203                     for old_s in old_storage:

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/basic.py in thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    663         ):
--> 664             outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    665 

    [... skipping hidden 14 frame]

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/utils.py in jax_funcified_fgraph(A)
     13     # forall_inplace,cpu,scan_fn}(TensorConstant{2}, IncSubtensor{InplaceSet;:int64:}.0, A)
---> 14     auto_9778 = scan(auto_8633, auto_9774, A)
     15     # Subtensor{int64}(forall_inplace,cpu,scan_fn}.0, ScalarConstant{1})

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/jax/dispatch.py in scan(*outer_inputs)
    420     def scan(*outer_inputs):
--> 421         scan_args = ScanArgs(
    422             list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info

TypeError: __init__() missing 1 required positional argument: 'as_while'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_123337/3574921839.py in <module>
     14 power = aesara.function(inputs=[A], outputs=final_result, mode="JAX")
     15 
---> 16 print(power(range(10)))

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/compile/function/types.py in __call__(self, *args, **kwargs)
    962         try:
    963             outputs = (
--> 964                 self.fn()
    965                 if output_subset is None
    966                 else self.fn(output_subset=output_subset)

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/utils.py in streamline_default_f()
    204                         old_s[0] = None
    205             except Exception:
--> 206                 raise_with_op(fgraph, node, thunk)
    207 
    208         f = streamline_default_f

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/utils.py in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    536         # Some exception need extra parameter in inputs. So forget the
    537         # extra long error message in that case.
--> 538     raise exc_value.with_traceback(exc_trace)
    539 
    540 

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/utils.py in streamline_default_f()
    200                     thunks, order, post_thunk_old_storage
    201                 ):
--> 202                     thunk()
    203                     for old_s in old_storage:
    204                         old_s[0] = None

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/basic.py in thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    662             thunk_outputs=thunk_outputs,
    663         ):
--> 664             outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    665 
    666             for o_node, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):

    [... skipping hidden 14 frame]

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/utils.py in jax_funcified_fgraph(A)
     12     auto_9774 = incsubtensor(auto_9541, auto_8744, auto_9712)
     13     # forall_inplace,cpu,scan_fn}(TensorConstant{2}, IncSubtensor{InplaceSet;:int64:}.0, A)
---> 14     auto_9778 = scan(auto_8633, auto_9774, A)
     15     # Subtensor{int64}(forall_inplace,cpu,scan_fn}.0, ScalarConstant{1})
     16     auto_9760 = subtensor(auto_9778, auto_9712)

~/miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/link/jax/dispatch.py in scan(*outer_inputs)
    419 
    420     def scan(*outer_inputs):
--> 421         scan_args = ScanArgs(
    422             list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
    423         )

TypeError: __init__() missing 1 required positional argument: 'as_while'
Apply node that caused the error: Subtensor{int64}(forall_inplace,cpu,scan_fn}.0, ScalarConstant{1})
Toposort index: 6
Inputs types: [TensorType(float64, (None, None)), ScalarType(int64)]
Inputs shapes: [(10,)]
Inputs strides: [(8,)]
Inputs values: ['not shown']
Outputs clients: [['output']]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

This is brought up in #710, but it was not clear that it also affected simple scan cases.

@ricardoV94 ricardoV94 added JAX Involves JAX transpilation Scan Involves the `Scan` `Op` bug Something isn't working labels Apr 22, 2022
@twiecki
Copy link
Contributor

twiecki commented Apr 22, 2022

@junpenglao Any ideas? :)

@ricardoV94
Copy link
Contributor Author

PS: The issue seems to be a simple case of a variable missing in an internal function, but I remember trying to add it and other things still failing. Hopefully it is still a simple fix!

@ricardoV94 ricardoV94 added the help wanted Extra attention is needed label Apr 23, 2022
@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Jul 20, 2022

Related to #710, safe to say jax scan implementation is simply broken at the moment

@rlouf
Copy link
Member

rlouf commented Sep 1, 2022

Still fails on aesara==2.7.9

Traceback (most recent call last):
  File "<stdin>", line 14, in <module>
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/compile/function/__init__.py", line 317, in function
    fn = pfunc(
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/compile/function/pfunc.py", line 374, in pfunc
    return orig_function(
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/compile/function/types.py", line 1763, in orig_function
    fn = m.create(defaults)
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/compile/function/types.py", line 1656, in create
    _fn, _i, _o = self.linker.make_thunk(
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/basic.py", line 254, in make_thunk
    return self.make_all(
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/basic.py", line 697, in make_all
    thunks, nodes, jit_fn = self.create_jitable_thunk(
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/basic.py", line 646, in create_jitable_thunk
    converted_fgraph = self.fgraph_convert(
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/jax/linker.py", line 13, in fgraph_convert
    return jax_funcify(fgraph, **kwargs)
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/jax/dispatch.py", line 670, in jax_funcify_FunctionGraph
    return fgraph_to_python(
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/utils.py", line 741, in fgraph_to_python
    compiled_func = op_conversion_fn(
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/remil/.conda/envs/aesara-dev/lib/python3.10/site-packages/aesara/link/jax/dispatch.py", line 407, in jax_funcify_Scan
    inner_fg = FunctionGraph(op.inputs, op.outputs)
AttributeError: 'Scan' object has no attribute 'inputs'

@rlouf rlouf self-assigned this Sep 20, 2022
@rlouf
Copy link
Member

rlouf commented Sep 20, 2022

This looks like the dispatcher wasn't updated after a scan refactor. I'll give it a shot using the current Numba dispatcher as a reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed JAX Involves JAX transpilation Scan Involves the `Scan` `Op`
Projects
None yet
Development

No branches or pull requests

3 participants