diff --git a/aesara/scan/op.py b/aesara/scan/op.py index 05e6e7dc59..e2174ace62 100644 --- a/aesara/scan/op.py +++ b/aesara/scan/op.py @@ -1210,7 +1210,7 @@ def make_node(self, *inputs): ): outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq) new_inputs.append(outer_nonseq) - if not outer_nonseq.type.in_same_class(inner_nonseq.type): + if not inner_nonseq.type.is_super(outer_nonseq.type): raise ValueError( ( f"Argument {outer_nonseq} given to the scan node is not" diff --git a/aesara/scan/rewriting.py b/aesara/scan/rewriting.py index f63db7b74c..276a343ba3 100644 --- a/aesara/scan/rewriting.py +++ b/aesara/scan/rewriting.py @@ -23,6 +23,7 @@ graph_inputs, io_toposort, is_in_ancestors, + replace_nominals_with_dummies, ) from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.features import ReplaceValidate @@ -82,6 +83,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): """ if not isinstance(node.op, Scan): return False + op = node.op op_info = op.info # We only need to take care of sequences and other arguments @@ -92,8 +94,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): st += op_info.n_sit_sot st += op_info.n_shared_outs - op_ins = op.inner_inputs - op_outs = op.inner_outputs + op_ins, op_outs = replace_nominals_with_dummies(op.inner_inputs, op.inner_outputs) # Corresponds to the initial states, which should stay untouched. # We put those variables aside, and put them back at the end. @@ -189,6 +190,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): allow_gc=op.allow_gc, ) nw_outs = nwScan(*nw_outer, return_list=True) + return dict([("remove", [node])] + list(zip(node.outputs, nw_outs))) else: return False @@ -207,7 +209,9 @@ def push_out_non_seq_scan(fgraph, node): if not isinstance(node.op, Scan): return False - node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs + node_inputs, node_outputs = replace_nominals_with_dummies( + node.op.inner_inputs, node.op.inner_outputs + ) local_fgraph_topo = io_toposort(node_inputs, node_outputs) local_fgraph_outs_set = set(node_outputs) @@ -417,7 +421,9 @@ def push_out_seq_scan(fgraph, node): if not isinstance(node.op, Scan): return False - node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs + node_inputs, node_outputs = replace_nominals_with_dummies( + node.op.inner_inputs, node.op.inner_outputs + ) local_fgraph_topo = io_toposort(node_inputs, node_outputs) local_fgraph_outs_set = set(node_outputs) @@ -832,9 +838,10 @@ def push_out_add_scan(fgraph, node): # Use `ScanArgs` to parse the inputs and outputs of scan for ease of # use - args = ScanArgs( - node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info + inner_inputs, inner_outputs = replace_nominals_with_dummies( + op.inner_inputs, op.inner_outputs ) + args = ScanArgs(node.inputs, node.outputs, inner_inputs, inner_outputs, op.info) clients = {} local_fgraph_topo = io_toposort( @@ -1694,6 +1701,8 @@ def merge(self, nodes): inner_outs = [[] for nd in nodes] outer_outs = [] + # inner_inputs, inner_outputs = replace_nominals_with_dummies(nd.op.inner_inputs, nd.op.inner_outputs) + def rename(ls, suffix): for k in ls: if k.name: @@ -1967,11 +1976,16 @@ def scan_merge_inouts(fgraph, node): # Do a first pass to merge identical external inputs. # Equivalent inputs will be stored in inp_equiv, then a new # scan node created without duplicates. + + inner_inputs, inner_outputs = replace_nominals_with_dummies( + node.op.inner_inputs, node.op.inner_outputs + ) + a = ScanArgs( node.inputs, node.outputs, - node.op.inner_inputs, - node.op.inner_outputs, + inner_inputs, + inner_outputs, node.op.info, ) @@ -2173,10 +2187,15 @@ def push_out_dot1_scan(fgraph, node): # Note that this works when only you need X[-1] in the end # and assumes dimshuffle are applied to vectors before calling dot op = node.op - sitsot_ins = op.inner_sitsot(op.inner_inputs) - sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) + + inner_inputs, inner_outputs = replace_nominals_with_dummies( + op.inner_inputs, op.inner_outputs + ) + + sitsot_ins = op.inner_sitsot(inner_inputs) + sitsot_outs = op.inner_sitsot_outs(inner_outputs) outer_sitsot = op.outer_sitsot_outs(node.outputs) - seqs = op.inner_seqs(op.inner_inputs) + seqs = op.inner_seqs(inner_inputs) for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot): if ( @@ -2218,23 +2237,23 @@ def push_out_dot1_scan(fgraph, node): # First let us split all arguments according to their # corresponding categories - inner_seqs = op.inner_seqs(op.inner_inputs) + inner_seqs = op.inner_seqs(inner_inputs) outer_seqs = op.outer_seqs(node.inputs) - inner_mitmot = op.inner_mitmot(op.inner_inputs) + inner_mitmot = op.inner_mitmot(inner_inputs) outer_mitmot = op.outer_mitmot(node.inputs) - inner_mitmot_outs = op.inner_mitmot_outs(op.inner_outputs) - inner_mitsot = op.inner_mitsot(op.inner_inputs) + inner_mitmot_outs = op.inner_mitmot_outs(inner_outputs) + inner_mitsot = op.inner_mitsot(inner_inputs) outer_mitsot = op.outer_mitsot(node.inputs) - inner_mitsot_outs = op.inner_mitsot_outs(op.inner_outputs) - inner_sitsot = op.inner_sitsot(op.inner_inputs) + inner_mitsot_outs = op.inner_mitsot_outs(inner_outputs) + inner_sitsot = op.inner_sitsot(inner_inputs) outer_sitsot = op.outer_sitsot(node.inputs) - inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) + inner_sitsot_outs = op.inner_sitsot_outs(inner_outputs) outer_nitsot = op.outer_nitsot(node.inputs) - inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs) - inner_shared = op.inner_shared(op.inner_inputs) + inner_nitsot_outs = op.inner_nitsot_outs(inner_outputs) + inner_shared = op.inner_shared(inner_inputs) outer_shared = op.outer_shared(node.inputs) - inner_shared_outs = op.inner_shared_outs(op.inner_outputs) - inner_non_seqs = op.inner_non_seqs(op.inner_inputs) + inner_shared_outs = op.inner_shared_outs(inner_outputs) + inner_non_seqs = op.inner_non_seqs(inner_inputs) outer_non_seqs = op.outer_non_seqs(node.inputs) new_info = dataclasses.replace(