Skip to content

Commit

Permalink
Rewrite minimal Scan dispatch for JAX
Browse files Browse the repository at this point in the history
Passes the first `xit_xot_types` test taken from the Numba test suite.
  • Loading branch information
rlouf committed Oct 14, 2022
1 parent 1ab4c69 commit d1326b8
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 146 deletions.
212 changes: 66 additions & 146 deletions aesara/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
@@ -1,159 +1,79 @@
import jax
import jax.numpy as jnp

from aesara.graph.fg import FunctionGraph
from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.scan.op import Scan
from aesara.scan.utils import ScanArgs


@jax_funcify.register(Scan)
def jax_funcify_Scan(op, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs)
jax_at_inner_func = jax_funcify(inner_fg, **kwargs)
def jax_funcify_Scan(op, node, **kwargs):
scan_inner_func = jax_funcify(op.fgraph)

def scan(*outer_inputs):
scan_args = ScanArgs(
list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
)

# `outer_inputs` is a list with the following composite form:
# [n_steps]
# + outer_in_seqs
# + outer_in_mit_mot
# + outer_in_mit_sot
# + outer_in_sit_sot
# + outer_in_shared
# + outer_in_nit_sot
# + outer_in_non_seqs
n_steps = scan_args.n_steps
seqs = scan_args.outer_in_seqs

# TODO: mit_mots
mit_mot_in_slices = []

mit_sot_in_slices = []
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
neg_taps = [abs(t) for t in tap if t < 0]
pos_taps = [abs(t) for t in tap if t > 0]
max_neg = max(neg_taps) if neg_taps else 0
max_pos = max(pos_taps) if pos_taps else 0
init_slice = seq[: max_neg + max_pos]
mit_sot_in_slices.append(init_slice)

sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]

init_carry = (
mit_mot_in_slices,
mit_sot_in_slices,
sit_sot_in_slices,
scan_args.outer_in_shared,
scan_args.outer_in_non_seqs,
)

def jax_args_to_inner_scan(op, carry, x):
# `carry` contains all inner-output taps, non_seqs, and shared
# terms
(
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
) = carry

# `x` contains the in_seqs
inner_in_seqs = x

# `inner_scan_inputs` is a list with the following composite form:
# inner_in_seqs
# + sum(inner_in_mit_mot, [])
# + sum(inner_in_mit_sot, [])
# + inner_in_sit_sot
# + inner_in_shared
# + inner_in_non_seqs
inner_in_mit_sot_flatten = []
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
inner_in_mit_sot_flatten.extend(array[jnp.array(index)])

inner_scan_inputs = sum(
[
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot_flatten,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
],
[],
)

return inner_scan_inputs

def inner_scan_outs_to_jax_outs(
op,
old_carry,
inner_scan_outs,
):
(
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
) = old_carry

def update_mit_sot(mit_sot, new_val):
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)

inner_out_mit_sot = [
update_mit_sot(mit_sot, new_val)
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
]

# This should contain all inner-output taps, non_seqs, and shared
# terms
if not inner_in_sit_sot:
inner_out_sit_sot = []
n_steps = outer_inputs[0]
outer_in_seqs = list(op.outer_seqs(outer_inputs))
outer_in_mit_mot = list(op.outer_mitmot(outer_inputs))
outer_in_mit_sot = list(op.outer_mitsot(outer_inputs))
outer_in_sit_sot = list(op.outer_sitsot(outer_inputs))
outer_in_nit_sot = list(op.outer_nitsot(outer_inputs))
outer_in_shared = list(op.outer_shared(outer_inputs))
outer_in_non_seqs = list(op.outer_non_seqs(outer_inputs))
if len(outer_in_mit_mot):
raise NotImplementedError("mit-mot not supported")
if len(outer_in_mit_sot):
raise NotImplementedError("mit-sot not supported")
if len(outer_in_sit_sot):
raise NotImplementedError("sit-sot not supported")
if len(outer_in_shared):
raise NotImplementedError("shared variables not supported")
if len(outer_in_non_seqs):
raise NotImplementedError("non sequence are not supported")

init_carry = outer_in_nit_sot
sequences = outer_in_seqs

def scan_inner_in_args(carry, x, is_dummy_sit_sot=True):
"""Create an inner-input expression.
Inner-inputs are ordered as follows:
- sequences
- mit-mot inputs
- mit-sot inputs
- sit-sot inputs
- shared-inputs
- non-sequences
"""

inner_in_seqs = x
if is_dummy_sit_sot:
inner_in_sit_sot = []
else:
inner_out_sit_sot = inner_scan_outs
new_carry = (
inner_in_mit_mot,
inner_out_mit_sot,
inner_out_sit_sot,
inner_in_shared,
inner_in_non_seqs,
)

return new_carry

def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = list(jax_at_inner_func(*inner_args))
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
return new_carry, inner_scan_outs

_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)

# We need to prepend the initial values so that the JAX output will
# match the raw `Scan` `Op` output and, thus, work with a downstream
# `Subtensor` `Op` introduced by the `scan` helper function.
def append_scan_out(scan_in_part, scan_out_part):
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)

if scan_args.outer_in_mit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
]
elif scan_args.outer_in_sit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
]

if len(scan_out_final) == 1:
scan_out_final = scan_out_final[0]
return scan_out_final
inner_in_sit_sot = carry
return sum([inner_in_seqs, inner_in_sit_sot], [])

def scan_new_carry(inner_outputs):
"""Create a new carry expression
Inner-outputs are ordered as follow:
- mit-mot-outputs
- mit-sot-outputs
- sit-sot-outputs
- nit-sots
- shared-outputs
[+ while-condition]
"""
carry = list(inner_outputs)
return carry

def body_fn(carry, x):
inner_in_args = scan_inner_in_args(carry, x)
inner_outputs = scan_inner_func(*inner_in_args)
carry = scan_new_carry(inner_outputs)
return carry, *inner_outputs

_, results = jax.lax.scan(body_fn, init_carry, sequences, length=n_steps)

return results

return scan
69 changes: 69 additions & 0 deletions tests/link/jax/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,85 @@
from packaging.version import parse as version_parse

import aesara.tensor as at
from aesara import function
from aesara.compile.mode import Mode
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.link.jax.linker import JAXLinker
from aesara.scan.basic import scan
from aesara.scan.op import Scan
from aesara.tensor.math import gammaln, log
from aesara.tensor.type import ivector, lscalar, scalar
from tests.link.jax.test_basic import compare_jax_and_py


jax = pytest.importorskip("jax")

# Disable all optimizations
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)


@pytest.mark.parametrize(
"fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check",
[
# sequences
(
lambda a_t: 2 * a_t,
[at.dvector("a")],
[{}],
[],
None,
[np.arange(10)],
None,
lambda op: op.info.n_seqs > 0,
),
],
)
def test_xit_xot_types(
fn,
sequences,
outputs_info,
non_sequences,
n_steps,
input_vals,
output_vals,
op_check,
):
"""Test basic xit-xot configurations."""
res, updates = scan(
fn,
sequences=sequences,
outputs_info=outputs_info,
non_sequences=non_sequences,
n_steps=n_steps,
strict=True,
mode=Mode(linker="py", optimizer=None),
)

if not isinstance(res, list):
res = [res]

# Get rid of any `Subtensor` indexing on the `Scan` outputs
res = [r.owner.inputs[0] if not isinstance(r.owner.op, Scan) else r for r in res]

scan_op = res[0].owner.op
assert isinstance(scan_op, Scan)

_ = op_check(scan_op)

if output_vals is None:
compare_jax_and_py(
((sequences + non_sequences), res), input_vals, updates=updates
)
else:
jax_fn = function(
(sequences + non_sequences), res, mode=jax_mode, updates=updates
)
res_vals = jax_fn(*input_vals)
assert np.allclose(res_vals, output_vals)


@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
Expand Down

0 comments on commit d1326b8

Please sign in to comment.