Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement new Loop and Scan operators #191

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):

JAX = Mode(
JAXLinker(),
RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]),
RewriteDatabaseQuery(
include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt", "not_jax"]
),
)
NUMBA = Mode(
NumbaLinker(),
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
import pytensor.link.jax.dispatch.random
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.loop

# isort: on
61 changes: 61 additions & 0 deletions pytensor/link/jax/dispatch/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import jax
from jax.tree_util import tree_flatten, tree_unflatten

from pytensor.compile.mode import get_mode
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.loop.op import Scan
from pytensor.typed_list import TypedListType


@jax_funcify.register(Scan)
def jax_funcify_Scan(op, node, global_fgraph, **kwargs):
# TODO: Rewrite as a while loop if only last states are used
if op.has_while_condition:
raise NotImplementedError(
"Scan ops with while condition cannot be transpiled JAX"
)

# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan?
update_fg = op.update_fg.clone()
rewriter = get_mode("JAX").optimizer
rewriter(update_fg)
Copy link
Member Author

@ricardoV94 ricardoV94 Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives an annoying Supervisor Feature missing warning... gotta clean that up


jaxified_scan_inner_fn = jax_funcify(update_fg, **kwargs)

# Only include the intermediate states that are used elsewhere
used_traces_idxs = [
i
for i, trace in enumerate(node.outputs[op.n_states :])
if global_fgraph.clients[trace]
]

def scan(max_iters, *outer_inputs):
states = outer_inputs[: op.n_states]
constants = outer_inputs[op.n_states :]

def scan_fn(carry, _):
resume, *carry = jaxified_scan_inner_fn(*carry, *constants)
assert resume
carry = list(carry)
# Return states as both carry and output to be appended
return carry, [c for i, c in enumerate(carry) if i in used_traces_idxs]

states, traces = jax.lax.scan(
scan_fn, init=list(states), xs=None, length=max_iters
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo: Check we are not missing performance by not having explicit sequences.

Todo: When there are multiple sequences PyTensor defines n_steps as the shortest sequence. JAX should be able to handle this, but if not we could consider not allowing sequences/n_steps with different lengths in the Pytensor scan.

Then we could pass a single shape as n_steps after asserting they are the same?

)
final_traces = [None] * len(states)
for idx, trace in zip(used_traces_idxs, traces):
if isinstance(op.trace_types[idx], TypedListType):
flattened_trace, treedef = tree_flatten(trace)
transposed_trace = [
tree_unflatten(treedef, l) for l in zip(*flattened_trace)
]
final_traces[idx] = transposed_trace
else:
final_traces[idx] = trace

return *states, *final_traces

return scan
1 change: 1 addition & 0 deletions pytensor/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ def fgraph_to_python(
global_env = {}

body_assigns = []
kwargs.setdefault("global_fgraph", fgraph)
for node in order:
compiled_func = op_conversion_fn(
node.op, node=node, storage_map=storage_map, **kwargs
Expand Down
Empty file added pytensor/loop/__init__.py
Empty file.
199 changes: 199 additions & 0 deletions pytensor/loop/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import functools
from typing import List, Union

import numpy as np

from pytensor import Variable, as_symbolic, clone_replace
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, truncated_graph_inputs
from pytensor.loop.op import Scan
from pytensor.scan.utils import until
from pytensor.tensor import as_tensor, constant, empty_like, minimum


def scan(
fn,
init_states=None,
sequences=None,
non_sequences=None,
n_steps=None,
go_backwards=False,
) -> Union[Variable, List[Variable]]:
if sequences is None and n_steps is None:
raise ValueError("Must provide n_steps when scanning without sequences")

if init_states is None:
init_states = []
else:
if not isinstance(init_states, (tuple, list)):
init_states = [init_states]
init_states = [as_symbolic(i) if i is not None else None for i in init_states]

if sequences is None:
sequences = []
else:
if not isinstance(sequences, (tuple, list)):
sequences = [sequences]
sequences = [as_tensor(s) for s in sequences]

if sequences:
leading_dims = [seq.shape[0] for seq in sequences]
shortest_dim = functools.reduce(minimum, leading_dims)
if n_steps is None:
n_steps = shortest_dim
else:
n_steps = minimum(n_steps, shortest_dim)

if non_sequences is None:
non_sequences = []
else:
if not isinstance(non_sequences, (tuple, list)):
non_sequences = [non_sequences]
non_sequences = [as_symbolic(n) for n in non_sequences]

# Create dummy inputs for the init state. The user function should not
# draw any relationship with the outer initial states, since these are only
# valid in the first iteration
inner_states = [i.type() if i is not None else None for i in init_states]

# Create subsequence inputs for the inner function
idx = constant(0, dtype="int64", name="idx")
symbolic_idx = idx.type(name="idx")
subsequences = [s[symbolic_idx] for s in sequences]

# Call user function to retrieve inner outputs. We use the same order as the old Scan,
# although inner_states + subsequences + non_sequences seems more intuitive,
# since subsequences are just a fancy non_sequence
# We don't pass the non-carried outputs [init is None] to the inner function
fn_inputs = (
subsequences + [i for i in inner_states if i is not None] + non_sequences
)
fn_outputs = fn(*fn_inputs)
if not isinstance(fn_outputs, (tuple, list)):
fn_outputs = [fn_outputs]
next_states = [out for out in fn_outputs if not isinstance(out, until)]

if len(next_states) > len(init_states):
if not init_states:
init_states = [None] * len(next_states)
inner_states = init_states
else:
raise ValueError(
"Please provide None as `init` for any output that is not carried over (i.e. it behaves like a map) "
)

# Replace None init by dummy empty tensors
prev_states = []
prev_inner_states = []
for i, (init_state, inner_state, next_state) in enumerate(
zip(init_states, inner_states, next_states)
):
if init_state is None:
# next_state may reference idx. We replace that by the initial value,
# so that the shape of the dummy init state does not depend on it.
[next_state] = clone_replace(
Copy link
Member

@ferrine ferrine Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not graph_replace or using memo for FunctionGraph(memo={symbolic_idx: idx}) (here)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that better?

output=[next_state], replace={symbolic_idx: idx}
)
init_state = empty_like(next_state)
init_state.name = "empty_init_state"
inner_state = init_state.type(name="dummy_state")
prev_states.append(init_state)
prev_inner_states.append(inner_state)

# Flip until to while condition
while_condition = [~out.condition for out in fn_outputs if isinstance(out, until)]
if not while_condition:
while_condition = [as_tensor(np.array(True))]
if len(while_condition) > 1:
raise ValueError("Only one until condition can be returned")

fgraph_inputs = [symbolic_idx] + prev_inner_states + sequences + non_sequences
fgraph_outputs = while_condition + [symbolic_idx + 1] + next_states

all_fgraph_inputs = truncated_graph_inputs(
fgraph_outputs, ancestors_to_include=fgraph_inputs
)
extra_fgraph_inputs = [
inp
for inp in all_fgraph_inputs
if (not isinstance(inp, Constant) and inp not in fgraph_inputs)
]
fgraph_inputs = fgraph_inputs + extra_fgraph_inputs
update_fg = FunctionGraph(inputs=fgraph_inputs, outputs=fgraph_outputs)

scan_op = Scan(update_fg=update_fg)
scan_outs = scan_op(
n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs
)
assert isinstance(scan_outs, list)
# Don't return the last states or the trace for the inner index
traces = scan_outs[scan_op.n_states + 1 :]
if len(traces) == 1:
return traces[0]
return traces


def map(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about subclassing Scan into

  • Map(Scan)
  • Reduce(Scan)
  • Filter(Scan)

It will be easier to dispatch into optimized implementations

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that later, not convinced we need that yet

fn,
sequences,
non_sequences=None,
go_backwards=False,
):
traces = scan(
fn=fn,
sequences=sequences,
non_sequences=non_sequences,
go_backwards=go_backwards,
)
return traces


def reduce(
fn,
init_states,
sequences,
non_sequences=None,
go_backwards=False,
):
traces = scan(
fn=fn,
init_states=init_states,
sequences=sequences,
non_sequences=non_sequences,
go_backwards=go_backwards,
)
if not isinstance(traces, list):
return traces[-1]
return [trace[-1] for trace in traces]


def filter(
fn,
sequences,
non_sequences=None,
go_backwards=False,
):
if not isinstance(sequences, (tuple, list)):
sequences = [sequences]

masks = scan(
fn=fn,
sequences=sequences,
non_sequences=non_sequences,
go_backwards=go_backwards,
)

if not isinstance(masks, list):
masks = [masks] * len(sequences)
elif len(masks) != len(sequences):
raise ValueError(
"filter fn must return one variable or len(sequences), but it returned {len(masks)}"
)
if not all(mask.dtype == "bool" for mask in masks):
raise TypeError("The output of filter fn should be a boolean variable")

filtered_sequences = [seq[mask] for seq, mask in zip(sequences, masks)]

if len(filtered_sequences) == 1:
return filtered_sequences[0]
return filtered_sequences
Loading