Skip to content

Commit

Permalink
Generalize the inner-FunctionGraph construction process
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 20, 2022
1 parent 5044067 commit 14c394d
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 99 deletions.
142 changes: 78 additions & 64 deletions aesara/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import OrderedDict
from copy import copy
from functools import partial
from typing import List, Optional, Sequence, cast
from typing import Dict, List, Optional, Sequence, Tuple, cast

import aesara.tensor as at
from aesara import function
Expand Down Expand Up @@ -81,6 +81,81 @@ def local_traverse(out):
return ret


def construct_nominal_fgraph(
inputs: Sequence[Variable], outputs: Sequence[Variable]
) -> Tuple[
FunctionGraph,
Sequence[Variable],
Dict[Variable, Variable],
Dict[Variable, Variable],
]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)

dummy_inputs.append(inp.type())

dummy_shared_inputs = []
shared_inputs = []
for var in graph_inputs(outputs, inputs):
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")

replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))

new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
(
local_inputs,
local_outputs,
(clone_d, update_d, update_expr, new_shared_inputs),
) = new

assert len(local_inputs) == len(inputs) + len(shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not new_shared_inputs

fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)

# The inputs need to be `NominalVariable`s so that we can merge
# inner-graphs
nominal_local_inputs = tuple(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)

fgraph.replace_all(zip(local_inputs, nominal_local_inputs))

for i, inp in enumerate(fgraph.inputs):
nom_inp = nominal_local_inputs[i]
fgraph.inputs[i] = nom_inp
fgraph.clients.pop(inp, None)
fgraph.add_input(nom_inp)

return fgraph, shared_inputs, update_d, update_expr


class OpFromGraph(Op, HasInnerGraph):
r"""
This creates an `Op` from inputs and outputs lists of variables.
Expand Down Expand Up @@ -338,76 +413,15 @@ def __init__(
f"Inputs and outputs must be Variable instances; got {out}"
)

dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)

dummy_inputs.append(inp.type())

if "updates" in kwargs or "givens" in kwargs:
raise NotImplementedError("Updates and givens are not supported")

self.is_inline = inline

dummy_shared_inputs = []
self.shared_inputs = []
for var in graph_inputs(outputs, inputs):
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
self.shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")

replacements = dict(
zip(inputs + self.shared_inputs, dummy_inputs + dummy_shared_inputs)
self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
inputs, outputs
)

new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + self.shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
(
local_inputs,
local_outputs,
(clone_d, update_d, update_expr, shared_inputs),
) = new

assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not shared_inputs

self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)

# The inputs need to be `NominalVariable`s so that we can merge
# inner-graphs
nominal_local_inputs = tuple(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)

self.fgraph.replace_all(zip(local_inputs, nominal_local_inputs))

for i, inp in enumerate(self.fgraph.inputs):
nom_inp = nominal_local_inputs[i]
self.fgraph.inputs[i] = nom_inp
self.fgraph.clients.pop(inp, None)
self.fgraph.add_input(nom_inp)

self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
Expand Down
41 changes: 10 additions & 31 deletions aesara/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@

import aesara
from aesara import tensor as at
from aesara.compile import SharedVariable
from aesara.compile.builders import infer_shape
from aesara.compile.builders import construct_nominal_fgraph, infer_shape
from aesara.compile.function.pfunc import pfunc
from aesara.compile.io import In, Out
from aesara.compile.mode import Mode, get_default_mode, get_mode
Expand All @@ -65,17 +64,13 @@
from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
from aesara.graph.basic import (
Apply,
Constant,
NominalVariable,
Variable,
clone_replace,
equal_computations,
graph_inputs,
io_connection_pattern,
replace_nominals_with_dummies,
)
from aesara.graph.features import NoOutputFromInplace
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.utils import InconsistencyError, MissingInputError
from aesara.link.c.basic import CLinker
Expand Down Expand Up @@ -755,22 +750,12 @@ def __init__(
If ``True``, all the shared variables used in the inner-graph must be provided.
"""
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
self.fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs)

input_replacements = []
for n, v in enumerate(inputs):
if not isinstance(v, (SharedVariable, Constant)):
input_replacements.append((v, NominalVariable(n, v.type)))

assert not isinstance(v, NominalVariable)

outputs = clone_replace(outputs, replace=input_replacements)

if input_replacements:
_, inputs_ = zip(*input_replacements)
inputs = list(inputs_)
else:
inputs = []
# The shared variables should have been removed, so, if there are
# any, it's because the user didn't specify an input.
if shared_inputs:
raise MissingInputError(f"Scan is missing inputs: {shared_inputs}")

self.info = info
self.truncate_gradient = truncate_gradient
Expand All @@ -782,7 +767,7 @@ def __init__(
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
message = self.name + " sub profile"
message = f"{self.name} sub profile"
else:
message = "Scan sub profile"

Expand All @@ -805,7 +790,7 @@ def tensorConstructor(shape, dtype):
while idx < info.n_mit_mot_outs:
# Not that for mit_mot there are several output slices per
# output sequence
o = outputs[idx]
o = self.fgraph.outputs[idx]
self.output_types.append(
# TODO: What can we actually say about the shape of this
# added dimension?
Expand All @@ -818,15 +803,15 @@ def tensorConstructor(shape, dtype):
# mit_sot / sit_sot / nit_sot
end = idx + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot

for o in outputs[idx:end]:
for o in self.fgraph.outputs[idx:end]:
self.output_types.append(
# TODO: What can we actually say about the shape of this
# added dimension?
typeConstructor((None,) + o.type.shape, o.type.dtype)
)

# shared outputs + possibly the ending condition
for o in outputs[end:]:
for o in self.fgraph.outputs[end:]:
self.output_types.append(o.type)

if info.as_while:
Expand Down Expand Up @@ -862,19 +847,13 @@ def tensorConstructor(shape, dtype):
self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs

self.fgraph = FunctionGraph(inputs, outputs, clone=False)

_ = self.prepare_fgraph(self.fgraph)

if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
raise InconsistencyError(
"Inner-graphs must not contain in-place operations."
)

# Do the missing inputs check here to have the error early.
for var in graph_inputs(self.inner_outputs, self.inner_inputs):
if var not in self.inner_inputs and not isinstance(var, Constant):
raise MissingInputError(f"ScanOp is missing an input: {repr(var)}")
self._cmodule_key = CLinker().cmodule_key_variables(
self.inner_inputs, self.inner_outputs, []
)
Expand Down
4 changes: 0 additions & 4 deletions tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,6 @@ def f_rnn_shared(u_t, x_tm1, tmp_W_in, tmp_W):
assert np.allclose(aesara_values, v_out)

def test_oinp_iinp_iout_oout_mappings(self):
"""
Test the mapping produces by
ScanOp.get_oinp_iinp_iout_oout_mappings()
"""

rng = RandomStream(123)

Expand Down

0 comments on commit 14c394d

Please sign in to comment.