From 7bcd42cde3934cb864958b3232f739fdfb8231e2 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 12 Jan 2023 17:42:51 +0100 Subject: [PATCH] Keep outer graph visible to the scan user function, including sequences Sequences are now demoted to being just another constant in the Scan Op. The user facing function creates the right indexing graph for iterating over sequences automatically. Some extra logic is added in the `scan_to_loop` rewrite to avoid creating duplicated indexes, while being on guard for Scans created elsewhere. --- pytensor/loop/basic.py | 63 +++++++++--- pytensor/loop/op.py | 216 ++++++++++++++++++++--------------------- tests/loop/basic.py | 32 +++++- tests/loop/test_op.py | 33 +++---- 4 files changed, 197 insertions(+), 147 deletions(-) diff --git a/pytensor/loop/basic.py b/pytensor/loop/basic.py index 92ae0bfe2e..bdb025fec2 100644 --- a/pytensor/loop/basic.py +++ b/pytensor/loop/basic.py @@ -1,12 +1,14 @@ +import functools from typing import List, Tuple import numpy as np -from pytensor import Variable, as_symbolic +from pytensor import Variable, as_symbolic, clone_replace from pytensor.graph import FunctionGraph +from pytensor.graph.basic import Constant, truncated_graph_inputs from pytensor.loop.op import Scan from pytensor.scan.utils import until -from pytensor.tensor import as_tensor, empty_like +from pytensor.tensor import as_tensor, constant, empty_like, minimum def scan( @@ -20,6 +22,8 @@ def scan( if sequences is None and n_steps is None: raise ValueError("Must provide n_steps when scanning without sequences") + # TODO: init_states should be made opaque to the inner function, + # since any relationship to the outer graph no longer holds if init_states is None: init_states = [] else: @@ -34,6 +38,14 @@ def scan( sequences = [sequences] sequences = [as_tensor(s) for s in sequences] + if sequences: + leading_dims = [seq.shape[0] for seq in sequences] + shortest_dim = functools.reduce(minimum, leading_dims) + if n_steps is None: + n_steps = shortest_dim + else: + n_steps = minimum(n_steps, shortest_dim) + if non_sequences is None: non_sequences = [] else: @@ -41,13 +53,16 @@ def scan( non_sequences = [non_sequences] non_sequences = [as_symbolic(n) for n in non_sequences] + # Create subsequence inputs for the inner function + idx = constant(0, dtype="int64", name="idx") + symbolic_idx = idx.type(name="idx") + subsequences = [s[symbolic_idx] for s in sequences] # Note: Old scan order is sequences + init + non_sequences - inner_sequences = [s[0] for s in sequences] - inner_inputs = [i.type() for i in init_states + inner_sequences + non_sequences] - inner_outputs = fn(*inner_inputs) - if not isinstance(inner_outputs, (tuple, list)): - inner_outputs = [inner_outputs] - next_states = [out for out in inner_outputs if not isinstance(out, until)] + fn_inputs = init_states + subsequences + non_sequences + fn_outputs = fn(*fn_inputs) + if not isinstance(fn_outputs, (tuple, list)): + fn_outputs = [fn_outputs] + next_states = [out for out in fn_outputs if not isinstance(out, until)] if len(next_states) > len(init_states): if not init_states: @@ -61,27 +76,43 @@ def scan( prev_states = [] for i, (init_state, next_state) in enumerate(zip(init_states, next_states)): if init_state is None: + # next_state may reference idx, let's replace that by the initial value + [next_state] = clone_replace( + output=[next_state], replace={symbolic_idx: idx} + ) init_state = empty_like(next_state) init_state.name = "empty_init_state" - inner_inputs.insert(i, init_state.type()) prev_states.append(init_state) - until_condition = [out.condition for out in inner_outputs if isinstance(out, until)] + until_condition = [out.condition for out in fn_outputs if isinstance(out, until)] if not until_condition: until_condition = [as_tensor(np.array(True))] if len(until_condition) > 1: raise ValueError("Only one until condition can be returned") - update_fg = FunctionGraph( - inputs=inner_inputs, outputs=until_condition + next_states + fgraph_inputs = [symbolic_idx] + prev_states + sequences + non_sequences + fgraph_outputs = until_condition + [symbolic_idx + 1] + next_states + + all_fgraph_inputs = truncated_graph_inputs( + fgraph_outputs, ancestors_to_include=fgraph_inputs + ) + extra_fgraph_inputs = [ + inp + for inp in all_fgraph_inputs + if (not isinstance(inp, Constant) and inp not in fgraph_inputs) + ] + fgraph_inputs = fgraph_inputs + extra_fgraph_inputs + update_fg = FunctionGraph(inputs=fgraph_inputs, outputs=fgraph_outputs) + + scan_op = Scan(update_fg=update_fg) + scan_outs = scan_op( + n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs ) - scan_op = Scan(update_fg=update_fg, n_sequences=len(sequences)) - scan_outs = scan_op(n_steps, *prev_states, *sequences, *non_sequences) assert isinstance(scan_outs, list) last_states = scan_outs[: scan_op.n_states] traces = scan_outs[scan_op.n_states :] - - return last_states, traces + # Don't return the inner index state + return last_states[1:], traces[1:] def map( diff --git a/pytensor/loop/op.py b/pytensor/loop/op.py index 2632a1d381..c858809eb1 100644 --- a/pytensor/loop/op.py +++ b/pytensor/loop/op.py @@ -1,14 +1,20 @@ -import functools from typing import Optional import numpy as np -from pytensor import In, Out, get_scalar_constant_value +from pytensor import In, Out from pytensor.compile import optdb, pfunc from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter from pytensor.graph.rewriting.basic import in2out from pytensor.scalar import constant -from pytensor.tensor import NoneConst, and_, empty, minimum, set_subtensor +from pytensor.tensor import ( + NoneConst, + add, + and_, + empty, + get_scalar_constant_value, + set_subtensor, +) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import Shape_i from pytensor.tensor.type import DenseTensorType, TensorType @@ -17,8 +23,14 @@ def validate_loop_update_types(update): assert update.outputs[0].type.dtype == "bool" - for input_state, output_state in zip(update.inputs, update.outputs[1:]): - assert input_state.type == output_state.type + for i, (input_state, output_state) in enumerate( + zip(update.inputs, update.outputs[1:]) + ): + if input_state.type != output_state.type: + raise TypeError( + f"The {i}-th input and output states of the inner loop function have different types: " + f"{input_state.type} vs {output_state.type}." + ) class Loop(Op): @@ -128,11 +140,11 @@ class Scan(Op): Roughly equivalent to ``` - def scan(fn, initial_states, sequences, constants, max_iters): + def scan(fn, initial_states, constants, max_iters): traces = [[]*len(initial_states)] states = initial_states - for (idx, *subsequences) in zip(*(range(max_iters), *sequences)): - resume, states = fn(*states, *subsequences, *constants) + for i in range(max_iters): + resume, states = fn(*states, *constants) for trace, state in zip(traces, states): trace.append(state) if not resume: @@ -142,15 +154,12 @@ def scan(fn, initial_states, sequences, constants, max_iters): Not all types of states can be collected, for instance RandomGenerator. For these `None` is returned in place of the respective traces - The number of iterations is bounded by max_iters or the shortest of sequences. - This Op must always be converted to a Loop during compilation. """ def __init__( self, update_fg: FunctionGraph, # (*state, *consts) -> (bool, *state) - n_sequences: int, reverse_fg: Optional[FunctionGraph] = None, ): validate_loop_update_types(update_fg) @@ -170,23 +179,8 @@ def __init__( # We can't concatenate all types of states, such as RandomTypes self.trace_types.append(NoneConst.type) - self.n_sequences = n_sequences - self.sequence_types = [] - for inner_seq in update_fg.inputs[ - self.n_states : self.n_states + self.n_sequences - ]: - # TODO: Accomodate other sequence types - assert isinstance(inner_seq.type, DenseTensorType) - self.sequence_types.append( - DenseTensorType( - shape=(None, *inner_seq.type.shape), dtype=inner_seq.type.dtype - ) - ) - - self.non_sequence_types = [ - inp.type for inp in update_fg.inputs[self.n_states + self.n_sequences :] - ] - self.n_non_sequences = len(self.non_sequence_types) + self.constant_types = [inp.type for inp in update_fg.inputs[self.n_states :]] + self.n_constants = len(self.constant_types) self.update_fg = update_fg.clone(check_integrity=False) self.reverse_fg = ( @@ -194,13 +188,9 @@ def __init__( ) def make_node(self, max_iters, *inputs): - assert len(inputs) == self.n_states + self.n_sequences + self.n_non_sequences - - if self.n_sequences == 0 and max_iters is None: - raise ValueError("Must provide max_iters in Scans without sequences") + assert len(inputs) == self.n_states + self.n_constants - if max_iters is not None: - max_iters = TensorType(dtype="int64", shape=()).filter_variable(max_iters) + max_iters = TensorType(dtype="int64", shape=()).filter_variable(max_iters) states = inputs[: self.n_states] states = [ @@ -208,23 +198,10 @@ def make_node(self, max_iters, *inputs): for inp_type, inp in zip(self.state_types, states) ] - sequences = inputs[self.n_states : self.n_states + self.n_sequences] - sequences = [ + constants = inputs[self.n_states :] + constants = [ inp_type.filter_variable(inp) - for inp_type, inp in zip(self.sequence_types, sequences) - ] - if sequences: - leading_dims = [seq.shape[0] for seq in sequences] - shortest_dim = functools.reduce(minimum, leading_dims) - if max_iters is None: - max_iters = shortest_dim - else: - max_iters = minimum(max_iters, shortest_dim) - - non_sequences = inputs[self.n_states + self.n_sequences :] - non_sequences = [ - inp_type.filter_variable(inp) - for inp_type, inp in zip(self.non_sequence_types, non_sequences) + for inp_type, inp in zip(self.constant_types, constants) ] # If there is no loop condition, `max_iters` exclusively defines the number of iterations @@ -249,7 +226,7 @@ def make_node(self, max_iters, *inputs): return Apply( self, - [max_iters, *states, *sequences, *non_sequences], + [max_iters, *states, *constants], [output_type() for output_type in self.state_types + trace_types], ) @@ -299,20 +276,16 @@ def scan_to_loop(fgraph, node): It roughly creates the following computational graph ``` - def scan(fn, initial_states, sequences, constants, max_iters): - - def update_fn(idx, states, traces, sequences, constants, max_iters) - subsequences = [seq[idx] for seq in subsequences] - resume, states = inner_fn(states, subsequences, constants) - for trace, state in zip(traces, states): - trace[idx] = state - return (resume and (idx < max_iters)), idx + 1, states, traces - + def scan(fn, idx, initial_states, constants, max_iters): idx = 0 + states = initial_states traces = [empty(max_iters, *initial_state.shape) for initial_state in initial_states] while True: - resume, idx, states, traces = update_fn(idx, *states, *traces, *sequences, *constants, max_iters) - if not resume: + resume, states, fn(*states, *traces, *constants) + for trace, state in zip(traces, states): + trace[idx] = state + idx += 1 + if not resume or idx >= max_iters: break traces = [trace[: idx] for trace in traces] return states, traces @@ -339,7 +312,6 @@ def update_fn(idx, states, traces, sequences, constants, max_iters) # Inputs to the new Loop max_iters = node.inputs[0] - init_idx = constant(np.array(0, dtype="int64"), name="idx") init_states = node.inputs[1 : 1 + op.n_states] init_traces = [ empty( @@ -348,49 +320,68 @@ def update_fn(idx, states, traces, sequences, constants, max_iters) ) for trace_idx in used_traces_idxs ] - sequences = node.inputs[1 + op.n_states : 1 + op.n_states + op.n_sequences] - non_sequences = node.inputs[1 + op.n_states + op.n_sequences :] + constants = node.inputs[1 + op.n_states :] - new_fg = op.update_fg.clone(check_integrity=False) + update_fg = op.update_fg.clone(check_integrity=False) - # Inner index - inner_prev_idx = init_idx.type() - inner_prev_idx.name = "prev_idx" + # Check if inner_fg computes and index already, otherwise create a new one + has_idx = False + if len(node.inputs) > 1: + try: + outer_inp = node.inputs[1] + outer_is_zero = get_scalar_constant_value(outer_inp) == 0 + except NotScalarConstantError: + pass + else: + if ( + outer_is_zero + and len(update_fg.inputs) > 0 + and len(update_fg.outputs) > 1 + ): + inner_out = update_fg.outputs[1] + if ( + inner_out.owner is not None + and inner_out.owner.op == add + and len(inner_out.owner.inputs) == 2 + ): + left, right = inner_out.owner.inputs + if left is update_fg.inputs[0]: + try: + has_idx = ( + get_scalar_constant_value( + right, only_process_constants=True + ) + == 1 + ) + except NotScalarConstantError: + pass + + if has_idx: + init_idx = outer_inp + inner_idx = inner_out.owner.inputs[0] + inner_next_idx = inner_out + if not has_idx: + init_idx = constant(np.array(0, dtype="int64"), name="idx") + inner_idx = init_idx.type() + inner_idx.name = "idx" + inner_next_idx = inner_idx + 1 + inner_next_idx.name = "next_idx" # Inner traces - inner_prev_states = new_fg.inputs[: op.n_states] - inner_prev_traces = [init_trace.type() for init_trace in init_traces] - for s, t in zip(inner_prev_states, inner_prev_traces): - t.name = "prev_trace" + inner_states = update_fg.inputs[: op.n_states] + inner_traces = [init_trace.type() for init_trace in init_traces] + for s, t in zip(inner_states, inner_traces): + t.name = "trace" if s.name: t.name = "_".join((t.name, s.name)) - inner_non_sequences = new_fg.inputs[op.n_states + op.n_sequences :] - - # Replace inner sub-sequences by sequence[idx] - inner_seqs_news = [] - if op.n_sequences: - inner_subseqs_old = new_fg.inputs[op.n_states : op.n_states + op.n_sequences] - inner_subseqs_new = [] - for sequence in sequences: - inner_seq_new = sequence.type() - inner_seq_new.name = sequence.name or "sequence" - inner_seqs_news.append(inner_seq_new) - inner_subseq_new = inner_seq_new[inner_prev_idx] - inner_subseq_new.name = inner_seq_new.name + "[prev_idx]" - inner_subseqs_new.append(inner_subseq_new) - - # Replace inner_sequence input by sequence[idx] - replacements = tuple(zip(inner_subseqs_old, inner_subseqs_new)) - new_fg.replace_all(replacements, import_missing=True) - - # Inner continue condition and index - inner_continue_cond, *inner_next_states = new_fg.outputs - inner_next_idx = inner_prev_idx + 1 - inner_next_idx.name = "next_idx" + inner_constants = update_fg.inputs[op.n_states :] + + # Inner continue condition + inner_continue_cond, *inner_next_states = update_fg.outputs inner_next_traces = [ - set_subtensor(prev_trace[inner_prev_idx], inner_next_states[trace_idx]) - for trace_idx, prev_trace in zip(used_traces_idxs, inner_prev_traces) + set_subtensor(prev_trace[inner_idx], inner_next_states[trace_idx]) + for trace_idx, prev_trace in zip(used_traces_idxs, inner_traces) ] for t in inner_next_traces: t.name = "next_trace" @@ -398,29 +389,34 @@ def update_fn(idx, states, traces, sequences, constants, max_iters) inner_continue_cond = and_(inner_continue_cond, inner_next_idx < inner_max_iters) inner_continue_cond.name = "continue(?)" - new_fg = FunctionGraph( + if not has_idx: + init_states = [init_idx] + init_states + inner_states = [inner_idx] + inner_states + inner_next_states = [inner_next_idx] + inner_next_states + + new_update_fg = FunctionGraph( inputs=[ - inner_prev_idx, - *inner_prev_states, - *inner_prev_traces, - *inner_seqs_news, - *inner_non_sequences, + *inner_states, + *inner_traces, + *inner_constants, inner_max_iters, ], outputs=[ inner_continue_cond, - inner_next_idx, *inner_next_states, *inner_next_traces, ], ) # TODO: Implement Reverse? - loop_op = Loop(update_fg=new_fg) - - final_idx, *new_outs = loop_op( - init_idx, *init_states, *init_traces, *sequences, *non_sequences, max_iters - ) + loop_op = Loop(update_fg=new_update_fg) + + new_outs = loop_op(*init_states, *init_traces, *constants, max_iters) + if has_idx: + # idx was part of the original scan, and therefore has a corresponding trace + final_idx = new_outs[0] + else: + final_idx, *new_outs = new_outs new_states = new_outs[: op.n_states] new_traces = new_outs[op.n_states :] diff --git a/tests/loop/basic.py b/tests/loop/basic.py index 700e389908..9af3cb0aaa 100644 --- a/tests/loop/basic.py +++ b/tests/loop/basic.py @@ -1,8 +1,9 @@ import numpy as np import pytensor +from pytensor import grad from pytensor.loop.basic import filter, map, reduce, scan -from pytensor.tensor import eq, vector, zeros +from pytensor.tensor import arange, eq, vector, zeros def test_scan(): @@ -19,6 +20,35 @@ def test_scan(): ) +def test_scan_taking_grads_non_sequence(): + xs = vector("xs") + ys = xs**2 + + _, [J] = scan( + lambda i, y, xs: grad(y[i], wrt=xs), + sequences=arange(ys.shape[0]), + non_sequences=[ys, xs], + ) + + f = pytensor.function([xs], J) + np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]]) + + +def test_scan_taking_grads_sequence(): + # This is not possible with the old Scan + xs = vector("xs") + ys = xs**2 + + _, [J] = scan( + lambda y, xs: grad(y, wrt=xs), + sequences=[ys], + non_sequences=[xs], + ) + + f = pytensor.function([xs], J) + np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]]) + + def test_map(): xs = vector("xs") ys = map( diff --git a/tests/loop/test_op.py b/tests/loop/test_op.py index 8445baa08b..28393d0d6a 100644 --- a/tests/loop/test_op.py +++ b/tests/loop/test_op.py @@ -41,7 +41,7 @@ def test_fori_scan(): update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2]) n_iters = 10 - y, ys = Scan(n_sequences=0, update_fg=update_fg)(n_iters, x) + y, ys = Scan(update_fg=update_fg)(n_iters, x) fn = function([x], [y, ys]) @@ -69,7 +69,7 @@ def test_fori_scan_shape(): update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2]) n_iters = 10 - _, ys = Scan(n_sequences=0, update_fg=update_fg)(n_iters, x) + _, ys = Scan(update_fg=update_fg)(n_iters, x) fn = function([x], ys.shape, on_unused_input="ignore") nodes = tuple(fn.maker.fgraph.apply_nodes) @@ -84,9 +84,7 @@ def test_while_scan(): update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) max_iters = 1000 - _, y, _, ys = Scan(n_sequences=0, update_fg=update_fg)( - max_iters, np.array(0, dtype="int64"), x - ) + _, y, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x) fn = function([x], [y, ys]) @@ -99,11 +97,10 @@ def test_while_scan(): ) assert len(loop_nodes) == 1 (loop_node,) = loop_nodes - assert len(loop_node.outputs) == 4 + assert len(loop_node.outputs) == 3 assert loop_node.outputs[0].type.shape == () assert loop_node.outputs[1].type.shape == () - assert loop_node.outputs[2].type.shape == () - assert loop_node.outputs[3].type.shape == (1000,) + assert loop_node.outputs[2].type.shape == (1000,) y_eval, ys_eval = fn(0) np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2)) @@ -116,9 +113,7 @@ def test_while_scan_shape(): update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) max_iters = 1000 - _, _, _, ys = Scan(n_sequences=0, update_fg=update_fg)( - max_iters, np.array(0, dtype="int64"), x - ) + _, _, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x) fn = function([x], ys.shape) loop_nodes = tuple( @@ -129,18 +124,18 @@ def test_while_scan_shape(): def test_foreach_scan(): - dummy_init = empty(()) - x = scalar("x") + idx = scalar("idx", dtype="int64") + dummy_x0 = empty(()) + xs = vector("xs") const = scalar("const") update_fg = FunctionGraph( - [dummy_init, x, const], [constant(np.array(True)), x * const] + [idx, dummy_x0, xs, const], [constant(np.array(True)), idx + 1, xs[idx] * const] ) - xs = vector("xs") - _, ys = Scan(n_sequences=1, update_fg=update_fg)(None, dummy_init, xs, const) + n_steps = xs.shape[0] + _, _, _, ys = Scan(update_fg=update_fg)(n_steps, 0, dummy_x0, xs, const) fn = pytensor.function([xs, const], ys) - pytensor.dprint(fn, print_type=True) np.testing.assert_almost_equal(fn(np.arange(10), 100), np.arange(10) * 100) @@ -157,9 +152,7 @@ def test_fori_random_scan(): [constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]], ) - _, new_rng, ys, rngs = Scan(n_sequences=0, update_fg=update_fg)( - n_iters, dummy_init, rng_shared - ) + _, new_rng, ys, rngs = Scan(update_fg=update_fg)(n_iters, dummy_init, rng_shared) assert isinstance(rngs.type, NoneTypeT) fn = function([], ys, updates={rng_shared: new_rng})