Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the JAX Subtensor and IncSubtensor dispatcher #1338

Merged
merged 14 commits into from
Dec 14, 2022

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Dec 7, 2022

In this PR I fix the outstanding issues with the Subtensor and IncSubtensor dispatchers in the JAX backend, and a few other things that came up in #1202.

  • Boolean mask arrays are not supported in JAX, except for some re-expressible logic that we should support via JAX-specific rewrites.
  • Constant indexing should be hard-coded in the JAX transpiled function;
  • Dynamic indexing is alllowed but should use jax.lax.dynamic_slice

Progress

  • Raise when using boolean mask arrays
  • Rewrites for boolean logic that can be expressed in a JIT-compilable way (waiting for Fix the JAX RandomVariable dispatcher #1284)
  • Hardcode constant start, stop, step in ARange, raise if dynamic
  • Raise in Subtensor if slicing with dynamic length;
  • Dispatch Elemwise operations on scalar values to Python operators
  • Fix IncSubtensor
  • Use jax.numpy.copy directly
  • Implement TensorFromScalar as a pass-through
    This one requires more thoughts on scalars in JAX vs scalars in Aesara.
  • Refactor the Reshape implementation to make sure the shape parameter is passed concrete values (waiting for Fix the JAX RandomVariable dispatcher #1284)
  • Allow combination of concrete values using Python operators as shape and size arguments (we may need a rewrite and custom Ops like JAXPythonAdd to make this easier). They already are

This is a spinoff of #1202.

@rlouf rlouf added enhancement New feature or request JAX Involves JAX transpilation labels Dec 7, 2022
@rlouf rlouf force-pushed the jax-fix-subtensor branch 3 times, most recently from b53c174 to 7fb0948 Compare December 7, 2022 14:39
@codecov
Copy link

codecov bot commented Dec 7, 2022

Codecov Report

Merging #1338 (78d2dbf) into main (bfcfe4b) will increase coverage by 0.29%.
The diff coverage is 91.82%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1338      +/-   ##
==========================================
+ Coverage   74.36%   74.66%   +0.29%     
==========================================
  Files         177      177              
  Lines       49066    49050      -16     
  Branches    10379    10400      +21     
==========================================
+ Hits        36488    36623     +135     
+ Misses      10285    10131     -154     
- Partials     2293     2296       +3     
Impacted Files Coverage Δ
aesara/link/jax/dispatch/elemwise.py 80.59% <50.00%> (ø)
aesara/link/jax/dispatch/shape.py 94.82% <75.00%> (+6.36%) ⬆️
aesara/tensor/rewriting/jax.py 86.44% <86.44%> (ø)
aesara/link/jax/dispatch/scalar.py 96.72% <95.74%> (-0.69%) ⬇️
aesara/link/jax/dispatch/basic.py 92.59% <100.00%> (+8.72%) ⬆️
aesara/link/jax/dispatch/subtensor.py 100.00% <100.00%> (+32.07%) ⬆️
aesara/link/jax/dispatch/tensor_basic.py 97.22% <100.00%> (+5.15%) ⬆️
aesara/tensor/basic.py 89.88% <0.00%> (-0.07%) ⬇️
tests/link/jax/test_subtensor.py
... and 2 more

@rlouf rlouf force-pushed the jax-fix-subtensor branch 4 times, most recently from 721e087 to 8814219 Compare December 8, 2022 08:31
@rlouf
Copy link
Member Author

rlouf commented Dec 8, 2022

Here is another issue I would like to fix in this PR. When operating on scalars, jax.lax.X returns a TracedArray when the correponding Python operator does not:

import jax
import jax.numpy as jnp

@jax.jit
def fn():
    a = 3 + 2
    b = jnp.ones(3) + jnp.ones(3)
    print(a)
    print(b)
    return a

@jax.jit
def fn_lax_add():
    a = jax.lax.add(3, 2)
    print(a)
    return a

fn()
# 5
# Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>

fn_lax_add()
# Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
import jax

@jax.jit
def fn():
    a = 3 * 2
    print(a)
    return a

@jax.jit
def fn_lax_mul():
    a = jax.lax.mul(3, 2)
    print(a)
    return a

fn()
# 6

fn_lax_mul()
# Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
import jax

@jax.jit
def fn():
    a = 3 / 2
    print(a)
    return a

@jax.jit
def fn_lax_div():
    a = jax.lax.div(3, 2)
    print(a)
    return a

fn()
# 1.5
fn_lax_div()
# Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

Since this can cause problems down the line when the result of the operation is passed as a size or shape parameter we should handle these cases explicitly in the backend. The problem was originally observed in #1202

@rlouf
Copy link
Member Author

rlouf commented Dec 8, 2022

The problem is illustrated by the following MWE:

import aesara
import aesara.tensor as at
import numpy as np

x = at.matrix('x')
shape = x.shape[0] + x.shape[1]
out = at.ones(shape)

fn = aesara.function((x,), out, mode="JAX")
try:
    fn(np.ones((2,3)))
except Exception as e:
    print(e)
# Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>,).
# If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
# Apply node that caused the error: Alloc(TensorConstant{1.0}, Elemwise{Add}[(0, 0)].0)
# Toposort index: 3
# Inputs types: [TensorType(float64, ()), TensorType(int64, ())]
# Inputs shapes: [(2, 3)]
# Inputs strides: [(24, 8)]
# Inputs values: ['not shown']
# Outputs clients: [['output']]

While the equivalent JAX implementation:

import jax
import jax.numpy as jnp
import numpy as np

@jax.jit
def fn(x):
    shape = x.shape[0] + x.shape[1]
    return jnp.ones(shape)

print(fn(np.ones((2,3))))
# [1. 1. 1. 1. 1.]

brandonwillard
brandonwillard previously approved these changes Dec 8, 2022
@rlouf
Copy link
Member Author

rlouf commented Dec 12, 2022

I have now fixed everything all the issues I identified in #1202 that were not related to the Scan implementation.

There is one last error that I do not understand in test_jas_Reshape_concrete_shape. The compiled JAX function returns the expected result, but Aesara fails with the following error message:

self = <aesara.tensor.shape.Reshape object at 0x7fbb71e8e9e0>
node = Reshape{2}(a, JAXShapeTuple.0), inp = [array([1., 2., 3., 4.]), None]
out_ = [[None]], params = Params(ndim:int32:2)

    def perform(self, node, inp, out_, params):
        x, shp = inp
        (out,) = out_
>       if len(shp) != self.ndim:
E       TypeError: object of type 'NoneType' has no len()

Do you see what might be the issue @brandonwillard ?

@brandonwillard
Copy link
Member

I have now fixed everything all the issues I identified in #1202 that were not related to the Scan implementation.

There is one last error that I do not understand in test_jas_Reshape_concrete_shape. The compiled JAX function returns the expected result, but Aesara fails with the following error message:

self = <aesara.tensor.shape.Reshape object at 0x7fbb71e8e9e0>
node = Reshape{2}(a, JAXShapeTuple.0), inp = [array([1., 2., 3., 4.]), None]
out_ = [[None]], params = Params(ndim:int32:2)

    def perform(self, node, inp, out_, params):
        x, shp = inp
        (out,) = out_
>       if len(shp) != self.ndim:
E       TypeError: object of type 'NoneType' has no len()

Do you see what might be the issue @brandonwillard ?

Yeah, there shouldn't be a JAXShapeTuple in a graph being evaluated in Python. Looks like we need to use separate optimization queries for each testing mode.

@rlouf rlouf force-pushed the jax-fix-subtensor branch 4 times, most recently from 06382af to 14bb49c Compare December 14, 2022 13:16
@rlouf rlouf force-pushed the jax-fix-subtensor branch 2 times, most recently from bd697dc to 78d2dbf Compare December 14, 2022 14:15
@rlouf
Copy link
Member Author

rlouf commented Dec 14, 2022

Ready for review.

@rlouf rlouf merged commit 5b165ea into aesara-devs:main Dec 14, 2022
@rlouf rlouf deleted the jax-fix-subtensor branch December 14, 2022 16:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request JAX Involves JAX transpilation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants