-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement new Loop and Scan operators #191
base: main
Are you sure you want to change the base?
Changes from all commits
3fe901d
78bb829
4da2f6e
e2fdf28
db7068e
5bc7070
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import jax | ||
from jax.tree_util import tree_flatten, tree_unflatten | ||
|
||
from pytensor.compile.mode import get_mode | ||
from pytensor.link.jax.dispatch.basic import jax_funcify | ||
from pytensor.loop.op import Scan | ||
from pytensor.typed_list import TypedListType | ||
|
||
|
||
@jax_funcify.register(Scan) | ||
def jax_funcify_Scan(op, node, global_fgraph, **kwargs): | ||
# TODO: Rewrite as a while loop if only last states are used | ||
if op.has_while_condition: | ||
raise NotImplementedError( | ||
"Scan ops with while condition cannot be transpiled JAX" | ||
) | ||
|
||
# Apply inner rewrites | ||
# TODO: Not sure this is the right place to do this, should we have a rewrite that | ||
# explicitly triggers the optimization of the inner graphs of Scan? | ||
update_fg = op.update_fg.clone() | ||
rewriter = get_mode("JAX").optimizer | ||
rewriter(update_fg) | ||
|
||
jaxified_scan_inner_fn = jax_funcify(update_fg, **kwargs) | ||
|
||
# Only include the intermediate states that are used elsewhere | ||
used_traces_idxs = [ | ||
i | ||
for i, trace in enumerate(node.outputs[op.n_states :]) | ||
if global_fgraph.clients[trace] | ||
] | ||
|
||
def scan(max_iters, *outer_inputs): | ||
states = outer_inputs[: op.n_states] | ||
constants = outer_inputs[op.n_states :] | ||
|
||
def scan_fn(carry, _): | ||
resume, *carry = jaxified_scan_inner_fn(*carry, *constants) | ||
assert resume | ||
carry = list(carry) | ||
# Return states as both carry and output to be appended | ||
return carry, [c for i, c in enumerate(carry) if i in used_traces_idxs] | ||
|
||
states, traces = jax.lax.scan( | ||
scan_fn, init=list(states), xs=None, length=max_iters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Todo: Check we are not missing performance by not having explicit sequences. Todo: When there are multiple sequences PyTensor defines n_steps as the shortest sequence. JAX should be able to handle this, but if not we could consider not allowing sequences/n_steps with different lengths in the Pytensor scan. Then we could pass a single shape as n_steps after asserting they are the same? |
||
) | ||
final_traces = [None] * len(states) | ||
for idx, trace in zip(used_traces_idxs, traces): | ||
if isinstance(op.trace_types[idx], TypedListType): | ||
flattened_trace, treedef = tree_flatten(trace) | ||
transposed_trace = [ | ||
tree_unflatten(treedef, l) for l in zip(*flattened_trace) | ||
] | ||
final_traces[idx] = transposed_trace | ||
else: | ||
final_traces[idx] = trace | ||
|
||
return *states, *final_traces | ||
|
||
return scan |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import functools | ||
from typing import List, Union | ||
|
||
import numpy as np | ||
|
||
from pytensor import Variable, as_symbolic, clone_replace | ||
from pytensor.graph import FunctionGraph | ||
from pytensor.graph.basic import Constant, truncated_graph_inputs | ||
from pytensor.loop.op import Scan | ||
from pytensor.scan.utils import until | ||
from pytensor.tensor import as_tensor, constant, empty_like, minimum | ||
|
||
|
||
def scan( | ||
fn, | ||
init_states=None, | ||
sequences=None, | ||
non_sequences=None, | ||
n_steps=None, | ||
go_backwards=False, | ||
) -> Union[Variable, List[Variable]]: | ||
if sequences is None and n_steps is None: | ||
raise ValueError("Must provide n_steps when scanning without sequences") | ||
|
||
if init_states is None: | ||
init_states = [] | ||
else: | ||
if not isinstance(init_states, (tuple, list)): | ||
init_states = [init_states] | ||
init_states = [as_symbolic(i) if i is not None else None for i in init_states] | ||
|
||
if sequences is None: | ||
sequences = [] | ||
else: | ||
if not isinstance(sequences, (tuple, list)): | ||
sequences = [sequences] | ||
sequences = [as_tensor(s) for s in sequences] | ||
|
||
if sequences: | ||
leading_dims = [seq.shape[0] for seq in sequences] | ||
shortest_dim = functools.reduce(minimum, leading_dims) | ||
if n_steps is None: | ||
n_steps = shortest_dim | ||
else: | ||
n_steps = minimum(n_steps, shortest_dim) | ||
|
||
if non_sequences is None: | ||
non_sequences = [] | ||
else: | ||
if not isinstance(non_sequences, (tuple, list)): | ||
non_sequences = [non_sequences] | ||
non_sequences = [as_symbolic(n) for n in non_sequences] | ||
|
||
# Create dummy inputs for the init state. The user function should not | ||
# draw any relationship with the outer initial states, since these are only | ||
# valid in the first iteration | ||
inner_states = [i.type() if i is not None else None for i in init_states] | ||
|
||
# Create subsequence inputs for the inner function | ||
idx = constant(0, dtype="int64", name="idx") | ||
symbolic_idx = idx.type(name="idx") | ||
subsequences = [s[symbolic_idx] for s in sequences] | ||
|
||
# Call user function to retrieve inner outputs. We use the same order as the old Scan, | ||
# although inner_states + subsequences + non_sequences seems more intuitive, | ||
# since subsequences are just a fancy non_sequence | ||
# We don't pass the non-carried outputs [init is None] to the inner function | ||
fn_inputs = ( | ||
subsequences + [i for i in inner_states if i is not None] + non_sequences | ||
) | ||
fn_outputs = fn(*fn_inputs) | ||
if not isinstance(fn_outputs, (tuple, list)): | ||
fn_outputs = [fn_outputs] | ||
next_states = [out for out in fn_outputs if not isinstance(out, until)] | ||
|
||
if len(next_states) > len(init_states): | ||
if not init_states: | ||
init_states = [None] * len(next_states) | ||
inner_states = init_states | ||
else: | ||
raise ValueError( | ||
"Please provide None as `init` for any output that is not carried over (i.e. it behaves like a map) " | ||
) | ||
|
||
# Replace None init by dummy empty tensors | ||
prev_states = [] | ||
prev_inner_states = [] | ||
for i, (init_state, inner_state, next_state) in enumerate( | ||
zip(init_states, inner_states, next_states) | ||
): | ||
if init_state is None: | ||
# next_state may reference idx. We replace that by the initial value, | ||
# so that the shape of the dummy init state does not depend on it. | ||
[next_state] = clone_replace( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not graph_replace or using memo for FunctionGraph(memo={symbolic_idx: idx}) (here)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is that better? |
||
output=[next_state], replace={symbolic_idx: idx} | ||
) | ||
init_state = empty_like(next_state) | ||
init_state.name = "empty_init_state" | ||
inner_state = init_state.type(name="dummy_state") | ||
prev_states.append(init_state) | ||
prev_inner_states.append(inner_state) | ||
|
||
# Flip until to while condition | ||
while_condition = [~out.condition for out in fn_outputs if isinstance(out, until)] | ||
if not while_condition: | ||
while_condition = [as_tensor(np.array(True))] | ||
if len(while_condition) > 1: | ||
raise ValueError("Only one until condition can be returned") | ||
|
||
fgraph_inputs = [symbolic_idx] + prev_inner_states + sequences + non_sequences | ||
fgraph_outputs = while_condition + [symbolic_idx + 1] + next_states | ||
|
||
all_fgraph_inputs = truncated_graph_inputs( | ||
fgraph_outputs, ancestors_to_include=fgraph_inputs | ||
) | ||
extra_fgraph_inputs = [ | ||
inp | ||
for inp in all_fgraph_inputs | ||
if (not isinstance(inp, Constant) and inp not in fgraph_inputs) | ||
] | ||
fgraph_inputs = fgraph_inputs + extra_fgraph_inputs | ||
update_fg = FunctionGraph(inputs=fgraph_inputs, outputs=fgraph_outputs) | ||
|
||
scan_op = Scan(update_fg=update_fg) | ||
scan_outs = scan_op( | ||
n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs | ||
) | ||
assert isinstance(scan_outs, list) | ||
# Don't return the last states or the trace for the inner index | ||
traces = scan_outs[scan_op.n_states + 1 :] | ||
if len(traces) == 1: | ||
return traces[0] | ||
return traces | ||
|
||
|
||
def map( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about subclassing Scan into
It will be easier to dispatch into optimized implementations There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can do that later, not convinced we need that yet |
||
fn, | ||
sequences, | ||
non_sequences=None, | ||
go_backwards=False, | ||
): | ||
traces = scan( | ||
fn=fn, | ||
sequences=sequences, | ||
non_sequences=non_sequences, | ||
go_backwards=go_backwards, | ||
) | ||
return traces | ||
|
||
|
||
def reduce( | ||
fn, | ||
init_states, | ||
sequences, | ||
non_sequences=None, | ||
go_backwards=False, | ||
): | ||
traces = scan( | ||
fn=fn, | ||
init_states=init_states, | ||
sequences=sequences, | ||
non_sequences=non_sequences, | ||
go_backwards=go_backwards, | ||
) | ||
if not isinstance(traces, list): | ||
return traces[-1] | ||
return [trace[-1] for trace in traces] | ||
|
||
|
||
def filter( | ||
fn, | ||
sequences, | ||
non_sequences=None, | ||
go_backwards=False, | ||
): | ||
if not isinstance(sequences, (tuple, list)): | ||
sequences = [sequences] | ||
|
||
masks = scan( | ||
fn=fn, | ||
sequences=sequences, | ||
non_sequences=non_sequences, | ||
go_backwards=go_backwards, | ||
) | ||
|
||
if not isinstance(masks, list): | ||
masks = [masks] * len(sequences) | ||
elif len(masks) != len(sequences): | ||
raise ValueError( | ||
"filter fn must return one variable or len(sequences), but it returned {len(masks)}" | ||
) | ||
if not all(mask.dtype == "bool" for mask in masks): | ||
raise TypeError("The output of filter fn should be a boolean variable") | ||
|
||
filtered_sequences = [seq[mask] for seq, mask in zip(sequences, masks)] | ||
|
||
if len(filtered_sequences) == 1: | ||
return filtered_sequences[0] | ||
return filtered_sequences |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives an annoying Supervisor Feature missing warning... gotta clean that up