Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the JAX Scan dispatcher #1202

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
699 changes: 568 additions & 131 deletions aesara/link/jax/dispatch/scan.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion aesara/link/jax/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def shape(x):


@jax_funcify.register(Shape_i)
def jax_funcify_Shape_i(op, **kwargs):
def jax_funcify_Shape_i(op, node, **kwargs):
i = op.i

def shape_i(x):
Expand Down
2 changes: 2 additions & 0 deletions aesara/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def tensor_from_scalar(x):
@jax_funcify.register(ScalarFromTensor)
def jax_funcify_ScalarFromTensor(op, **kwargs):
def scalar_from_tensor(x):
if isinstance(x, (float, int)):
return x
return jnp.array(x).flatten()[0]

return scalar_from_tensor
6 changes: 5 additions & 1 deletion aesara/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
)

return jax_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
fgraph,
input_storage=input_storage,
storage_map=storage_map,
global_fgraph=fgraph,
**kwargs,
)

def jit_compile(self, fn):
Expand Down
39 changes: 17 additions & 22 deletions tests/link/jax/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def set_aesara_flags():

def compare_jax_and_py(
fgraph: FunctionGraph,
test_inputs: Iterable,
inputs: Iterable,
assert_fn: Optional[Callable] = None,
must_be_device_array: bool = True,
jax_mode=jax_mode,
updates=None,
):
"""Function to compare python graph output and jax compiled output for testing equality

Expand All @@ -56,34 +57,28 @@ def compare_jax_and_py(
assert_fn: func, opt
Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose
must_be_device_array: Bool
Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes
if this device array is found it indicates if the result was computed by jax

Returns
-------
jax_res
updates
Updates to be passed to `aesara.function`.

"""
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)

fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
aesara_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode)
jax_res = aesara_jax_fn(*test_inputs)
if isinstance(fgraph, tuple):
fn_inputs, fn_outputs = fgraph
else:
fn_inputs = fgraph.inputs
fn_outputs = fgraph.outputs

fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)]

if must_be_device_array:
if isinstance(jax_res, list):
assert all(
isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res
)
else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
aesara_py_fn = function(fn_inputs, fn_outputs, mode=py_mode, updates=updates)
py_res = aesara_py_fn(*inputs)

aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = aesara_py_fn(*test_inputs)
aesara_jax_fn = function(fn_inputs, fn_outputs, mode=jax_mode, updates=updates)
jax_res = aesara_jax_fn(*inputs)

if len(fgraph.outputs) > 1:
if len(fn_outputs) > 1:
for j, p in zip(jax_res, py_res):
assert_fn(j, p)
else:
Expand Down
4 changes: 1 addition & 3 deletions tests/link/jax/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def test_extra_ops():
indices = np.arange(np.product((3, 4)))
out = at_extra_ops.unravel_index(indices, (3, 4), order="C")
fgraph = FunctionGraph([], out)
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])


@pytest.mark.parametrize(
Expand Down
Loading