Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dummy input variables during Scan rewrites #1145

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aesara/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ def make_node(self, *inputs):
):
outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq)
if not outer_nonseq.type.in_same_class(inner_nonseq.type):
if not inner_nonseq.type.is_super(outer_nonseq.type):
raise ValueError(
(
f"Argument {outer_nonseq} given to the scan node is not"
Expand Down
63 changes: 41 additions & 22 deletions aesara/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
graph_inputs,
io_toposort,
is_in_ancestors,
replace_nominals_with_dummies,
)
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate
Expand Down Expand Up @@ -82,6 +83,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
"""
if not isinstance(node.op, Scan):
return False

op = node.op
op_info = op.info
# We only need to take care of sequences and other arguments
Expand All @@ -92,8 +94,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
st += op_info.n_sit_sot
st += op_info.n_shared_outs

op_ins = op.inner_inputs
op_outs = op.inner_outputs
op_ins, op_outs = replace_nominals_with_dummies(op.inner_inputs, op.inner_outputs)

# Corresponds to the initial states, which should stay untouched.
# We put those variables aside, and put them back at the end.
Expand Down Expand Up @@ -189,6 +190,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
allow_gc=op.allow_gc,
)
nw_outs = nwScan(*nw_outer, return_list=True)

return dict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
else:
return False
Expand All @@ -207,7 +209,9 @@ def push_out_non_seq_scan(fgraph, node):
if not isinstance(node.op, Scan):
return False

node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
node_inputs, node_outputs = replace_nominals_with_dummies(
node.op.inner_inputs, node.op.inner_outputs
)

local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(node_outputs)
Expand Down Expand Up @@ -417,7 +421,9 @@ def push_out_seq_scan(fgraph, node):
if not isinstance(node.op, Scan):
return False

node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
node_inputs, node_outputs = replace_nominals_with_dummies(
node.op.inner_inputs, node.op.inner_outputs
)

local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(node_outputs)
Expand Down Expand Up @@ -832,9 +838,10 @@ def push_out_add_scan(fgraph, node):

# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(
node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info
inner_inputs, inner_outputs = replace_nominals_with_dummies(
op.inner_inputs, op.inner_outputs
)
args = ScanArgs(node.inputs, node.outputs, inner_inputs, inner_outputs, op.info)

clients = {}
local_fgraph_topo = io_toposort(
Expand Down Expand Up @@ -1694,6 +1701,8 @@ def merge(self, nodes):
inner_outs = [[] for nd in nodes]
outer_outs = []

# inner_inputs, inner_outputs = replace_nominals_with_dummies(nd.op.inner_inputs, nd.op.inner_outputs)

def rename(ls, suffix):
for k in ls:
if k.name:
Expand Down Expand Up @@ -1967,11 +1976,16 @@ def scan_merge_inouts(fgraph, node):
# Do a first pass to merge identical external inputs.
# Equivalent inputs will be stored in inp_equiv, then a new
# scan node created without duplicates.

inner_inputs, inner_outputs = replace_nominals_with_dummies(
node.op.inner_inputs, node.op.inner_outputs
)

a = ScanArgs(
node.inputs,
node.outputs,
node.op.inner_inputs,
node.op.inner_outputs,
inner_inputs,
inner_outputs,
node.op.info,
)

Expand Down Expand Up @@ -2173,10 +2187,15 @@ def push_out_dot1_scan(fgraph, node):
# Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot
op = node.op
sitsot_ins = op.inner_sitsot(op.inner_inputs)
sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)

inner_inputs, inner_outputs = replace_nominals_with_dummies(
op.inner_inputs, op.inner_outputs
)

sitsot_ins = op.inner_sitsot(inner_inputs)
sitsot_outs = op.inner_sitsot_outs(inner_outputs)
outer_sitsot = op.outer_sitsot_outs(node.outputs)
seqs = op.inner_seqs(op.inner_inputs)
seqs = op.inner_seqs(inner_inputs)
for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):

if (
Expand Down Expand Up @@ -2218,23 +2237,23 @@ def push_out_dot1_scan(fgraph, node):
# First let us split all arguments according to their
# corresponding categories

inner_seqs = op.inner_seqs(op.inner_inputs)
inner_seqs = op.inner_seqs(inner_inputs)
outer_seqs = op.outer_seqs(node.inputs)
inner_mitmot = op.inner_mitmot(op.inner_inputs)
inner_mitmot = op.inner_mitmot(inner_inputs)
outer_mitmot = op.outer_mitmot(node.inputs)
inner_mitmot_outs = op.inner_mitmot_outs(op.inner_outputs)
inner_mitsot = op.inner_mitsot(op.inner_inputs)
inner_mitmot_outs = op.inner_mitmot_outs(inner_outputs)
inner_mitsot = op.inner_mitsot(inner_inputs)
outer_mitsot = op.outer_mitsot(node.inputs)
inner_mitsot_outs = op.inner_mitsot_outs(op.inner_outputs)
inner_sitsot = op.inner_sitsot(op.inner_inputs)
inner_mitsot_outs = op.inner_mitsot_outs(inner_outputs)
inner_sitsot = op.inner_sitsot(inner_inputs)
outer_sitsot = op.outer_sitsot(node.inputs)
inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
inner_sitsot_outs = op.inner_sitsot_outs(inner_outputs)
outer_nitsot = op.outer_nitsot(node.inputs)
inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs)
inner_shared = op.inner_shared(op.inner_inputs)
inner_nitsot_outs = op.inner_nitsot_outs(inner_outputs)
inner_shared = op.inner_shared(inner_inputs)
outer_shared = op.outer_shared(node.inputs)
inner_shared_outs = op.inner_shared_outs(op.inner_outputs)
inner_non_seqs = op.inner_non_seqs(op.inner_inputs)
inner_shared_outs = op.inner_shared_outs(inner_outputs)
inner_non_seqs = op.inner_non_seqs(inner_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)

new_info = dataclasses.replace(
Expand Down