2
2
from collections import OrderedDict
3
3
from copy import copy
4
4
from functools import partial
5
- from typing import Dict , List , Optional , Sequence , Tuple , cast
5
+ from typing import List , Optional , Sequence , Tuple , cast
6
6
7
7
import aesara .tensor as at
8
- from aesara import function
9
8
from aesara .compile .function .pfunc import rebuild_collect_shared
9
+ from aesara .compile .io import In , Out
10
10
from aesara .compile .mode import optdb
11
+ from aesara .compile .ops import update_placeholder
11
12
from aesara .compile .sharedvalue import SharedVariable
12
13
from aesara .configdefaults import config
13
14
from aesara .gradient import DisconnectedType , Rop , grad
@@ -83,13 +84,26 @@ def local_traverse(out):
83
84
84
85
def construct_nominal_fgraph (
85
86
inputs : Sequence [Variable ], outputs : Sequence [Variable ]
86
- ) -> Tuple [
87
- FunctionGraph ,
88
- Sequence [Variable ],
89
- Dict [Variable , Variable ],
90
- Dict [Variable , Variable ],
91
- ]:
92
- """Construct an inner-`FunctionGraph` with ordered nominal inputs."""
87
+ ) -> Tuple [FunctionGraph , Sequence [Variable ],]:
88
+ r"""Construct an inner-`FunctionGraph` with ordered nominal inputs.
89
+
90
+ .. note::
91
+
92
+ Updates (e.g. from `SharedVariable.default_update`) are appended to the resulting
93
+ `FunctionGraph`'s outputs.
94
+
95
+ Parameters
96
+ ==========
97
+ inputs
98
+ A list of inputs.
99
+ outputs
100
+ A list of outputs.
101
+
102
+ Returns
103
+ =======
104
+ The `FunctionGraph` and a list of shared inputs.
105
+
106
+ """
93
107
dummy_inputs = []
94
108
for n , inp in enumerate (inputs ):
95
109
if (
@@ -105,6 +119,7 @@ def construct_nominal_fgraph(
105
119
106
120
dummy_shared_inputs = []
107
121
shared_inputs = []
122
+ default_updates = {}
108
123
for var in graph_inputs (outputs , inputs ):
109
124
if isinstance (var , SharedVariable ):
110
125
# To correctly support shared variables the inner-graph should
@@ -113,14 +128,18 @@ def construct_nominal_fgraph(
113
128
# That's why we collect the shared variables and replace them
114
129
# with dummies.
115
130
shared_inputs .append (var )
116
- dummy_shared_inputs .append (var .type ())
131
+ dummy_var = var .type ()
132
+ dummy_shared_inputs .append (dummy_var )
133
+
134
+ if var .default_update :
135
+ default_updates [dummy_var ] = var .default_update
117
136
elif var not in inputs and not isinstance (var , Constant ):
118
137
raise MissingInputError (f"OpFromGraph is missing an input: { var } " )
119
138
120
139
replacements = dict (zip (inputs + shared_inputs , dummy_inputs + dummy_shared_inputs ))
121
140
122
141
new = rebuild_collect_shared (
123
- cast (Sequence [Variable ], outputs ),
142
+ outputs = cast (Sequence [Variable ], outputs + list ( default_updates . values ()) ),
124
143
inputs = inputs + shared_inputs ,
125
144
replace = replacements ,
126
145
copy_inputs_over = False ,
@@ -131,13 +150,23 @@ def construct_nominal_fgraph(
131
150
(clone_d , update_d , update_expr , new_shared_inputs ),
132
151
) = new
133
152
153
+ local_default_updates = local_outputs [len (outputs ) :]
154
+ update_d .update (
155
+ {clone_d [k ]: v for k , v in zip (default_updates .keys (), local_default_updates )}
156
+ )
157
+ update_expr .extend (local_default_updates )
158
+
134
159
assert len (local_inputs ) == len (inputs ) + len (shared_inputs )
135
- assert len (local_outputs ) == len (outputs )
136
- assert not update_d
137
- assert not update_expr
160
+ assert len (local_outputs ) == len (outputs ) + len (default_updates )
138
161
assert not new_shared_inputs
139
162
140
- fgraph = FunctionGraph (local_inputs , local_outputs , clone = False )
163
+ update_mapping = {
164
+ local_outputs .index (v ): local_inputs .index (k ) for k , v in update_d .items ()
165
+ }
166
+
167
+ fgraph = FunctionGraph (
168
+ local_inputs , local_outputs , clone = False , update_mapping = update_mapping
169
+ )
141
170
142
171
# The inputs need to be `NominalVariable`s so that we can merge
143
172
# inner-graphs
@@ -153,7 +182,7 @@ def construct_nominal_fgraph(
153
182
fgraph .clients .pop (inp , None )
154
183
fgraph .add_input (nom_inp )
155
184
156
- return fgraph , shared_inputs , update_d , update_expr
185
+ return fgraph , shared_inputs
157
186
158
187
159
188
class OpFromGraph (Op , HasInnerGraph ):
@@ -316,37 +345,44 @@ def __init__(
316
345
name : Optional [str ] = None ,
317
346
** kwargs ,
318
347
):
319
- """
348
+ r"""Construct an `OpFromGraph` instance.
349
+
350
+ .. note::
351
+
352
+ `SharedVariable`\s in `outputs` will have their `SharedVariable.default_update` values
353
+ altered in order to support in-lining in the presence of updates.
354
+
320
355
Parameters
321
356
----------
322
357
inputs
323
358
The inputs to the graph.
324
359
outputs
325
360
The outputs to the graph.
326
361
inline
327
- Defaults to ``False``
328
-
329
362
``True`` : Cause the :class:`Op`'s original graph being used during
330
363
compilation, the :class:`Op` will not be visible in the compiled
331
364
graph but rather its internal graph.
332
365
333
366
``False`` : will use a pre-compiled function inside.
367
+
368
+ Defaults to ``False``.
369
+
334
370
grad_overrides
335
- Defaults to ``'default'``.
336
371
This argument is mutually exclusive with ``lop_overrides``.
337
372
338
- ``'default'`` : Do not override, use default grad() result
373
+ ``'default'`` : Do not override, use default :meth:`Op. grad` result
339
374
340
375
`OpFromGraph`: Override with another `OpFromGraph`, should
341
376
accept inputs as the same order and types of ``inputs`` and ``output_grads``
342
- arguments as one would specify in :meth:`Op.grad`() method.
377
+ arguments as one would specify in :meth:`Op.grad` method.
343
378
344
379
`callable`: Should take two args: ``inputs`` and ``output_grads``.
345
380
Each argument is expected to be a list of :class:`Variable `.
346
381
Must return list of :class:`Variable `.
347
- lop_overrides
382
+
348
383
Defaults to ``'default'``.
349
384
385
+ lop_overrides
350
386
This argument is mutually exclusive with ``grad_overrides``.
351
387
352
388
These options are similar to the ``grad_overrides`` above, but for
@@ -355,7 +391,7 @@ def __init__(
355
391
``'default'``: Do not override, use the default :meth:`Op.L_op` result
356
392
357
393
`OpFromGraph`: Override with another `OpFromGraph`, should
358
- accept inputs as the same order and types of `` inputs` `,
394
+ accept inputs in the same order and types as ` inputs`,
359
395
``outputs`` and ``output_grads`` arguments as one would specify in
360
396
:meth:`Op.grad` method.
361
397
@@ -371,11 +407,11 @@ def __init__(
371
407
:class:`Variable`. Each list element corresponds to gradient of
372
408
a specific input, length of list must be equal to number of inputs.
373
409
410
+ Defaults to ``'default'``.
411
+
374
412
rop_overrides
375
413
One of ``{'default', OpFromGraph, callable, Variable}``.
376
414
377
- Defaults to ``'default'``.
378
-
379
415
``'default'``: Do not override, use the default :meth:`Op.R_op` result
380
416
381
417
`OpFromGraph`: Override with another `OpFromGraph`, should
@@ -397,11 +433,13 @@ def __init__(
397
433
must be equal to number of outputs. connection_pattern If not
398
434
``None``, this will be used as the connection_pattern for this
399
435
:class:`Op`.
436
+
437
+ Defaults to ``'default'``.
438
+
400
439
name
401
440
A name for debugging purposes.
402
441
kwargs
403
- Check :func:`aesara.function` for more arguments, only works when not
404
- inline.
442
+ See :func:`aesara.function`.
405
443
"""
406
444
407
445
if not (isinstance (inputs , list ) and isinstance (outputs , list )):
@@ -418,9 +456,24 @@ def __init__(
418
456
419
457
self .is_inline = inline
420
458
421
- self .fgraph , self .shared_inputs , _ , _ = construct_nominal_fgraph (
422
- inputs , outputs
423
- )
459
+ # These `shared_inputs` are the original variables in `outputs`
460
+ # (i.e. not clones).
461
+ self .fgraph , shared_inputs = construct_nominal_fgraph (inputs , outputs )
462
+
463
+ # We need to hold on to the original variables so that gradients can be
464
+ # taken wrt. them. Ideally, we wouldn't hold on to specific `Variable`
465
+ # references like this outside of graph, but we're maintaining support
466
+ # for old functionality right now.
467
+ self .shared_inputs = []
468
+ for v in shared_inputs :
469
+ # This is needed so that `aesara.function` will create an update
470
+ # output placeholder in the `FunctionGraph` it compiles. We need
471
+ # placeholders like this in order to properly inline `OpFromGraph`s
472
+ # containing updates.
473
+ # FYI: When the corresponding updates aren't used, they should be
474
+ # removed at the `aesara.function` level.
475
+ v .default_update = update_placeholder (v )
476
+ self .shared_inputs .append (v )
424
477
425
478
self .kwargs = kwargs
426
479
self .input_types = [inp .type for inp in inputs ]
@@ -933,7 +986,36 @@ def fn(self):
933
986
if getattr (self , "_fn" , None ) is not None :
934
987
return self ._fn
935
988
936
- self ._fn = function (self .inner_inputs , self .inner_outputs , ** self .kwargs )
989
+ from aesara .compile .function .pfunc import pfunc
990
+
991
+ # We don't want calls/evaluations of this `Op` to change
992
+ # the inner-graph, so we need to clone it
993
+ fgraph , _ = self .fgraph .clone_get_equiv (copy_inputs = False , copy_orphans = False )
994
+
995
+ wrapped_inputs = [In (x , borrow = False ) for x in fgraph .inputs ]
996
+ wrapped_outputs = [Out (x , borrow = True ) for x in fgraph .outputs ]
997
+
998
+ n_inputs = len (fgraph .inputs )
999
+
1000
+ for out_idx , in_idx in fgraph .update_mapping .items ():
1001
+ shared_input = self .shared_inputs [in_idx - n_inputs ]
1002
+ in_var = fgraph .inputs [in_idx ]
1003
+ updated_wrapped_input = In (
1004
+ variable = in_var ,
1005
+ value = shared_input .container ,
1006
+ update = fgraph .outputs [out_idx ],
1007
+ implicit = True ,
1008
+ shared = True ,
1009
+ )
1010
+ wrapped_inputs [in_idx ] = updated_wrapped_input
1011
+
1012
+ self ._fn = pfunc (
1013
+ wrapped_inputs ,
1014
+ wrapped_outputs ,
1015
+ fgraph = fgraph ,
1016
+ no_default_updates = True ,
1017
+ ** self .kwargs ,
1018
+ )
937
1019
self ._fn .trust_input = True
938
1020
939
1021
return self ._fn
@@ -944,6 +1026,11 @@ def inner_inputs(self):
944
1026
945
1027
@property
946
1028
def inner_outputs (self ):
1029
+ """Return all the outputs except those used for updates."""
1030
+ n_updates = len (self .fgraph .update_mapping )
1031
+ if n_updates > 0 :
1032
+ return self .fgraph .outputs [:- n_updates ]
1033
+
947
1034
return self .fgraph .outputs
948
1035
949
1036
def clone (self ):
@@ -952,28 +1039,52 @@ def clone(self):
952
1039
return res
953
1040
954
1041
def perform (self , node , inputs , outputs ):
955
- variables = self .fn (* inputs )
956
- assert len (variables ) == len (outputs )
957
- for output , variable in zip (outputs , variables ):
958
- output [0 ] = variable
1042
+ results = self .fn (* inputs )
1043
+ for output , res in zip (outputs , results ):
1044
+ output [0 ] = res
959
1045
960
1046
961
1047
@node_rewriter ([OpFromGraph ])
962
1048
def inline_ofg_expansion (fgraph , node ):
963
- """
964
- This optimization expands internal graph of OpFromGraph.
965
- Only performed if node.op.is_inline == True
966
- Doing so can improve optimization at the cost of compilation speed.
1049
+ """Expand the internal graph of an `OpFromGraph`.
1050
+
1051
+ Only performed if `` node.op.is_inline == True``.
1052
+
967
1053
"""
968
1054
op = node .op
969
- if not isinstance (op , OpFromGraph ):
970
- return False
1055
+
971
1056
if not op .is_inline :
972
1057
return False
973
- return clone_replace (
974
- op .inner_outputs , {u : v for u , v in zip (op .inner_inputs , node .inputs )}
1058
+
1059
+ outputs = clone_replace (
1060
+ op .fgraph .outputs , {u : v for u , v in zip (op .inner_inputs , node .inputs )}
975
1061
)
976
1062
1063
+ replacements = {
1064
+ old_var : new_var
1065
+ for old_var , new_var in zip (node .outputs , outputs [: len (op .inner_outputs )])
1066
+ }
1067
+
1068
+ # Add the updates from `OpFromGraph` into the outer-graph
1069
+ for out_idx , in_idx in op .fgraph .update_mapping .items ():
1070
+ shared_input = node .inputs [in_idx ]
1071
+ assert isinstance (shared_input , SharedVariable )
1072
+
1073
+ outer_in_idx = fgraph .inputs .index (shared_input )
1074
+
1075
+ # There should be a placeholder output in `fgraph.outputs` that we can
1076
+ # use. If there isn't, then someone forgot/removed the
1077
+ # `SharedVariable.default_update`s on the inputs to the `OpFromGraph`
1078
+ # (i.e. at the user-level/graph construction-time).
1079
+ outer_out_idx = fgraph .inv_update_mapping [outer_in_idx ]
1080
+ update_var = fgraph .outputs [outer_out_idx ]
1081
+
1082
+ assert update_var is not shared_input
1083
+
1084
+ replacements [update_var ] = outputs [out_idx ]
1085
+
1086
+ return replacements
1087
+
977
1088
978
1089
# We want to run this before the first merge optimizer
979
1090
# and before the first scan optimizer.
0 commit comments