-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the JAX Scan
dispatcher
#1202
base: main
Are you sure you want to change the base?
Conversation
We can actually work around a lot of the dynamic indexing issues with Then I'll keep making my way through testing more and more scan features. |
e63a54a
to
6c61c95
Compare
6c61c95
to
4df1d5a
Compare
4df1d5a
to
d1326b8
Compare
After a lot of messing around I decided to go for a full rewrite and follow the Numba implementation. I have a minimal version that passes the first 3 That JAX easily complains about dynamic slicing may be a blessing in disguise as it highlights some gaps in Aesara's rewrites, e.g. with #1257 and others. Workarounds that I have currently had to implement could be easily avoided using the adequate rewrites at compile time. I also switched to run the test without rewrites, and I should probably start gathering a set of rewrites that would help with transpilation. How would we go about having backend-specific rewrites?
|
3cf33a6
to
6f5d668
Compare
Numba mode already specializes its rewrites, so check out its definition in |
4bd71bb
to
a1b7b5c
Compare
a1b7b5c
to
4bd71bb
Compare
4bd71bb
to
1212a1c
Compare
This is turning into a much bigger PR than expected as I am also trying to fix any issue that prevents me from running the
While I'm at it I am going to fix as many known issues with the JAX dispatcher as possible (issues and tests marked as |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1202 +/- ##
==========================================
+ Coverage 74.35% 74.49% +0.14%
==========================================
Files 177 173 -4
Lines 49046 48658 -388
Branches 10379 10390 +11
==========================================
- Hits 36468 36250 -218
+ Misses 10285 10112 -173
- Partials 2293 2296 +3
|
a761818
to
0183921
Compare
The following test with a def test_nit_sot_shared():
res, updates = scan(
fn=lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
0, 1, name="a"
),
n_steps=3,
)
jax_fn = function((), res, updates=updates, mode="JAX")
jax_res = jax_fn()
assert jax_res.shape == (3,) The values are correct, but the 65 gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
66 state_keys = ["key", "pos"]
67
68 for key in gen_keys:
69 if key not in data:
70 raise TypeError()
71
72 for key in state_keys:
73 if key not in data["state"]:
74 raise TypeError()
75
76 state_key = data["state"]["key"]
77 if state_key.shape == (624,) and state_key.dtype == np.uint32:
78 # TODO: Add an option to convert to a `RandomState` instance?
79 return data Indeed, the shared state for random variables in the JAX backend also contains a The Plus I don't think we need to carry this state around in the JAX backend, isn't only |
0183921
to
558ade1
Compare
973ec08
to
3de5fb3
Compare
2232d76
to
c7097dd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the commit entitled "Return a scalar when the tensor values is a scalar", is there an associated MWE/test case?
Also, the commit description mentions that ScalarFromTensor
is being called on scalars, and I want to make sure that those input scalars are TensorType
s scalars, and not ScalarType
scalars. The latter would imply that we're missing a rewrite for useless ScalarFromTensor
s.
2e7b3a8
to
fd37b21
Compare
import aesara
import aesara.tensor as at
a = at.iscalar("a")
x = at.arange(3)
out = x[:a]
aesara.dprint(out)
# Subtensor{:int32:} [id A]
# |ARange{dtype='int64'} [id B]
# | |TensorConstant{0} [id C]
# | |TensorConstant{3} [id D]
# | |TensorConstant{1} [id E]
# |ScalarFromTensor [id F]
# |a [id G]
try:
fn = aesara.function((a,), out, mode="JAX")
fn(1)
except Exception as e:
print(f"\n{e}")
# Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
# Apply node that caused the error: DeepCopyOp(Subtensor{:int32:}.0)
# Toposort index: 2
# Inputs types: [TensorType(int64, (None,))]
# Inputs shapes: [()]
# Inputs strides: [()]
# Inputs values: [array(1, dtype=int32)]
# Outputs clients: [['output']] In this case there are two solutions:
|
As I recall, the trouble with using that is that it's limited to only the (outermost) graph inputs, and we can't compose |
We can always ask users to JIT-compile functions themselves if that's the case, and raise a warning at compilation ("JAX will only be able to JIT-compile your function if you specifiy the {input_position}-th argument ({variable_name}) as static"). Given the number of issues with the JAX backend this work is uncovering, I decided to break the changes down in several smaller PRs and fix the issues unrelated to |
d718bdd
to
e528e44
Compare
The following code fails: import aesara
import aesara.tensor as at
a_at = at.dvector("a")
res, updates = aesara.scan(
fn=lambda a_t: 2 * a_t,
sequences=a_at
)
jax_fn = aesara.function((a_at,), res, updates=updates, mode="JAX")
jax_fn([0, 1, 2, 3, 4])
# IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
# Apply node that caused the error: Elemwise{mul,no_inplace}(TensorConstant{(1,) of 2.0}, Subtensor{int64:int64:int8}.0) I print the associated function graph: aesara.dprint(jax_fn)
# Elemwise{mul,no_inplace} [id A] 5
# |TensorConstant{(1,) of 2.0} [id B]
# |Subtensor{int64:int64:int8} [id C] 4
# |a [id D]
# |ScalarFromTensor [id E] 3
# | |Elemwise{Composite{Switch(LE(i0, i1), i1, i2)}}[(0, 0)] [id F] 2
# | |Shape_i{0} [id G] 0
# | | |a [id D]
# | |TensorConstant{0} [id H]
# | |TensorConstant{0} [id I]
# |ScalarFromTensor [id J] 1
# | |Shape_i{0} [id G] 0
# |ScalarConstant{1} [id K]
jax_fn.maker.fgraph.toposort()[4].tag
# scratchpad{'imported_by': ['local_subtensor_merge']}
jax_fn.maker.fgraph.toposort()[2].tag
# scratchpad{'imported_by': ['inplace_elemwise_optimizer']}
# jax_fn.maker.fgraph.toposort()[1].tag
scratchpad{'imported_by': ['local_subtensor_merge']} Several remarks:
|
The following code also fails, because of an import aesara
import aesara.tensor as at
from aesara.compile.mode import Mode
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.link.jax.linker import JAXLinker
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
res, updates = aesara.scan(
fn=lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1),
outputs_info=[
{"initial": at.as_tensor(1.0, dtype="floatX"), "taps": [-1]},
{"initial": at.as_tensor(0.5, dtype="floatX"), "taps": [-1]},
],
n_steps=10,
)
jax_fn = function((), res, updates=updates, mode=jax_mode)
aesara.dprint(jax_fn)
# Subtensor{int64::} [id A] 17
# |for{cpu,scan_fn}.0 [id B] 16
# | |TensorConstant{10} [id C]
# | |IncSubtensor{Set;:int64:} [id D] 15
# | | |AllocEmpty{dtype='float64'} [id E] 14
# | | | |Elemwise{add,no_inplace} [id F] 13
# | | | |TensorConstant{10} [id C]
# | | | |Subtensor{int64} [id G] 11
# | | | |Shape [id H] 10
# | | | | |Unbroadcast{0} [id I] 9
# | | | | |InplaceDimShuffle{x} [id J] 8
# | | | | |TensorConstant{1.0} [id K]
# | | | |ScalarConstant{0} [id L]
# | | |Unbroadcast{0} [id I] 9
# | | |ScalarFromTensor [id M] 12
# | | |Subtensor{int64} [id G] 11
# | |IncSubtensor{Set;:int64:} [id N] 7
# | |AllocEmpty{dtype='float64'} [id O] 6
# | | |Elemwise{add,no_inplace} [id P] 5
# | | |TensorConstant{10} [id C]
# | | |Subtensor{int64} [id Q] 3
# | | |Shape [id R] 2
# | | | |Unbroadcast{0} [id S] 1
# | | | |InplaceDimShuffle{x} [id T] 0
# | | | |TensorConstant{0.5} [id U]
# | | |ScalarConstant{0} [id V]
# | |Unbroadcast{0} [id S] 1
# | |ScalarFromTensor [id W] 4
# | |Subtensor{int64} [id Q] 3
# |ScalarConstant{1} [id X]
# Subtensor{int64::} [id Y] 18
# |for{cpu,scan_fn}.1 [id B] 16
# |ScalarConstant{1} [id Z]
# Inner graphs:
# for{cpu,scan_fn}.0 [id B]
# >Elemwise{mul,no_inplace} [id BA]
# > |TensorConstant{2} [id BB]
# > |*0-<TensorType(float64, ())> [id BC] -> [id D]
# >Elemwise{mul,no_inplace} [id BD]
# > |TensorConstant{2} [id BE]
# > |*1-<TensorType(float64, ())> [id BF] -> [id N]
# for{cpu,scan_fn}.1 [id B]
# >Elemwise{mul,no_inplace} [id BA]
# >Elemwise{mul,no_inplace} [id BD] JAX indeed complains that the input to import aesara
import aesara.tensor as at
from aesara.compile.mode import Mode
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.link.jax.linker import JAXLinker
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
res, updates = aesara.scan(
fn=lambda a_tm1: 2 * a_tm1,
outputs_info=[
{"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}
],
n_steps=6,
)
jax_fn = function((), res, updates=updates, mode=jax_mode)
aesara.dprint(jax_fn)
# Subtensor{int64::} [id A] 8
# |for{cpu,scan_fn} [id B] 7
# | |TensorConstant{6} [id C]
# | |IncSubtensor{Set;:int64:} [id D] 6
# | |AllocEmpty{dtype='float64'} [id E] 5
# | | |Elemwise{add,no_inplace} [id F] 4
# | | |TensorConstant{6} [id C]
# | | |Subtensor{int64} [id G] 2
# | | |Shape [id H] 1
# | | | |Subtensor{:int64:} [id I] 0
# | | | |TensorConstant{[0. 1.]} [id J]
# | | | |ScalarConstant{2} [id K]
# | | |ScalarConstant{0} [id L]
# | |Subtensor{:int64:} [id I] 0
# | |ScalarFromTensor [id M] 3
# | |Subtensor{int64} [id G] 2
# |ScalarConstant{2} [id N]
# Inner graphs:
# for{cpu,scan_fn} [id B]
# >Elemwise{mul,no_inplace} [id O]
# > |TensorConstant{2} [id P]
# > |*0-<TensorType(float64, ())> [id Q] -> [id D]
# fn = function((), res, updates=updates)
# assert np.allclose(fn(), jax_fn()) |
e528e44
to
a088c4b
Compare
I am currently waiting for #1338 to be merged to see what else needs to be fixed in the backend to allow the tests to pass. |
5eafcd5
to
0932c8e
Compare
4912edc
to
b09a40e
Compare
This PR tries to address the issues observed in #710 and #924 with the transpilation of
Scan
operators. Most importantly, we increase the test coverage ofScan
's functionalities.