Skip to content

Commit

Permalink
Support while loops in the JAX dispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 10, 2022
1 parent b78a011 commit 2232d76
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 34 deletions.
1 change: 1 addition & 0 deletions aesara/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def jax_funcify_FunctionGraph(
fgraph_name="jax_funcified_fgraph",
**kwargs,
):

return fgraph_to_python(
fgraph,
jax_funcify,
Expand Down
270 changes: 259 additions & 11 deletions aesara/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, Dict, List

import jax
import jax.numpy as jnp

from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.scan.op import Scan
Expand All @@ -14,6 +15,7 @@ def jax_funcify_Scan(op, node, **kwargs):
input_taps = {
"mit_sot": op.info.mit_sot_in_slices,
"sit_sot": op.info.sit_sot_in_slices,
"nit_sot": op.info.sit_sot_in_slices,
}

# Outer-inputs are the inputs to the `Scan` apply node, built from the
Expand All @@ -36,13 +38,258 @@ def parse_outer_inputs(outer_inputs):
return outer_in

if op.info.as_while:
raise NotImplementedError("While loops are not supported in the JAX backend.")
# We can only compile a `Scan` node that acts as a `while` loop to JAX
# if only the last computed value is ever used in the outer function.
# TODO: Determine if that's the case
# TODO: Rewrite the graph if that's the case
# TODO: Implement a simple `while` loop that returns the last step
return make_jax_while_fn(scan_inner_fn, parse_outer_inputs, input_taps)
else:
return make_jax_scan_fn(
scan_inner_fn,
parse_outer_inputs,
input_taps,
)
return make_jax_scan_fn(scan_inner_fn, parse_outer_inputs, input_taps)


def make_jax_while_fn(
scan_inner_fn: Callable,
parse_outer_inputs: Callable[[TensorVariable], Dict[str, List[TensorVariable]]],
input_taps: Dict,
):
"""Create a `jax.lax.while_loop` function to perform `Scan` computations when it
is used as while loop.
`jax.lax.while_loop` iterates by passing a value `carry` to a `body_fun` that
must return a value of the same type (Pytree structure, shape and dtype of
the leaves). Before calling `body_fn`, it calls `cond_fn` which takes the
current value and returns a boolean that indicates whether to keep iterating
or not.
The JAX `while_loop` needs to perform the following operations:
1. Extract the inner-inputs;
2. Build the initial carry value;
3. Inside the loop:
1. `carry` -> inner-inputs;
2. inner-outputs -> `carry`
4. Post-process the `carry` storage and return outputs
"""

def build_while_carry(outer_in):
"""Build the inputs to `jax.lax.scan` from the outer-inputs."""
init_carry = {
"mit_sot": [],
"mit_sot_storage": outer_in["mit_sot"],
"sit_sot": [],
"sit_sot_storage": outer_in["sit_sot"],
"shared": outer_in["shared"],
"sequences": outer_in["sequences"],
"non_sequences": outer_in["non_sequences"],
}
init_carry["step"] = 0
init_carry["do_stop"] = False
return init_carry

def build_inner_outputs_map(outer_in):
"""Map the inner-output variables to their position in the tuple returned by the inner function.
TODO: Copied from the scan builder
Inner-outputs are ordered as follow:
- mit-mot-outputs
- mit-sot-outputs
- sit-sot-outputs
- nit-sots (no carry)
- shared-outputs
[+ while-condition]
"""
inner_outputs_names = ["mit_sot", "sit_sot", "nit_sot", "shared"]

offset = 0
inner_output_idx = defaultdict(list)
for name in inner_outputs_names:
num_outputs = len(outer_in[name])
for i in range(num_outputs):
inner_output_idx[name].append(offset + i)
offset += num_outputs

return inner_output_idx

def from_carry_storage(carry, step, input_taps):
"""Fetch the inner inputs from the values stored in the carry array.
`Scan` passes storage arrays as inputs, which are then read from and
updated in the loop body. At each step we need to read from this array
the inputs that will be passed to the inner function.
This mechanism is necessary because we handle multiple-input taps within
the `scan` instead of letting users manage the memory in the use cases
where this is necessary.
TODO: Copied from the scan builder
"""

def fetch(carry, step, offset):
return carry[step + offset]

inner_inputs = []
for taps, carry_element in zip(input_taps, carry):
storage_size = -min(taps)
offsets = [storage_size + tap for tap in taps]
inner_inputs.append(
[fetch(carry_element, step, offset) for offset in offsets]
)

return sum(inner_inputs, [])

def to_carry_storage(inner_outputs, carry, step, input_taps):
"""Create the new carry array from the inner output
`Scan` passes storage arrays as inputs, which are then read from and
updated in the loop body. At each step we need to update this array
with the outputs of the inner function
TODO: Copied from the scan builder
"""
new_carry_element = []
for taps, carry_element, output in zip(input_taps, carry, inner_outputs):
new_carry_element.append(
[carry_element.at[step - tap].set(output) for tap in taps]
)

return sum(new_carry_element, [])

def while_loop(*outer_inputs):

outer_in = parse_outer_inputs(outer_inputs)
init_carry = build_while_carry(outer_in)
inner_output_idx = build_inner_outputs_map(outer_in)

def inner_inputs_from_carry(carry):
"""Get inner-inputs from the arguments passed to the `jax.lax.while_loop` body function.
Inner-inputs are ordered as follows:
- sequences
- mit-mot inputs
- mit-sot inputs
- sit-sot inputs
- shared-inputs
- non-sequences
"""
current_step = carry["step"]

inner_in_mit_sot = from_carry_storage(
carry["mit_sot_storage"], current_step, input_taps["mit_sot"]
)
inner_in_sit_sot = from_carry_storage(
carry["sit_sot_storage"], current_step, input_taps["sit_sot"]
)
inner_in_shared = carry.get("shared", [])
inner_in_non_sequences = carry.get("non_sequences", [])

return sum(
[
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_sequences,
],
[],
)

def carry_from_inner_outputs(carry, inner_outputs):
step = carry["step"]
new_carry = {
"mit_sot": [],
"sit_sot": [],
"sit_sot_storage": [],
"nit_sot": [],
"mit_sot_storage": [],
"shared": [],
"step": step + 1,
"sequences": carry["sequences"],
"non_sequences": carry["non_sequences"],
"do_stop": inner_outputs[-1],
}

if "shared" in inner_output_idx:
shared_inner_outputs = [
inner_outputs[idx] for idx in inner_output_idx["shared"]
]
new_carry["shared"] = shared_inner_outputs

if "mit_sot" in inner_output_idx:
mit_sot_inner_outputs = [
inner_outputs[idx] for idx in inner_output_idx["mit_sot"]
]
new_carry["mit_sot"] = mit_sot_inner_outputs
new_carry["mit_sot_storage"] = to_carry_storage(
mit_sot_inner_outputs,
carry["mit_sot_storage"],
step,
input_taps["mit_sot"],
)

if "sit_sot" in inner_output_idx:
sit_sot_inner_outputs = [
inner_outputs[idx] for idx in inner_output_idx["sit_sot"]
]
new_carry["sit_sot"] = sit_sot_inner_outputs
new_carry["sit_sot_storage"] = to_carry_storage(
sit_sot_inner_outputs,
carry["sit_sot_storage"],
step,
input_taps["sit_sot"],
)

if "nit_sot" in inner_output_idx:
nit_sot_inner_outputs = [
inner_outputs[idx] for idx in inner_output_idx["nit_sot"]
]
new_carry["nit_sot"] = nit_sot_inner_outputs

return new_carry

def cond_fn(carry):
# The inner-function of `Scan` returns a boolean as the last
# value. This needs to be included in `carry`.
# TODO: Will it return `False` if the number of steps is exceeded?
return ~carry["do_stop"]

def body_fn(carry):
inner_inputs = inner_inputs_from_carry(carry)
inner_outputs = scan_inner_fn(*inner_inputs)
new_carry = carry_from_inner_outputs(carry, inner_outputs)
return new_carry

# The `Scan` implementation in the C backend will execute the
# function once before checking the termination condition, while
# `jax.lax.while_loop` checks the condition first. We thus need to call
# `body_fn` once before calling `jax.lax.while_loop`. This allows us,
# along with `n_steps`, to build the storage array for the `nit-sot`s
# since there is no way to know their shape and dtype before executing
# the function.
inner_inputs = inner_inputs_from_carry(init_carry)
inner_outputs = scan_inner_fn(*inner_inputs)
carry = carry_from_inner_outputs(init_carry, inner_outputs)
carry = jax.lax.while_loop(cond_fn, body_fn, carry)

# Post-process the storage arrays
# We make sure that the outputs are not scalars in case an array
# is expected downstream since `Scan` is supposed to always return arrays
carry["sit_sot"] = [jnp.atleast_1d(element) for element in carry["sit_sot"]]
carry["mit_sot"] = [jnp.atleast_1d(element) for element in carry["mit_sot"]]
carry["nit_not"] = [jnp.atleast_1d(element) for element in carry["nit_sot"]]

outer_outputs = ["mit_sot", "sit_sot", "nit_sot", "shared"]
results = sum([carry[output] for output in outer_outputs], [])
if len(results) == 1:
return results[0]
else:
return results

return while_loop


def make_jax_scan_fn(
Expand All @@ -58,7 +305,8 @@ def make_jax_scan_fn(
stacked to the previous outputs. We use this to our advantage to build
`Scan` outputs without having to post-process the storage arrays.
The JAX scan function needs to perform the following operations:
The JAX `scan` function needs to perform the following operations:
1. Extract the inner-inputs;
2. Build the initial `carry` and `sequence` values;
3. Inside the loop:
Expand Down Expand Up @@ -265,11 +513,11 @@ def body_fn(carry, x):
)

shared_output = tuple(last_carry["shared"])
results = results + shared_output
outer_outputs = results + shared_output

if len(results) == 1:
return results[0]
if len(outer_outputs) == 1:
return outer_outputs[0]

return results
return outer_outputs

return scan
2 changes: 1 addition & 1 deletion aesara/link/jax/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def shape(x):


@jax_funcify.register(Shape_i)
def jax_funcify_Shape_i(op, **kwargs):
def jax_funcify_Shape_i(op, node, **kwargs):
i = op.i

def shape_i(x):
Expand Down
35 changes: 35 additions & 0 deletions aesara/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.graph.basic import Constant
from aesara.graph.rewriting.basic import WalkingGraphRewriter, node_rewriter
from aesara.link.basic import JITLinker


Expand All @@ -12,7 +13,10 @@ class JAXLinker(JITLinker):

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from aesara.link.jax.dispatch import jax_funcify
from aesara.scan.op import Scan
from aesara.tensor.random.type import RandomType
from aesara.tensor.shape import Shape_i
from aesara.tensor.subtensor import Subtensor

shared_rng_inputs = [
inp
Expand Down Expand Up @@ -49,6 +53,37 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
)

@node_rewriter([Scan])
def check_while_returns_last_output(fgraph, node):
op = node.op
if not op.info.as_while:
return False

# Count the number of outputs of the outer function. We ignore
# `shared` variables since they are not not accumulated and not
# returned to the user.
num_outer_outputs = (
op.info.n_mit_mot
+ op.info.n_mit_sot
+ op.info.n_sit_sot
+ op.info.n_nit_sot
)
for out in node.outputs[:num_outer_outputs]:
for client, _ in fgraph.clients[out]:
if isinstance(client, str):
raise NotImplementedError()
elif isinstance(client.op, Subtensor):
idx_list = client.op.idx_list
if isinstance(idx_list[0], slice):
raise NotImplementedError()
elif not isinstance(client.op, Shape_i):
raise NotImplementedError()

return False

jax_opt = WalkingGraphRewriter(check_while_returns_last_output)
jax_opt.rewrite(fgraph)

return jax_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
)
Expand Down
Loading

0 comments on commit 2232d76

Please sign in to comment.