-
-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the JAX RandomVariable
dispatcher
#1284
Conversation
394853b
to
61f7b50
Compare
I have now registered the JAX implementation for every Aesara I will now try to generalize the implementation, and improve the tests / make them more exhaustive. The current implementation is half way there, but not completely satisfactory: a user that has defined a custom |
0eafddd
to
1a25bd0
Compare
8eb9292
to
caaabf2
Compare
21bb62f
to
fd7be3c
Compare
The transpilation process of
We should use this pattern throughout the JAX backend to handle the There is one case I still need to handle. The following fails to compile while the corresponding JAX code is perfectly valid: import aesara
import aesara.tensor as at
x_at = at.dmatrix()
f = at.random.normal(0, 1, size=(x_at.shape[0],))
g_fn = aesara.function([x_at], f, mode=jax_mode)
aesara.dprint(f.inputs[1])
# MakeVector{dtype='int64'} [id A]
# |Subtensor{int64} [id B]
# |Shape [id C]
# | |<TensorType(float64, (?, ?))> [id D]
# |ScalarConstant{0} [id E] This is due to the way @jax_funcify.register(MakeVector)
def jax_funcify_MakeVector(op, **kwargs):
def makevector(*x):
return tuple(x) I am not sure what Finally, this is a subtlety, but while the following is perfectly valid Aesara code: ```python
import aesara
import aesara.tensor as at
x_at = at.dmatrix()
f = at.random.normal(0, 1, size=x_at.shape[0])
g_fn = aesara.function([x_at], f, mode=jax_mode) its equivalent will fail to compile in JAX with Static argumentsThere's a third case I did not consider for now, since it involves deeper changes in the backend, but that we should eventually support: import aesara
import aesara.tensor as at
size = at.iscalar()
x_rv = at.random.normal(0, 1, size=size)
fn = aesara.function([size], x_rv) In this case, JAX will happily JIT-compile the function if |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1284 +/- ##
==========================================
+ Coverage 74.27% 74.35% +0.07%
==========================================
Files 175 177 +2
Lines 48887 49036 +149
Branches 10375 10379 +4
==========================================
+ Hits 36312 36461 +149
- Misses 10282 10283 +1
+ Partials 2293 2292 -1
|
Currently the following fail to compile to JAX but shouldn't. I reproduce the graph of the
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng)
aesara.dprint(out.owner.inputs[1])
# InplaceDimShuffle{x} [id A]
# |Subtensor{int64} [id B]
# |Shape [id C]
# | |<TensorType(float64, (?, ?))> [id D]
# |ScalarConstant{1} [id E]
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
aesara.dprint(out.owner.inputs[1])
# MakeVector{dtype='int64'} [id A]
# |Subtensor{int64} [id B]
# |Shape [id C]
# | |<TensorType(float64, (?, ?))> [id D]
# |ScalarConstant{0} [id E] These two cases can be handled with a rewrite that replaces |
ce0eaf8
to
88f71b9
Compare
I added a rewrite that replaces How to make this work?The most pressing problem I have is that the linker does not seem to transpile the new import aesara.tensor as at
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
jax_fn = function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2,) The # normal_rv{0, (0, 0), floatX, False}.1 [id A] 3
# |RandomStateSharedVariable(<RandomState(MT19937) at 0x7FDD0F5C7140>) [id B]
# |JAXShapeTuple [id C] 2
# | |Subtensor{int64} [id D] 1
# | |Shape [id E] 0
# | | |<TensorType(float64, (?, ?))> [id F]
# | |ScalarConstant{0} [id G]
# |TensorConstant{11} [id H]
# |TensorConstant{0} [id I]
# |TensorConstant{1} [id J] But it seems to be skipped during transpilation: def jax_funcified_fgraph(tensor_variable, random_state_shared_variable):
# Shape(<TensorType(float64, (?, ?))>)
tensor_variable_1 = shape(tensor_variable)
# Subtensor{int64}(Shape.0, ScalarConstant{0})
tensor_variable_2 = subtensor(tensor_variable_1, scalar_constant)
# MakeVector{dtype='int64'}(Subtensor{int64}.0)
tensor_variable_3 = makevector(tensor_variable_2)
# normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7F3BE3FC7140>), JAXShapeTuple.0, TensorConstant{11}, TensorConstant{0}, TensorConstant{1})
variable, tensor_variable_5 = sample_fn(random_state_shared_variable, tensor_variable_4, tensor_constant, tensor_constant_1, tensor_constant_2) How to better organize the code?The rewrite currently lives in |
3ce5e00
to
07329ed
Compare
e459df7
to
f592345
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed some changes that should allow the new JAX-specific rewrite to run in "JAX"
mode.
The import order/circularity was a challenge that was addressed by moving the rewrite into a sub-package for the relevant rewrites (i.e. aesara.tensor.random
rewrites). In this case, the import question shifts toward the use of JAXShapeTuple
, which is defined in a module that requires an import of jax
, but that was answered with a "lazy" import inside the rewrite itself. This doesn't seem like the best approach, but it's not particularly bad; regardless, we should revisit this situation, because I can see us adding more backend-specific rewrite customizations like this, and we should have a good approach for that sooner than later.
For instance, we could define JAXShapeTuple
somewhere else, but the question is "Where?". Anyway, we can still move forward as is.
The JAX dispatcher for
RandomVariable
Op
s is currently broken, testing is lacking and the shapes are not handled at all.Shape
operator (two cases were the shape is aConcreteArray
); fail gracefully otherwise;StudentTRV
size
parameter is the (dynamic) shape of a tensor.RandomVariable
definitionGumbelRV
StandardNormalRV
RandomVariable
implementationMakeVector
andDimShuffle
when passed assize
parameter