From ffd9ff7a73fa499bcfedb9199abfd7c2e86ddd1f Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 6 Oct 2022 00:33:25 -0500 Subject: [PATCH] Add support for shared inputs in numba_funcify_Scan --- aesara/link/numba/dispatch/scan.py | 313 +++++++++++++++++------------ tests/link/numba/test_scan.py | 177 +++++++++++++--- 2 files changed, 330 insertions(+), 160 deletions(-) diff --git a/aesara/link/numba/dispatch/scan.py b/aesara/link/numba/dispatch/scan.py index 564615b1e3..6180995849 100644 --- a/aesara/link/numba/dispatch/scan.py +++ b/aesara/link/numba/dispatch/scan.py @@ -1,4 +1,3 @@ -from itertools import groupby from textwrap import dedent, indent from typing import Dict, List, Optional, Tuple @@ -14,6 +13,7 @@ ) from aesara.link.utils import compile_function_src from aesara.scan.op import Scan +from aesara.tensor.type import TensorType def idx_to_str( @@ -49,8 +49,6 @@ def range_arr(x): def numba_funcify_Scan(op, node, **kwargs): scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) - n_seqs = op.info.n_seqs - outer_in_names_to_vars = { (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs) } @@ -60,22 +58,63 @@ def numba_funcify_Scan(op, node, **kwargs): outer_in_mit_sot_names = op.outer_mitsot(outer_in_names) outer_in_sit_sot_names = op.outer_sitsot(outer_in_names) outer_in_nit_sot_names = op.outer_nitsot(outer_in_names) + outer_in_shared_names = op.outer_shared(outer_in_names) + outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names) + + # These are all the outer-input names that have produce outputs/have output + # taps (i.e. they have inner-outputs and corresponding outer-outputs). + # Outer-outputs are ordered as follows: + # mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs outer_in_outtap_names = ( outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names + outer_in_nit_sot_names + + outer_in_shared_names ) - outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names) - inner_in_to_index_offset: List[Tuple[str, Optional[int], Optional[int]]] = [] - allocate_taps_storage: List[str] = [] + # We create distinct variables for/references to the storage arrays for + # each output. + outer_in_to_storage_name: Dict[str, str] = {} + for outer_in_name in outer_in_mit_mot_names: + outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_mitmot_storage" + + for outer_in_name in outer_in_mit_sot_names: + outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_mitsot_storage" + + for outer_in_name in outer_in_sit_sot_names: + outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_sitsot_storage" + + for outer_in_name in outer_in_nit_sot_names: + outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage" + + for outer_in_name in outer_in_shared_names: + outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_shared_storage" + + outer_output_names = list(outer_in_to_storage_name.values()) + assert len(outer_output_names) == len(node.outputs) + + # Construct the inner-input expressions (e.g. indexed storage expressions) + # Inner-inputs are ordered as follows: + # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + + # shared-inputs + non-sequences. + inner_in_exprs: List[str] = [] + + def add_inner_in_expr( + outer_in_name: str, tap_offset: Optional[int], storage_size_var: Optional[str] + ): + """Construct an inner-input expression.""" + storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name) + indexed_inner_in_str = ( + storage_name + if tap_offset is None + else idx_to_str(storage_name, tap_offset, size=storage_size_var) + ) + inner_in_exprs.append(indexed_inner_in_str) for outer_in_name in outer_in_seqs_names: - # A sequence with multiple taps is provided as multiple modified input - # sequences--all sliced so as to keep following the logic of a normal - # sequence. - inner_in_to_index_offset.append((outer_in_name, 0, None)) + # These outer-inputs are indexed without offsets or storage wrap-around + add_inner_in_expr(outer_in_name, 0, None) inner_in_names_to_input_taps: Dict[str, Tuple[int]] = dict( zip( @@ -89,60 +128,120 @@ def numba_funcify_Scan(op, node, **kwargs): zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices) ) + # Inner-outputs consist of: + # mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + + # shared-outputs [+ while-condition] inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))] - # Maps storage array names to their tap values (i.e. maximum absolute tap - # value) and storage sizes - inner_out_name_to_taps_storage: List[Tuple[str, int, Optional[str]]] = [] - outer_in_to_storage_name: Dict[str, str] = {} - outer_in_sot_names = set( - outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names - ) + # inner_out_shared_names = op.inner_shared_outs(inner_output_names) + + # The assignment statements that copy inner-outputs into the outer-outputs + # storage + inner_out_to_outer_in_stmts: List[str] = [] + + # Special statements that perform storage truncation for `while`-loops and + # rotation for initially truncated storage. + output_storage_post_proc_stmts: List[str] = [] + + # In truncated storage situations (e.g. created by `save_mem_new_scan`), + # the taps and output storage overlap, instead of the standard situation in + # which the output storage is large enough to contain both the initial taps + # values and the output storage. In this truncated case, we use the + # storage array like a circular buffer, and that's why we need to track the + # storage size along with the taps length/indexing offset. + def add_output_storage_post_proc_stmt( + outer_in_name: str, tap_sizes: Tuple[int], storage_size: Optional[str] + ): + + if storage_size is None: + return + + tap_size = max(tap_sizes) + + if op.info.as_while: + # While loops need to truncate the output storage to a length given + # by the number of iterations performed. + output_storage_post_proc_stmts.append( + dedent( + f""" + if i + {tap_size} < {storage_size}: + {storage_size} = i + {tap_size} + {outer_in_name} = {outer_in_name}[:{storage_size}] + """ + ).strip() + ) + + # Rotate the storage so that the last computed value is at the end of + # the storage array. + # This is needed when the output storage array does not have a length + # equal to the number of taps plus `n_steps`. + output_storage_post_proc_stmts.append( + dedent( + f""" + {outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) + if {outer_in_name}_shift > 0: + {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] + {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] + {outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left)) + """ + ).strip() + ) + + # Special in-loop statements that create (nit-sot) storage arrays after a + # single iteration is performed. This is necessary because we don't know + # the exact shapes of the storage arrays that need to be allocated until + # after an iteration is performed. inner_out_post_processing_stmts: List[str] = [] + + # Storage allocation statements + # For output storage allocated/provided by the inputs, these statements + # will either construct aliases between the input names and the entries in + # `outer_in_to_storage_name` or assign the latter to expressions that + # create copies of those storage inputs. + # In the nit-sot case, empty dummy arrays are assigned to the storage + # variables and updated later by the statements in + # `inner_out_post_processing_stmts`. + storage_alloc_stmts: List[str] = [] + for outer_in_name in outer_in_outtap_names: outer_in_var = outer_in_names_to_vars[outer_in_name] - if outer_in_name in outer_in_sot_names: - if outer_in_name in outer_in_mit_mot_names: - storage_name = f"{outer_in_name}_mitmot_storage" - elif outer_in_name in outer_in_mit_sot_names: - storage_name = f"{outer_in_name}_mitsot_storage" - else: - # Note that the outputs with single, non-`-1` taps are (e.g. `taps - # = [-2]`) are classified as mit-sot, so the code for handling - # sit-sots remains constant as follows - storage_name = f"{outer_in_name}_sitsot_storage" + if outer_in_name not in outer_in_nit_sot_names: - output_idx = len(outer_in_to_storage_name) - outer_in_to_storage_name[outer_in_name] = storage_name + storage_name = outer_in_to_storage_name[outer_in_name] - input_taps = inner_in_names_to_input_taps[outer_in_name] - tap_storage_size = -min(input_taps) - assert tap_storage_size >= 0 + is_tensor_type = isinstance(outer_in_var.type, TensorType) + if is_tensor_type: + storage_size_name = f"{outer_in_name}_len" + storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]" + input_taps = inner_in_names_to_input_taps[outer_in_name] + tap_storage_size = -min(input_taps) + assert tap_storage_size >= 0 - storage_size_name = f"{outer_in_name}_len" + for in_tap in input_taps: + tap_offset = in_tap + tap_storage_size + assert tap_offset >= 0 + add_inner_in_expr(outer_in_name, tap_offset, storage_size_name) - for in_tap in input_taps: - tap_offset = in_tap + tap_storage_size - assert tap_offset >= 0 - # In truncated storage situations (i.e. created by - # `save_mem_new_scan`), the taps and output storage overlap, - # instead of the standard situation in which the output storage - # is large enough to contain both the initial taps values and - # the output storage. - inner_in_to_index_offset.append( - (outer_in_name, tap_offset, storage_size_name) + output_taps = inner_in_names_to_output_taps.get( + outer_in_name, [tap_storage_size] ) + for out_tap in output_taps: + inner_out_to_outer_in_stmts.append( + idx_to_str(storage_name, out_tap, size=storage_size_name) + ) - output_taps = inner_in_names_to_output_taps.get( - outer_in_name, [tap_storage_size] - ) - for out_tap in output_taps: - inner_out_name_to_taps_storage.append( - (storage_name, out_tap, storage_size_name) + add_output_storage_post_proc_stmt( + storage_name, output_taps, storage_size_name ) - if output_idx in node.op.destroy_map: + else: + storage_size_stmt = "" + add_inner_in_expr(outer_in_name, None, None) + inner_out_to_outer_in_stmts.append(storage_name) + + output_idx = outer_output_names.index(storage_name) + if output_idx in node.op.destroy_map or not is_tensor_type: storage_alloc_stmt = f"{storage_name} = {outer_in_name}" else: storage_alloc_stmt = f"{storage_name} = np.copy({outer_in_name})" @@ -150,14 +249,16 @@ def numba_funcify_Scan(op, node, **kwargs): storage_alloc_stmt = dedent( f""" # {outer_in_var.type} - {storage_size_name} = {outer_in_name}.shape[0] + {storage_size_stmt} {storage_alloc_stmt} """ ).strip() - allocate_taps_storage.append(storage_alloc_stmt) + storage_alloc_stmts.append(storage_alloc_stmt) + + else: + assert outer_in_name in outer_in_nit_sot_names - elif outer_in_name in outer_in_nit_sot_names: # This is a special case in which there are no outer-inputs used # for outer-output storage, so we need to create our own storage # from scratch. @@ -166,21 +267,25 @@ def numba_funcify_Scan(op, node, **kwargs): outer_in_to_storage_name[outer_in_name] = storage_name storage_size_name = f"{outer_in_name}_len" - inner_out_name_to_taps_storage.append((storage_name, 0, storage_size_name)) + + inner_out_to_outer_in_stmts.append( + idx_to_str(storage_name, 0, size=storage_size_name) + ) + add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name) # In case of nit-sots we are provided the length of the array in # the iteration dimension instead of actual arrays, hence we # allocate space for the results accordingly. - curr_nit_sot_position = outer_in_names[1:].index(outer_in_name) - n_seqs - curr_nit_sot = op.inner_outputs[curr_nit_sot_position] - needs_alloc = curr_nit_sot.ndim > 0 + curr_nit_sot_position = outer_in_nit_sot_names.index(outer_in_name) + curr_nit_sot = op.inner_nitsot_outs(op.inner_outputs)[curr_nit_sot_position] + needs_alloc = curr_nit_sot.type.ndim > 0 storage_shape = create_tuple_string( [storage_size_name] + ["0"] * curr_nit_sot.ndim ) storage_dtype = curr_nit_sot.type.numpy_dtype.name - allocate_taps_storage.append( + storage_alloc_stmts.append( dedent( f""" # {curr_nit_sot.type} @@ -191,99 +296,43 @@ def numba_funcify_Scan(op, node, **kwargs): ) if needs_alloc: - allocate_taps_storage.append(f"{outer_in_name}_ready = False") + storage_alloc_stmts.append(f"{outer_in_name}_ready = False") # In this case, we don't know the shape of the output storage # array until we get some output from the inner-function. # With the following we add delayed output storage initialization: - inner_out_name = inner_output_names[curr_nit_sot_position] + inner_out_name = op.inner_nitsot_outs(inner_output_names)[ + curr_nit_sot_position + ] inner_out_post_processing_stmts.append( dedent( f""" if not {outer_in_name}_ready: - {storage_name} = np.empty(({storage_size_name},) + {inner_out_name}.shape, dtype=np.{storage_dtype}) + {storage_name} = np.empty(({storage_size_name},) + np.shape({inner_out_name}), dtype=np.{storage_dtype}) {outer_in_name}_ready = True """ ).strip() ) - # The non_seqs are passed to the inner function as-is for name in outer_in_non_seqs_names: - inner_in_to_index_offset.append((name, None, None)) - - inner_out_storage_indexed = [ - name if taps is None else idx_to_str(name, taps, size=size) - for (name, taps, size) in inner_out_name_to_taps_storage - ] - - output_storage_post_processing_stmts: List[str] = [] - - for outer_in_name, grp_vals in groupby( - inner_out_name_to_taps_storage, lambda x: x[0] - ): - - _, tap_sizes, storage_sizes = zip(*grp_vals) - - tap_size = max(tap_sizes) - storage_size = storage_sizes[0] - - if op.info.as_while: - # While loops need to truncate the output storage to a length given - # by the number of iterations performed. - output_storage_post_processing_stmts.append( - dedent( - f""" - if i + {tap_size} < {storage_size}: - {storage_size} = i + {tap_size} - {outer_in_name} = {outer_in_name}[:{storage_size}] - """ - ).strip() - ) - - # Rotate the storage so that the last computed value is at the end of - # the storage array. - # This is needed when the output storage array does not have a length - # equal to the number of taps plus `n_steps`. - output_storage_post_processing_stmts.append( - dedent( - f""" - {outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) - if {outer_in_name}_shift > 0: - {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] - {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] - {outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left)) - """ - ).strip() - ) + add_inner_in_expr(name, None, None) if op.info.as_while: # The inner function will return a boolean as the last value - inner_out_storage_indexed.append("cond") + inner_out_to_outer_in_stmts.append("cond") - output_names = [outer_in_to_storage_name[n] for n in outer_in_outtap_names] + assert len(inner_in_exprs) == len(op.fgraph.inputs) - # Construct the inner-input expressions - inner_inputs: List[str] = [] - for outer_in_name, tap_offset, size in inner_in_to_index_offset: - storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name) - indexed_inner_in_str = ( - idx_to_str(storage_name, tap_offset, size=size) - if tap_offset is not None - else storage_name - ) - # if outer_in_names_to_vars[outer_in_name].type.ndim - 1 <= 0: - # # Convert scalar inner-inputs to Numba scalars - # indexed_inner_in_str = f"to_numba_scalar({indexed_inner_in_str})" - inner_inputs.append(indexed_inner_in_str) - - inner_inputs = create_arg_string(inner_inputs) + inner_in_args = create_arg_string(inner_in_exprs) inner_outputs = create_tuple_string(inner_output_names) - input_storage_block = "\n".join(allocate_taps_storage) - output_storage_post_processing_block = "\n".join( - output_storage_post_processing_stmts - ) + input_storage_block = "\n".join(storage_alloc_stmts) + output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts) inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) + inner_out_to_outer_out_stmts = "\n".join( + [f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)] + ) + scan_op_src = f""" def scan({", ".join(outer_in_names)}): @@ -292,14 +341,14 @@ def scan({", ".join(outer_in_names)}): i = 0 cond = False while i < n_steps and not cond: - {inner_outputs} = scan_inner_func({inner_inputs}) + {inner_outputs} = scan_inner_func({inner_in_args}) {indent(inner_out_post_processing_block, " " * 8)} - {create_tuple_string(inner_out_storage_indexed)} = {inner_outputs} +{indent(inner_out_to_outer_out_stmts, " " * 8)} i += 1 {indent(output_storage_post_processing_block, " " * 4)} - return {create_arg_string(output_names)} + return {create_arg_string(outer_output_names)} """ global_env = { diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 99444ca787..2095587bfc 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -2,15 +2,160 @@ import pytest import aesara.tensor as at -from aesara import config, grad +from aesara import config, function, grad from aesara.compile.mode import Mode, get_mode from aesara.graph.fg import FunctionGraph from aesara.scan.basic import scan +from aesara.scan.op import Scan from aesara.scan.utils import until +from aesara.tensor.random.utils import RandomStream from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py +@pytest.mark.parametrize( + "fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check", + [ + # sequences + ( + lambda a_t: 2 * a_t, + [at.dvector("a")], + [{}], + [], + None, + [np.arange(10)], + None, + lambda op: op.info.n_seqs > 0, + ), + # nit-sot + ( + lambda: at.as_tensor(2.0), + [], + [{}], + [], + 3, + [], + None, + lambda op: op.info.n_nit_sot > 0, + ), + # nit-sot, non_seq + ( + lambda c: at.as_tensor(2.0) * c, + [], + [{}], + [at.dscalar("c")], + 3, + [1.0], + None, + lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0, + ), + # sit-sot + ( + lambda a_tm1: 2 * a_tm1, + [], + [{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}], + [], + 3, + [], + None, + lambda op: op.info.n_sit_sot > 0, + ), + # sit-sot, while + ( + lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)), + [], + [{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}], + [], + 3, + [], + None, + lambda op: op.info.n_sit_sot > 0, + ), + # nit-sot, shared input/output + ( + lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal( + 0, 1, name="a" + ), + [], + [{}], + [], + 3, + [], + [np.array([-1.63408257, 0.18046406, 2.43265803])], + lambda op: op.info.n_shared_outs > 0, + ), + # mit-sot (that's also a type of sit-sot) + ( + lambda a_tm1: 2 * a_tm1, + [], + [{"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}], + [], + 6, + [], + None, + lambda op: op.info.n_mit_sot > 0, + ), + # mit-sot + ( + lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1), + [], + [ + {"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}, + {"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}, + ], + [], + 10, + [], + None, + lambda op: op.info.n_mit_sot > 0, + ), + ], +) +def test_xit_xot_types( + fn, + sequences, + outputs_info, + non_sequences, + n_steps, + input_vals, + output_vals, + op_check, +): + """Test basic xit-xot configurations.""" + res, updates = scan( + fn, + sequences=sequences, + outputs_info=outputs_info, + non_sequences=non_sequences, + n_steps=n_steps, + strict=True, + mode=Mode(linker="py", optimizer=None), + ) + + if not isinstance(res, list): + res = [res] + + # Get rid of any `Subtensor` indexing on the `Scan` outputs + res = [r.owner.inputs[0] if not isinstance(r.owner.op, Scan) else r for r in res] + + scan_op = res[0].owner.op + assert isinstance(scan_op, Scan) + + _ = op_check(scan_op) + + if output_vals is None: + compare_numba_and_py( + (sequences + non_sequences, res), input_vals, updates=updates + ) + else: + numba_mode = get_mode("NUMBA") + numba_fn = function( + sequences + non_sequences, res, mode=numba_mode, updates=updates + ) + res_val = numba_fn(*input_vals) + assert np.allclose(res_val, output_vals) + + def test_scan_multiple_output(): """Test a scan implementation of a SEIR model. @@ -202,34 +347,10 @@ def power_step(prior_result, x): compare_numba_and_py(out_fg, test_input_vals) -def test_scan_save_mem_basic(): +@pytest.mark.parametrize("n_steps_val", [1, 5]) +def test_scan_save_mem_basic(n_steps_val): """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite.""" - k = at.iscalar("k") - A = at.dvector("A") - - result, _ = scan( - fn=lambda prior_result, A: prior_result * A, - outputs_info=at.ones_like(A), - non_sequences=A, - n_steps=k, - ) - - numba_mode = get_mode("NUMBA") # .including("scan_save_mem") - py_mode = Mode("py").including("scan_save_mem") - - out_fg = FunctionGraph([A, k], [result]) - test_input_vals = (np.arange(10, dtype=np.int32), 2) - compare_numba_and_py( - out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode - ) - test_input_vals = (np.arange(10, dtype=np.int32), 4) - compare_numba_and_py( - out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode - ) - -@pytest.mark.parametrize("n_steps_val", [1, 5]) -def test_scan_save_mem_2(n_steps_val): def f_pow2(x_tm2, x_tm1): return 2 * x_tm1 + x_tm2 @@ -245,7 +366,7 @@ def f_pow2(x_tm2, x_tm1): state_val = np.array([1.0, 2.0]) - numba_mode = get_mode("NUMBA") # .including("scan_save_mem") + numba_mode = get_mode("NUMBA").including("scan_save_mem") py_mode = Mode("py").including("scan_save_mem") out_fg = FunctionGraph([init_x, n_steps], [output])