Skip to content

Commit d0f42fe

Browse files
Add support for default updates in OpFromGraph
1 parent 48409dd commit d0f42fe

File tree

4 files changed

+291
-57
lines changed

4 files changed

+291
-57
lines changed

aesara/compile/builders.py

+155-44
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from collections import OrderedDict
33
from copy import copy
44
from functools import partial
5-
from typing import Dict, List, Optional, Sequence, Tuple, cast
5+
from typing import List, Optional, Sequence, Tuple, cast
66

77
import aesara.tensor as at
8-
from aesara import function
98
from aesara.compile.function.pfunc import rebuild_collect_shared
9+
from aesara.compile.io import In, Out
1010
from aesara.compile.mode import optdb
11+
from aesara.compile.ops import update_placeholder
1112
from aesara.compile.sharedvalue import SharedVariable
1213
from aesara.configdefaults import config
1314
from aesara.gradient import DisconnectedType, Rop, grad
@@ -83,13 +84,26 @@ def local_traverse(out):
8384

8485
def construct_nominal_fgraph(
8586
inputs: Sequence[Variable], outputs: Sequence[Variable]
86-
) -> Tuple[
87-
FunctionGraph,
88-
Sequence[Variable],
89-
Dict[Variable, Variable],
90-
Dict[Variable, Variable],
91-
]:
92-
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
87+
) -> Tuple[FunctionGraph, Sequence[Variable],]:
88+
r"""Construct an inner-`FunctionGraph` with ordered nominal inputs.
89+
90+
.. note::
91+
92+
Updates (e.g. from `SharedVariable.default_update`) are appended to the resulting
93+
`FunctionGraph`'s outputs.
94+
95+
Parameters
96+
==========
97+
inputs
98+
A list of inputs.
99+
outputs
100+
A list of outputs.
101+
102+
Returns
103+
=======
104+
The `FunctionGraph` and a list of shared inputs.
105+
106+
"""
93107
dummy_inputs = []
94108
for n, inp in enumerate(inputs):
95109
if (
@@ -105,6 +119,7 @@ def construct_nominal_fgraph(
105119

106120
dummy_shared_inputs = []
107121
shared_inputs = []
122+
default_updates = {}
108123
for var in graph_inputs(outputs, inputs):
109124
if isinstance(var, SharedVariable):
110125
# To correctly support shared variables the inner-graph should
@@ -113,14 +128,18 @@ def construct_nominal_fgraph(
113128
# That's why we collect the shared variables and replace them
114129
# with dummies.
115130
shared_inputs.append(var)
116-
dummy_shared_inputs.append(var.type())
131+
dummy_var = var.type()
132+
dummy_shared_inputs.append(dummy_var)
133+
134+
if var.default_update:
135+
default_updates[dummy_var] = var.default_update
117136
elif var not in inputs and not isinstance(var, Constant):
118137
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
119138

120139
replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))
121140

122141
new = rebuild_collect_shared(
123-
cast(Sequence[Variable], outputs),
142+
outputs=cast(Sequence[Variable], outputs + list(default_updates.values())),
124143
inputs=inputs + shared_inputs,
125144
replace=replacements,
126145
copy_inputs_over=False,
@@ -131,13 +150,23 @@ def construct_nominal_fgraph(
131150
(clone_d, update_d, update_expr, new_shared_inputs),
132151
) = new
133152

153+
local_default_updates = local_outputs[len(outputs) :]
154+
update_d.update(
155+
{clone_d[k]: v for k, v in zip(default_updates.keys(), local_default_updates)}
156+
)
157+
update_expr.extend(local_default_updates)
158+
134159
assert len(local_inputs) == len(inputs) + len(shared_inputs)
135-
assert len(local_outputs) == len(outputs)
136-
assert not update_d
137-
assert not update_expr
160+
assert len(local_outputs) == len(outputs) + len(default_updates)
138161
assert not new_shared_inputs
139162

140-
fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
163+
update_mapping = {
164+
local_outputs.index(v): local_inputs.index(k) for k, v in update_d.items()
165+
}
166+
167+
fgraph = FunctionGraph(
168+
local_inputs, local_outputs, clone=False, update_mapping=update_mapping
169+
)
141170

142171
# The inputs need to be `NominalVariable`s so that we can merge
143172
# inner-graphs
@@ -153,7 +182,7 @@ def construct_nominal_fgraph(
153182
fgraph.clients.pop(inp, None)
154183
fgraph.add_input(nom_inp)
155184

156-
return fgraph, shared_inputs, update_d, update_expr
185+
return fgraph, shared_inputs
157186

158187

159188
class OpFromGraph(Op, HasInnerGraph):
@@ -316,37 +345,44 @@ def __init__(
316345
name: Optional[str] = None,
317346
**kwargs,
318347
):
319-
"""
348+
r"""Construct an `OpFromGraph` instance.
349+
350+
.. note::
351+
352+
`SharedVariable`\s in `outputs` will have their `SharedVariable.default_update` values
353+
altered in order to support in-lining in the presence of updates.
354+
320355
Parameters
321356
----------
322357
inputs
323358
The inputs to the graph.
324359
outputs
325360
The outputs to the graph.
326361
inline
327-
Defaults to ``False``
328-
329362
``True`` : Cause the :class:`Op`'s original graph being used during
330363
compilation, the :class:`Op` will not be visible in the compiled
331364
graph but rather its internal graph.
332365
333366
``False`` : will use a pre-compiled function inside.
367+
368+
Defaults to ``False``.
369+
334370
grad_overrides
335-
Defaults to ``'default'``.
336371
This argument is mutually exclusive with ``lop_overrides``.
337372
338-
``'default'`` : Do not override, use default grad() result
373+
``'default'`` : Do not override, use default :meth:`Op.grad` result
339374
340375
`OpFromGraph`: Override with another `OpFromGraph`, should
341376
accept inputs as the same order and types of ``inputs`` and ``output_grads``
342-
arguments as one would specify in :meth:`Op.grad`() method.
377+
arguments as one would specify in :meth:`Op.grad` method.
343378
344379
`callable`: Should take two args: ``inputs`` and ``output_grads``.
345380
Each argument is expected to be a list of :class:`Variable `.
346381
Must return list of :class:`Variable `.
347-
lop_overrides
382+
348383
Defaults to ``'default'``.
349384
385+
lop_overrides
350386
This argument is mutually exclusive with ``grad_overrides``.
351387
352388
These options are similar to the ``grad_overrides`` above, but for
@@ -355,7 +391,7 @@ def __init__(
355391
``'default'``: Do not override, use the default :meth:`Op.L_op` result
356392
357393
`OpFromGraph`: Override with another `OpFromGraph`, should
358-
accept inputs as the same order and types of ``inputs``,
394+
accept inputs in the same order and types as `inputs`,
359395
``outputs`` and ``output_grads`` arguments as one would specify in
360396
:meth:`Op.grad` method.
361397
@@ -371,11 +407,11 @@ def __init__(
371407
:class:`Variable`. Each list element corresponds to gradient of
372408
a specific input, length of list must be equal to number of inputs.
373409
410+
Defaults to ``'default'``.
411+
374412
rop_overrides
375413
One of ``{'default', OpFromGraph, callable, Variable}``.
376414
377-
Defaults to ``'default'``.
378-
379415
``'default'``: Do not override, use the default :meth:`Op.R_op` result
380416
381417
`OpFromGraph`: Override with another `OpFromGraph`, should
@@ -397,11 +433,13 @@ def __init__(
397433
must be equal to number of outputs. connection_pattern If not
398434
``None``, this will be used as the connection_pattern for this
399435
:class:`Op`.
436+
437+
Defaults to ``'default'``.
438+
400439
name
401440
A name for debugging purposes.
402441
kwargs
403-
Check :func:`aesara.function` for more arguments, only works when not
404-
inline.
442+
See :func:`aesara.function`.
405443
"""
406444

407445
if not (isinstance(inputs, list) and isinstance(outputs, list)):
@@ -418,9 +456,24 @@ def __init__(
418456

419457
self.is_inline = inline
420458

421-
self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
422-
inputs, outputs
423-
)
459+
# These `shared_inputs` are the original variables in `outputs`
460+
# (i.e. not clones).
461+
self.fgraph, shared_inputs = construct_nominal_fgraph(inputs, outputs)
462+
463+
# We need to hold on to the original variables so that gradients can be
464+
# taken wrt. them. Ideally, we wouldn't hold on to specific `Variable`
465+
# references like this outside of graph, but we're maintaining support
466+
# for old functionality right now.
467+
self.shared_inputs = []
468+
for v in shared_inputs:
469+
# This is needed so that `aesara.function` will create an update
470+
# output placeholder in the `FunctionGraph` it compiles. We need
471+
# placeholders like this in order to properly inline `OpFromGraph`s
472+
# containing updates.
473+
# FYI: When the corresponding updates aren't used, they should be
474+
# removed at the `aesara.function` level.
475+
v.default_update = update_placeholder(v)
476+
self.shared_inputs.append(v)
424477

425478
self.kwargs = kwargs
426479
self.input_types = [inp.type for inp in inputs]
@@ -933,7 +986,36 @@ def fn(self):
933986
if getattr(self, "_fn", None) is not None:
934987
return self._fn
935988

936-
self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
989+
from aesara.compile.function.pfunc import pfunc
990+
991+
# We don't want calls/evaluations of this `Op` to change
992+
# the inner-graph, so we need to clone it
993+
fgraph, _ = self.fgraph.clone_get_equiv(copy_inputs=False, copy_orphans=False)
994+
995+
wrapped_inputs = [In(x, borrow=False) for x in fgraph.inputs]
996+
wrapped_outputs = [Out(x, borrow=True) for x in fgraph.outputs]
997+
998+
n_inputs = len(fgraph.inputs)
999+
1000+
for out_idx, in_idx in fgraph.update_mapping.items():
1001+
shared_input = self.shared_inputs[in_idx - n_inputs]
1002+
in_var = fgraph.inputs[in_idx]
1003+
updated_wrapped_input = In(
1004+
variable=in_var,
1005+
value=shared_input.container,
1006+
update=fgraph.outputs[out_idx],
1007+
implicit=True,
1008+
shared=True,
1009+
)
1010+
wrapped_inputs[in_idx] = updated_wrapped_input
1011+
1012+
self._fn = pfunc(
1013+
wrapped_inputs,
1014+
wrapped_outputs,
1015+
fgraph=fgraph,
1016+
no_default_updates=True,
1017+
**self.kwargs,
1018+
)
9371019
self._fn.trust_input = True
9381020

9391021
return self._fn
@@ -944,6 +1026,11 @@ def inner_inputs(self):
9441026

9451027
@property
9461028
def inner_outputs(self):
1029+
"""Return all the outputs except those used for updates."""
1030+
n_updates = len(self.fgraph.update_mapping)
1031+
if n_updates > 0:
1032+
return self.fgraph.outputs[:-n_updates]
1033+
9471034
return self.fgraph.outputs
9481035

9491036
def clone(self):
@@ -952,28 +1039,52 @@ def clone(self):
9521039
return res
9531040

9541041
def perform(self, node, inputs, outputs):
955-
variables = self.fn(*inputs)
956-
assert len(variables) == len(outputs)
957-
for output, variable in zip(outputs, variables):
958-
output[0] = variable
1042+
results = self.fn(*inputs)
1043+
for output, res in zip(outputs, results):
1044+
output[0] = res
9591045

9601046

9611047
@node_rewriter([OpFromGraph])
9621048
def inline_ofg_expansion(fgraph, node):
963-
"""
964-
This optimization expands internal graph of OpFromGraph.
965-
Only performed if node.op.is_inline == True
966-
Doing so can improve optimization at the cost of compilation speed.
1049+
"""Expand the internal graph of an `OpFromGraph`.
1050+
1051+
Only performed if ``node.op.is_inline == True``.
1052+
9671053
"""
9681054
op = node.op
969-
if not isinstance(op, OpFromGraph):
970-
return False
1055+
9711056
if not op.is_inline:
9721057
return False
973-
return clone_replace(
974-
op.inner_outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)}
1058+
1059+
outputs = clone_replace(
1060+
op.fgraph.outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)}
9751061
)
9761062

1063+
replacements = {
1064+
old_var: new_var
1065+
for old_var, new_var in zip(node.outputs, outputs[: len(op.inner_outputs)])
1066+
}
1067+
1068+
# Add the updates from `OpFromGraph` into the outer-graph
1069+
for out_idx, in_idx in op.fgraph.update_mapping.items():
1070+
shared_input = node.inputs[in_idx]
1071+
assert isinstance(shared_input, SharedVariable)
1072+
1073+
outer_in_idx = fgraph.inputs.index(shared_input)
1074+
1075+
# There should be a placeholder output in `fgraph.outputs` that we can
1076+
# use. If there isn't, then someone forgot/removed the
1077+
# `SharedVariable.default_update`s on the inputs to the `OpFromGraph`
1078+
# (i.e. at the user-level/graph construction-time).
1079+
outer_out_idx = fgraph.inv_update_mapping[outer_in_idx]
1080+
update_var = fgraph.outputs[outer_out_idx]
1081+
1082+
assert update_var is not shared_input
1083+
1084+
replacements[update_var] = outputs[out_idx]
1085+
1086+
return replacements
1087+
9771088

9781089
# We want to run this before the first merge optimizer
9791090
# and before the first scan optimizer.

aesara/scan/op.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,9 @@ def __init__(
750750
If ``True``, all the shared variables used in the inner-graph must be provided.
751751
752752
"""
753-
self.fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs)
753+
self.fgraph, shared_inputs = construct_nominal_fgraph(inputs, outputs)
754+
755+
assert not self.fgraph.update_mapping
754756

755757
# The shared variables should have been removed, so, if there are
756758
# any, it's because the user didn't specify an input.

0 commit comments

Comments
 (0)