Skip to content

Commit

Permalink
Support mit-mots in the JAX backend
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 8, 2022
1 parent 344dbed commit a088c4b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
25 changes: 20 additions & 5 deletions aesara/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def assert_while_returns_last_output(fgraph, node):
def jax_funcify_Scan(op, node, **kwargs):
scan_inner_fn = jax_funcify(op.fgraph)
input_taps = {
"mit_mot": op.info.mit_mot_in_slices,
"mit_sot": op.info.mit_sot_in_slices,
"sit_sot": op.info.sit_sot_in_slices,
"nit_sot": op.info.sit_sot_in_slices,
Expand All @@ -76,9 +77,6 @@ def parse_outer_inputs(outer_inputs):
"shared": list(op.outer_shared(outer_inputs)),
"non_sequences": list(op.outer_non_seqs(outer_inputs)),
}
if len(outer_in["mit_mot"]) > 0:
raise NotImplementedError("mit-mot not supported")

return outer_in

if op.info.as_while:
Expand Down Expand Up @@ -364,7 +362,7 @@ def build_jax_scan_inputs(outer_in: Dict):
sequences = outer_in["sequences"]
init_carry = {
name: outer_in[name]
for name in ["mit_sot", "sit_sot", "shared", "non_sequences"]
for name in ["mit_mot", "mit_sot", "sit_sot", "shared", "non_sequences"]
}
init_carry["step"] = 0
return n_steps, sequences, init_carry
Expand All @@ -381,7 +379,7 @@ def build_inner_outputs_map(outer_in):
[+ while-condition]
"""
inner_outputs_names = ["mit_sot", "sit_sot", "nit_sot", "shared"]
inner_outputs_names = ["mit_mot", "mit_sot", "sit_sot", "nit_sot", "shared"]

offset = 0
inner_output_idx = defaultdict(list)
Expand Down Expand Up @@ -456,6 +454,9 @@ def scan_inner_in_args(carry, x):
current_step = carry["step"]

inner_in_seqs = x
inner_in_mit_mot = from_carry_storage(
carry["mit_mot"], current_step, input_taps["mit_mot"]
)
inner_in_mit_sot = from_carry_storage(
carry["mit_sot"], current_step, input_taps["mit_sot"]
)
Expand All @@ -468,6 +469,7 @@ def scan_inner_in_args(carry, x):
return sum(
[
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
Expand All @@ -480,6 +482,7 @@ def scan_new_carry(carry, inner_outputs):
"""Create a new carry value from the values returned by the inner function (inner-outputs)."""
step = carry["step"]
new_carry = {
"mit_mot": [],
"mit_sot": [],
"sit_sot": [],
"shared": [],
Expand All @@ -493,6 +496,14 @@ def scan_new_carry(carry, inner_outputs):
]
new_carry["shared"] = shared_inner_outputs

if "mit_mot" in inner_output_idx:
mit_mot_inner_outputs = [
inner_outputs[idx] for idx in inner_output_idx["mit_mot"]
]
new_carry["mit_mot"] = to_carry_storage(
mit_mot_inner_outputs, carry["mit_mot"], step, input_taps["mit_mot"]
)

if "mit_sot" in inner_output_idx:
mit_sot_inner_outputs = [
inner_outputs[idx] for idx in inner_output_idx["mit_sot"]
Expand Down Expand Up @@ -527,6 +538,10 @@ def scan_new_outputs(inner_outputs):
"""
outer_outputs = []
if "mit_mot" in inner_output_idx:
outer_outputs.append(
[inner_outputs[idx] for idx in inner_output_idx["mit_mot"]]
)
if "mit_sot" in inner_output_idx:
outer_outputs.append(
[inner_outputs[idx] for idx in inner_output_idx["mit_sot"]]
Expand Down
42 changes: 39 additions & 3 deletions tests/link/jax/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from packaging.version import parse as version_parse

import aesara.tensor as at
from aesara import function
from aesara import function, grad
from aesara.compile.mode import Mode
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
Expand All @@ -27,11 +27,10 @@


def test_while_cannnot_use_all_outputs():
"""The JAX backend cannot return all the outputs of a while loop.
"""The JAX backend cannot use all the outputs of a while loop.
Indeed, JAX has fundamental limitations that prevent it from returning
all the intermediate results computed in a `jax.lax.while_loop` loop.
"""
res, updates = scan(
fn=lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
Expand Down Expand Up @@ -233,6 +232,16 @@ def test_sequence_opt():
# [],
# 3,
# [],
# lambda op: op.info.n_sit_sot > 0,
# ),
# # sit-sot, while
# (
# lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
# [],
# [{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}],
# [],
# 3,
# [],
# None,
# lambda op: op.info.n_sit_sot > 0,
# ),
Expand Down Expand Up @@ -478,3 +487,30 @@ def power_step(prior_result, x):

for output_jax, output in zip(jax_res, res):
assert np.allclose(jax_res, res)


@pytest.mark.xfail(reason="Fails for reasons unrelated to `Scan`")
def test_mitmots_basic():

init_x = at.dvector()
seq = at.dvector()

def inner_fct(seq, state_old, state_current):
return state_old * 2 + state_current + seq

out, _ = scan(
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]}
)

g_outs = grad(out.sum(), [seq, init_x])

out_fg = FunctionGraph([seq, init_x], g_outs)

seq_val = np.arange(3)
init_x_val = np.r_[-2, -1]
(seq_val, init_x_val)

fn = function(out_fg.inputs, out_fg.outputs)
jax_fn = function(out_fg.inputs, out_fg.outputs, mode="JAX")
print(fn(seq_val, init_x_val))
print(jax_fn(seq_val, init_x_val))

0 comments on commit a088c4b

Please sign in to comment.