diff --git a/aesara/link/jax/dispatch/scan.py b/aesara/link/jax/dispatch/scan.py index 12c588c0d6..b621c05cfc 100644 --- a/aesara/link/jax/dispatch/scan.py +++ b/aesara/link/jax/dispatch/scan.py @@ -1,159 +1,79 @@ import jax -import jax.numpy as jnp -from aesara.graph.fg import FunctionGraph from aesara.link.jax.dispatch.basic import jax_funcify from aesara.scan.op import Scan -from aesara.scan.utils import ScanArgs @jax_funcify.register(Scan) -def jax_funcify_Scan(op, **kwargs): - inner_fg = FunctionGraph(op.inputs, op.outputs) - jax_at_inner_func = jax_funcify(inner_fg, **kwargs) +def jax_funcify_Scan(op, node, **kwargs): + scan_inner_func = jax_funcify(op.fgraph) def scan(*outer_inputs): - scan_args = ScanArgs( - list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info - ) - - # `outer_inputs` is a list with the following composite form: - # [n_steps] - # + outer_in_seqs - # + outer_in_mit_mot - # + outer_in_mit_sot - # + outer_in_sit_sot - # + outer_in_shared - # + outer_in_nit_sot - # + outer_in_non_seqs - n_steps = scan_args.n_steps - seqs = scan_args.outer_in_seqs - - # TODO: mit_mots - mit_mot_in_slices = [] - - mit_sot_in_slices = [] - for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot): - neg_taps = [abs(t) for t in tap if t < 0] - pos_taps = [abs(t) for t in tap if t > 0] - max_neg = max(neg_taps) if neg_taps else 0 - max_pos = max(pos_taps) if pos_taps else 0 - init_slice = seq[: max_neg + max_pos] - mit_sot_in_slices.append(init_slice) - - sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot] - - init_carry = ( - mit_mot_in_slices, - mit_sot_in_slices, - sit_sot_in_slices, - scan_args.outer_in_shared, - scan_args.outer_in_non_seqs, - ) - - def jax_args_to_inner_scan(op, carry, x): - # `carry` contains all inner-output taps, non_seqs, and shared - # terms - ( - inner_in_mit_mot, - inner_in_mit_sot, - inner_in_sit_sot, - inner_in_shared, - inner_in_non_seqs, - ) = carry - - # `x` contains the in_seqs - inner_in_seqs = x - # `inner_scan_inputs` is a list with the following composite form: - # inner_in_seqs - # + sum(inner_in_mit_mot, []) - # + sum(inner_in_mit_sot, []) - # + inner_in_sit_sot - # + inner_in_shared - # + inner_in_non_seqs - inner_in_mit_sot_flatten = [] - for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices): - inner_in_mit_sot_flatten.extend(array[jnp.array(index)]) - - inner_scan_inputs = sum( - [ - inner_in_seqs, - inner_in_mit_mot, - inner_in_mit_sot_flatten, - inner_in_sit_sot, - inner_in_shared, - inner_in_non_seqs, - ], - [], - ) - - return inner_scan_inputs - - def inner_scan_outs_to_jax_outs( - op, - old_carry, - inner_scan_outs, - ): - ( - inner_in_mit_mot, - inner_in_mit_sot, - inner_in_sit_sot, - inner_in_shared, - inner_in_non_seqs, - ) = old_carry - - def update_mit_sot(mit_sot, new_val): - return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0) - - inner_out_mit_sot = [ - update_mit_sot(mit_sot, new_val) - for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs) - ] - - # This should contain all inner-output taps, non_seqs, and shared - # terms - if not inner_in_sit_sot: - inner_out_sit_sot = [] + n_steps = outer_inputs[0] + outer_in_seqs = list(op.outer_seqs(outer_inputs)) + outer_in_mit_mot = list(op.outer_mitmot(outer_inputs)) + outer_in_mit_sot = list(op.outer_mitsot(outer_inputs)) + outer_in_sit_sot = list(op.outer_sitsot(outer_inputs)) + outer_in_nit_sot = list(op.outer_nitsot(outer_inputs)) + outer_in_shared = list(op.outer_shared(outer_inputs)) + outer_in_non_seqs = list(op.outer_non_seqs(outer_inputs)) + if len(outer_in_mit_mot): + raise NotImplementedError("mit-mot not supported") + if len(outer_in_mit_sot): + raise NotImplementedError("mit-sot not supported") + if len(outer_in_sit_sot): + raise NotImplementedError("sit-sot not supported") + if len(outer_in_shared): + raise NotImplementedError("shared variables not supported") + if len(outer_in_non_seqs): + raise NotImplementedError("non sequence are not supported") + + init_carry = outer_in_nit_sot + sequences = outer_in_seqs + + def scan_inner_in_args(carry, x, is_dummy_sit_sot=True): + """Create an inner-input expression. + + Inner-inputs are ordered as follows: + - sequences + - mit-mot inputs + - mit-sot inputs + - sit-sot inputs + - shared-inputs + - non-sequences + """ + + inner_in_seqs = x + if is_dummy_sit_sot: + inner_in_sit_sot = [] else: - inner_out_sit_sot = inner_scan_outs - new_carry = ( - inner_in_mit_mot, - inner_out_mit_sot, - inner_out_sit_sot, - inner_in_shared, - inner_in_non_seqs, - ) - - return new_carry - - def jax_inner_func(carry, x): - inner_args = jax_args_to_inner_scan(op, carry, x) - inner_scan_outs = list(jax_at_inner_func(*inner_args)) - new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs) - return new_carry, inner_scan_outs - - _, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps) - - # We need to prepend the initial values so that the JAX output will - # match the raw `Scan` `Op` output and, thus, work with a downstream - # `Subtensor` `Op` introduced by the `scan` helper function. - def append_scan_out(scan_in_part, scan_out_part): - return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0) - - if scan_args.outer_in_mit_sot: - scan_out_final = [ - append_scan_out(init, out) - for init, out in zip(scan_args.outer_in_mit_sot, scan_out) - ] - elif scan_args.outer_in_sit_sot: - scan_out_final = [ - append_scan_out(init, out) - for init, out in zip(scan_args.outer_in_sit_sot, scan_out) - ] - - if len(scan_out_final) == 1: - scan_out_final = scan_out_final[0] - return scan_out_final + inner_in_sit_sot = carry + return sum([inner_in_seqs, inner_in_sit_sot], []) + + def scan_new_carry(inner_outputs): + """Create a new carry expression + + Inner-outputs are ordered as follow: + - mit-mot-outputs + - mit-sot-outputs + - sit-sot-outputs + - nit-sots + - shared-outputs + [+ while-condition] + + """ + carry = list(inner_outputs) + return carry + + def body_fn(carry, x): + inner_in_args = scan_inner_in_args(carry, x) + inner_outputs = scan_inner_func(*inner_in_args) + carry = scan_new_carry(inner_outputs) + return carry, *inner_outputs + + _, results = jax.lax.scan(body_fn, init_carry, sequences, length=n_steps) + + return results return scan diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index 158f8bd14d..133581640c 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -3,9 +3,14 @@ from packaging.version import parse as version_parse import aesara.tensor as at +from aesara import function +from aesara.compile.mode import Mode from aesara.configdefaults import config from aesara.graph.fg import FunctionGraph +from aesara.graph.rewriting.db import RewriteDatabaseQuery +from aesara.link.jax.linker import JAXLinker from aesara.scan.basic import scan +from aesara.scan.op import Scan from aesara.tensor.math import gammaln, log from aesara.tensor.type import ivector, lscalar, scalar from tests.link.jax.test_basic import compare_jax_and_py @@ -13,6 +18,70 @@ jax = pytest.importorskip("jax") +# Disable all optimizations +opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) +jax_mode = Mode(JAXLinker(), opts) + + +@pytest.mark.parametrize( + "fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check", + [ + # sequences + ( + lambda a_t: 2 * a_t, + [at.dvector("a")], + [{}], + [], + None, + [np.arange(10)], + None, + lambda op: op.info.n_seqs > 0, + ), + ], +) +def test_xit_xot_types( + fn, + sequences, + outputs_info, + non_sequences, + n_steps, + input_vals, + output_vals, + op_check, +): + """Test basic xit-xot configurations.""" + res, updates = scan( + fn, + sequences=sequences, + outputs_info=outputs_info, + non_sequences=non_sequences, + n_steps=n_steps, + strict=True, + mode=Mode(linker="py", optimizer=None), + ) + + if not isinstance(res, list): + res = [res] + + # Get rid of any `Subtensor` indexing on the `Scan` outputs + res = [r.owner.inputs[0] if not isinstance(r.owner.op, Scan) else r for r in res] + + scan_op = res[0].owner.op + assert isinstance(scan_op, Scan) + + _ = op_check(scan_op) + + if output_vals is None: + compare_jax_and_py( + ((sequences + non_sequences), res), input_vals, updates=updates + ) + else: + jax_fn = function( + (sequences + non_sequences), res, mode=jax_mode, updates=updates + ) + res_vals = jax_fn(*input_vals) + assert np.allclose(res_vals, output_vals) + @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"),