Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the JAX RandomVariable dispatcher #1284

Merged
merged 8 commits into from
Dec 9, 2022
Merged

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Nov 3, 2022

The JAX dispatcher for RandomVariable Ops is currently broken, testing is lacking and the shapes are not handled at all.

  • Refactor the dispatcher to handle the API differences in JAX;
  • Handle the shape when specified as a constant, or the output of a Shape operator (two cases were the shape is a ConcreteArray); fail gracefully otherwise;
  • Register an implementation for StudentTRV
  • Add comprehensive test suite
    • Every implementation needs to be tested;
    • Add test where the size parameter is the (dynamic) shape of a tensor.
    • Add test with implementation for custom RandomVariable definition
  • Register implementation for GumbelRV
  • Register implementation for StandardNormalRV
  • Open an issue for each missing RandomVariable implementation
  • Rewrite MakeVector and DimShuffle when passed as size parameter

@rlouf rlouf added enhancement New feature or request testing JAX Involves JAX transpilation random variables Involves random variables and/or sampling labels Nov 3, 2022
@rlouf rlouf force-pushed the rewrite-jax-rv branch 5 times, most recently from 394853b to 61f7b50 Compare December 2, 2022 15:03
@rlouf
Copy link
Member Author

rlouf commented Dec 2, 2022

I have now registered the JAX implementation for every Aesara RandomVariable that has a JAX implementation in the JAX library individually, and added some minimal test for each.

I will now try to generalize the implementation, and improve the tests / make them more exhaustive.

The current implementation is half way there, but not completely satisfactory: a user that has defined a custom RandomVariable would need to register jax_funcify_RandomVariable as well as make_sample_fn. Instead I should register jax_funcify_RandomVariable for all RandomVariable ops, and fail in the generic make_sample_fn if the name of the distribution cannot be found in the jax.random namespace. One would then only need to import make_sample_fn and register their implementation in JAX.

@rlouf rlouf force-pushed the rewrite-jax-rv branch 3 times, most recently from 0eafddd to 1a25bd0 Compare December 2, 2022 16:06
@rlouf rlouf marked this pull request as ready for review December 2, 2022 21:08
@rlouf rlouf marked this pull request as draft December 3, 2022 08:02
@rlouf rlouf closed this Dec 3, 2022
@rlouf rlouf reopened this Dec 3, 2022
@rlouf rlouf force-pushed the rewrite-jax-rv branch 4 times, most recently from 8eb9292 to caaabf2 Compare December 4, 2022 19:38
@rlouf rlouf marked this pull request as ready for review December 4, 2022 19:38
@rlouf rlouf force-pushed the rewrite-jax-rv branch 2 times, most recently from 21bb62f to fd7be3c Compare December 5, 2022 09:47
@rlouf
Copy link
Member Author

rlouf commented Dec 5, 2022

The transpilation process of RandomVariables now goes through two functions:

  • The first, registered for the generic RandomVariable handles the size parameters; if we know at compile-time it is a constant use this constant, if it is the output of a Shape operator then use the output of a previous node. These are the two situations in which JAX will not complain, so we fail gracefully otherwise.
  • The second, jax_sample_fn registers the implementation of the sampling function for each RandomVariable subtype individually. This way users can register an implementation for the RandomVariables they defined without having to worry about JAX's quirks.

We should use this pattern throughout the JAX backend to handle the shape arguments, which is one big pain point with this backend. Along with powerful shape inference on our side it should be enough to handle most problematic cases gracefully.

There is one case I still need to handle. The following fails to compile while the corresponding JAX code is perfectly valid:

import aesara
import aesara.tensor as at

x_at = at.dmatrix()
f = at.random.normal(0, 1, size=(x_at.shape[0],))
g_fn = aesara.function([x_at], f, mode=jax_mode)

aesara.dprint(f.inputs[1])
# MakeVector{dtype='int64'} [id A]
# |Subtensor{int64} [id B]
#   |Shape [id C]
#   | |<TensorType(float64, (?, ?))> [id D]
#   |ScalarConstant{0} [id E]

This is due to the way MakeVector is implemented: calling jnp.array(x) on this creates a TracedArray. This example works if I disable the check in funcify_RandomVariable and change the MakeVector implementation to:

@jax_funcify.register(MakeVector)
def jax_funcify_MakeVector(op, **kwargs):
    def makevector(*x):
        return tuple(x)

I am not sure what MakeVector is used for besides this case, but we may need to use a different Op here on the JAX side to differentiate between cases. I will also need to update my primitive isinstance(node.inputs[1], Shape) test to take this case into account, this type of checks will become routine in the backend.

Finally, this is a subtlety, but while the following is perfectly valid Aesara code:

```python
import aesara
import aesara.tensor as at

x_at = at.dmatrix()
f = at.random.normal(0, 1, size=x_at.shape[0])
g_fn = aesara.function([x_at], f, mode=jax_mode)

its equivalent will fail to compile in JAX with TypeError: jax.core.NamedShape() argument after * must be an iterable, not int

Static arguments

There's a third case I did not consider for now, since it involves deeper changes in the backend, but that we should eventually support:

import aesara
import aesara.tensor as at

size = at.iscalar()
x_rv = at.random.normal(0, 1, size=size)

fn = aesara.function([size], x_rv)

In this case, JAX will happily JIT-compile the function if size is in the static_argnums parameters. We would need a simple system that tracks the parameters that should be passed as static to jax.jit, and use this information at the end of the transpilation process when the function is jit-compiled. Alternatively we can defer jit-compilation to the user and let them set this parameter as static should they want to. The choice here matters, because JAX will re-compile the function each time a new value is passed for parameters marked as static.

@codecov
Copy link

codecov bot commented Dec 5, 2022

Codecov Report

Merging #1284 (e4090a9) into main (ae182f0) will increase coverage by 0.07%.
The diff coverage is 99.38%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1284      +/-   ##
==========================================
+ Coverage   74.27%   74.35%   +0.07%     
==========================================
  Files         175      177       +2     
  Lines       48887    49036     +149     
  Branches    10375    10379       +4     
==========================================
+ Hits        36312    36461     +149     
- Misses      10282    10283       +1     
+ Partials     2293     2292       -1     
Impacted Files Coverage Δ
aesara/compile/mode.py 84.47% <ø> (ø)
aesara/tensor/random/rewriting/basic.py 94.16% <ø> (ø)
aesara/link/jax/dispatch/shape.py 88.46% <93.33%> (+1.97%) ⬆️
aesara/link/jax/dispatch/random.py 100.00% <100.00%> (+2.63%) ⬆️
aesara/tensor/random/rewriting/__init__.py 100.00% <100.00%> (ø)
aesara/tensor/random/rewriting/jax.py 100.00% <100.00%> (ø)

@rlouf
Copy link
Member Author

rlouf commented Dec 5, 2022

Currently the following fail to compile to JAX but shouldn't. I reproduce the graph of the size parameter:

  • The size argument is an InplaceDimShuffle because Aesara converts scalar arguments to vectors
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng)

aesara.dprint(out.owner.inputs[1])
# InplaceDimShuffle{x} [id A]
# |Subtensor{int64} [id B]
#   |Shape [id C]
#   | |<TensorType(float64, (?, ?))> [id D]
#   |ScalarConstant{1} [id E]
  • The size argument is a MakeVector because the user inputs the shape as a tuple
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
aesara.dprint(out.owner.inputs[1])
# MakeVector{dtype='int64'} [id A]
# |Subtensor{int64} [id B]
#   |Shape [id C]
#   | |<TensorType(float64, (?, ?))> [id D]
#   |ScalarConstant{0} [id E]

These two cases can be handled with a rewrite that replaces MakeVector and DimShuffle((), ('x',)) when found as the size argument of a RandomVariable by a new ShapeTuple Op that returns its inputs in a tuple.

@rlouf rlouf force-pushed the rewrite-jax-rv branch 2 times, most recently from ce0eaf8 to 88f71b9 Compare December 6, 2022 13:44
@rlouf
Copy link
Member Author

rlouf commented Dec 6, 2022

I added a rewrite that replaces MakeVector and DimShuffle (when used to convert a scalar into a 1d vector) by a dummy JAXShapeTuple Op when found as a size input to a RandomVariable. This is my first time writing rewrites, and I need help to both make this work, and better organize the code.

How to make this work?

The most pressing problem I have is that the linker does not seem to transpile the new JAXShapeTuple (even though I registered it). On the following example:

import aesara.tensor as at

rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
jax_fn = function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2,)

The FunctionGraph instance passed to jax_funcify in aesara.link.jax.linker does contain the JAXShapeTuple Op instead of the original MakeVector Op:

# normal_rv{0, (0, 0), floatX, False}.1 [id A] 3
#  |RandomStateSharedVariable(<RandomState(MT19937) at 0x7FDD0F5C7140>) [id B]
#  |JAXShapeTuple [id C] 2
#  | |Subtensor{int64} [id D] 1
#  |   |Shape [id E] 0
#  |   | |<TensorType(float64, (?, ?))> [id F]
#  |   |ScalarConstant{0} [id G]
#  |TensorConstant{11} [id H]
#  |TensorConstant{0} [id I]
#  |TensorConstant{1} [id J]

But it seems to be skipped during transpilation:

  def jax_funcified_fgraph(tensor_variable, random_state_shared_variable):
      # Shape(<TensorType(float64, (?, ?))>)
      tensor_variable_1 = shape(tensor_variable)
      # Subtensor{int64}(Shape.0, ScalarConstant{0})
      tensor_variable_2 = subtensor(tensor_variable_1, scalar_constant)
      # MakeVector{dtype='int64'}(Subtensor{int64}.0)
      tensor_variable_3 = makevector(tensor_variable_2)
      # normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7F3BE3FC7140>), JAXShapeTuple.0, TensorConstant{11}, TensorConstant{0}, TensorConstant{1})
      variable, tensor_variable_5 = sample_fn(random_state_shared_variable, tensor_variable_4, tensor_constant, tensor_constant_1, tensor_constant_2)

How to better organize the code?

The rewrite currently lives in aesara.link.jax.linker as a standalone rewrite, but this is obviously not ideal. What is the recommended way to store and apply the rewrites here?

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed some changes that should allow the new JAX-specific rewrite to run in "JAX" mode.

The import order/circularity was a challenge that was addressed by moving the rewrite into a sub-package for the relevant rewrites (i.e. aesara.tensor.random rewrites). In this case, the import question shifts toward the use of JAXShapeTuple, which is defined in a module that requires an import of jax, but that was answered with a "lazy" import inside the rewrite itself. This doesn't seem like the best approach, but it's not particularly bad; regardless, we should revisit this situation, because I can see us adding more backend-specific rewrite customizations like this, and we should have a good approach for that sooner than later.

For instance, we could define JAXShapeTuple somewhere else, but the question is "Where?". Anyway, we can still move forward as is.

aesara/tensor/random/rewriting/jax.py Show resolved Hide resolved
tests/link/jax/test_random.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request important JAX Involves JAX transpilation random variables Involves random variables and/or sampling testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants