Skip to content

Commit

Permalink
Keep outer graph visible to the scan user function, including sequences
Browse files Browse the repository at this point in the history
Sequences are now demoted to being just another constant in the Scan Op.
The user facing function creates the right indexing graph for iterating over sequences automatically.

Some extra logic is added in the `scan_to_loop` rewrite to avoid creating duplicated indexes,
while being on guard for Scans created elsewhere.
  • Loading branch information
ricardoV94 committed Jan 12, 2023
1 parent c46cd53 commit 7bcd42c
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 147 deletions.
63 changes: 47 additions & 16 deletions pytensor/loop/basic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import functools
from typing import List, Tuple

import numpy as np

from pytensor import Variable, as_symbolic
from pytensor import Variable, as_symbolic, clone_replace
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, truncated_graph_inputs
from pytensor.loop.op import Scan
from pytensor.scan.utils import until
from pytensor.tensor import as_tensor, empty_like
from pytensor.tensor import as_tensor, constant, empty_like, minimum


def scan(
Expand All @@ -20,6 +22,8 @@ def scan(
if sequences is None and n_steps is None:
raise ValueError("Must provide n_steps when scanning without sequences")

# TODO: init_states should be made opaque to the inner function,
# since any relationship to the outer graph no longer holds
if init_states is None:
init_states = []
else:
Expand All @@ -34,20 +38,31 @@ def scan(
sequences = [sequences]
sequences = [as_tensor(s) for s in sequences]

if sequences:
leading_dims = [seq.shape[0] for seq in sequences]
shortest_dim = functools.reduce(minimum, leading_dims)
if n_steps is None:
n_steps = shortest_dim
else:
n_steps = minimum(n_steps, shortest_dim)

if non_sequences is None:
non_sequences = []
else:
if not isinstance(non_sequences, (tuple, list)):
non_sequences = [non_sequences]
non_sequences = [as_symbolic(n) for n in non_sequences]

# Create subsequence inputs for the inner function
idx = constant(0, dtype="int64", name="idx")
symbolic_idx = idx.type(name="idx")
subsequences = [s[symbolic_idx] for s in sequences]
# Note: Old scan order is sequences + init + non_sequences
inner_sequences = [s[0] for s in sequences]
inner_inputs = [i.type() for i in init_states + inner_sequences + non_sequences]
inner_outputs = fn(*inner_inputs)
if not isinstance(inner_outputs, (tuple, list)):
inner_outputs = [inner_outputs]
next_states = [out for out in inner_outputs if not isinstance(out, until)]
fn_inputs = init_states + subsequences + non_sequences
fn_outputs = fn(*fn_inputs)
if not isinstance(fn_outputs, (tuple, list)):
fn_outputs = [fn_outputs]
next_states = [out for out in fn_outputs if not isinstance(out, until)]

if len(next_states) > len(init_states):
if not init_states:
Expand All @@ -61,27 +76,43 @@ def scan(
prev_states = []
for i, (init_state, next_state) in enumerate(zip(init_states, next_states)):
if init_state is None:
# next_state may reference idx, let's replace that by the initial value
[next_state] = clone_replace(
output=[next_state], replace={symbolic_idx: idx}
)
init_state = empty_like(next_state)
init_state.name = "empty_init_state"
inner_inputs.insert(i, init_state.type())
prev_states.append(init_state)

until_condition = [out.condition for out in inner_outputs if isinstance(out, until)]
until_condition = [out.condition for out in fn_outputs if isinstance(out, until)]
if not until_condition:
until_condition = [as_tensor(np.array(True))]
if len(until_condition) > 1:
raise ValueError("Only one until condition can be returned")

update_fg = FunctionGraph(
inputs=inner_inputs, outputs=until_condition + next_states
fgraph_inputs = [symbolic_idx] + prev_states + sequences + non_sequences
fgraph_outputs = until_condition + [symbolic_idx + 1] + next_states

all_fgraph_inputs = truncated_graph_inputs(
fgraph_outputs, ancestors_to_include=fgraph_inputs
)
extra_fgraph_inputs = [
inp
for inp in all_fgraph_inputs
if (not isinstance(inp, Constant) and inp not in fgraph_inputs)
]
fgraph_inputs = fgraph_inputs + extra_fgraph_inputs
update_fg = FunctionGraph(inputs=fgraph_inputs, outputs=fgraph_outputs)

scan_op = Scan(update_fg=update_fg)
scan_outs = scan_op(
n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs
)
scan_op = Scan(update_fg=update_fg, n_sequences=len(sequences))
scan_outs = scan_op(n_steps, *prev_states, *sequences, *non_sequences)
assert isinstance(scan_outs, list)
last_states = scan_outs[: scan_op.n_states]
traces = scan_outs[scan_op.n_states :]

return last_states, traces
# Don't return the inner index state
return last_states[1:], traces[1:]


def map(
Expand Down
Loading

0 comments on commit 7bcd42c

Please sign in to comment.