-
-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the JAX Subtensor
and IncSubtensor
dispatcher
#1338
Conversation
b53c174
to
7fb0948
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1338 +/- ##
==========================================
+ Coverage 74.36% 74.66% +0.29%
==========================================
Files 177 177
Lines 49066 49050 -16
Branches 10379 10400 +21
==========================================
+ Hits 36488 36623 +135
+ Misses 10285 10131 -154
- Partials 2293 2296 +3
|
721e087
to
8814219
Compare
Here is another issue I would like to fix in this PR. When operating on scalars, import jax
import jax.numpy as jnp
@jax.jit
def fn():
a = 3 + 2
b = jnp.ones(3) + jnp.ones(3)
print(a)
print(b)
return a
@jax.jit
def fn_lax_add():
a = jax.lax.add(3, 2)
print(a)
return a
fn()
# 5
# Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>
fn_lax_add()
# Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> import jax
@jax.jit
def fn():
a = 3 * 2
print(a)
return a
@jax.jit
def fn_lax_mul():
a = jax.lax.mul(3, 2)
print(a)
return a
fn()
# 6
fn_lax_mul()
# Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> import jax
@jax.jit
def fn():
a = 3 / 2
print(a)
return a
@jax.jit
def fn_lax_div():
a = jax.lax.div(3, 2)
print(a)
return a
fn()
# 1.5
fn_lax_div()
# Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> Since this can cause problems down the line when the result of the operation is passed as a |
The problem is illustrated by the following MWE: import aesara
import aesara.tensor as at
import numpy as np
x = at.matrix('x')
shape = x.shape[0] + x.shape[1]
out = at.ones(shape)
fn = aesara.function((x,), out, mode="JAX")
try:
fn(np.ones((2,3)))
except Exception as e:
print(e)
# Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>,).
# If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
# Apply node that caused the error: Alloc(TensorConstant{1.0}, Elemwise{Add}[(0, 0)].0)
# Toposort index: 3
# Inputs types: [TensorType(float64, ()), TensorType(int64, ())]
# Inputs shapes: [(2, 3)]
# Inputs strides: [(24, 8)]
# Inputs values: ['not shown']
# Outputs clients: [['output']] While the equivalent JAX implementation: import jax
import jax.numpy as jnp
import numpy as np
@jax.jit
def fn(x):
shape = x.shape[0] + x.shape[1]
return jnp.ones(shape)
print(fn(np.ones((2,3))))
# [1. 1. 1. 1. 1.] |
d4d7ce9
to
8267fff
Compare
8267fff
to
4dea813
Compare
49b4c3d
to
b0de7b9
Compare
I have now fixed everything all the issues I identified in #1202 that were not related to the There is one last error that I do not understand in self = <aesara.tensor.shape.Reshape object at 0x7fbb71e8e9e0>
node = Reshape{2}(a, JAXShapeTuple.0), inp = [array([1., 2., 3., 4.]), None]
out_ = [[None]], params = Params(ndim:int32:2)
def perform(self, node, inp, out_, params):
x, shp = inp
(out,) = out_
> if len(shp) != self.ndim:
E TypeError: object of type 'NoneType' has no len() Do you see what might be the issue @brandonwillard ? |
Yeah, there shouldn't be a |
b0de7b9
to
2a3baa5
Compare
06382af
to
14bb49c
Compare
14bb49c
to
d937ab5
Compare
bd697dc
to
78d2dbf
Compare
Ready for review. |
In this PR I fix the outstanding issues with the
Subtensor
andIncSubtensor
dispatchers in the JAX backend, and a few other things that came up in #1202.jax.lax.dynamic_slice
Progress
RandomVariable
dispatcher #1284)start
,stop
,step
inARange
, raise if dynamicSubtensor
if slicing with dynamic length;Elemwise
operations on scalar values to Python operatorsIncSubtensor
jax.numpy.copy
directlyTensorFromScalar
as a pass-throughThis one requires more thoughts on scalars in JAX vs scalars in Aesara.
Reshape
implementation to make sure theshape
parameter is passed concrete values (waiting for Fix the JAXRandomVariable
dispatcher #1284)Allow combination of concrete values using Python operators asThey already areshape
andsize
arguments (we may need a rewrite and custom Ops likeJAXPythonAdd
to make this easier).This is a spinoff of #1202.