30
30
from aesara .configdefaults import config
31
31
from aesara .graph .basic import Variable , graph_inputs , io_toposort
32
32
from aesara .graph .destroyhandler import DestroyHandler
33
- from aesara .graph .features import BadOptimization
33
+ from aesara .graph .features import AlreadyThere , BadOptimization , Feature
34
34
from aesara .graph .fg import InconsistencyError
35
35
from aesara .graph .op import COp , HasInnerGraph , Op
36
36
from aesara .graph .utils import MethodNotDefined
@@ -433,7 +433,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
433
433
equivalence_tracker = _VariableEquivalenceTracker ()
434
434
fgraph , updates = std_fgraph (input_specs , output_specs , accept_inplace )
435
435
fgraph .attach_feature (equivalence_tracker )
436
- return fgraph , updates , equivalence_tracker
436
+ return fgraph , updates
437
437
438
438
439
439
class DataDestroyed :
@@ -1181,96 +1181,84 @@ def __ne__(self, other):
1181
1181
return not (self == other )
1182
1182
1183
1183
1184
- class _VariableEquivalenceTracker :
1184
+ class _VariableEquivalenceTracker ( Feature ) :
1185
1185
"""
1186
1186
A FunctionGraph Feature that keeps tabs on an FunctionGraph and
1187
1187
tries to detect problems.
1188
1188
1189
1189
"""
1190
1190
1191
- fgraph = None
1192
- """WRITEME"""
1193
-
1194
- equiv = None
1195
- """WRITEME"""
1196
-
1197
- active_nodes = None
1198
- """WRITEME"""
1199
-
1200
- inactive_nodes = None
1201
- """WRITEME"""
1202
-
1203
- all_variables_ever = None
1204
- """WRITEME"""
1205
-
1206
- reasons = None
1207
- """WRITEME"""
1208
-
1209
- replaced_by = None
1210
- """WRITEME"""
1191
+ def on_attach (self , fgraph ):
1211
1192
1212
- event_list = None
1213
- """WRITEME"""
1193
+ if hasattr ( fgraph , "_eq_tracker_equiv" ):
1194
+ raise AlreadyThere ()
1214
1195
1215
- def __init__ (self ):
1216
- self .fgraph = None
1196
+ fgraph ._eq_tracker_equiv = {}
1197
+ fgraph ._eq_tracker_active_nodes = set ()
1198
+ fgraph ._eq_tracker_inactive_nodes = set ()
1199
+ fgraph ._eq_tracker_fgraph = fgraph
1200
+ fgraph ._eq_tracker_all_variables_ever = []
1201
+ fgraph ._eq_tracker_reasons = {}
1202
+ fgraph ._eq_tracker_replaced_by = {}
1203
+ fgraph ._eq_tracker_event_list = []
1217
1204
1218
- def on_attach (self , fgraph ):
1219
- assert self .fgraph is None
1220
- self .equiv = {}
1221
- self .active_nodes = set ()
1222
- self .inactive_nodes = set ()
1223
- self .fgraph = fgraph
1224
- self .all_variables_ever = []
1225
- self .reasons = {}
1226
- self .replaced_by = {}
1227
- self .event_list = []
1228
1205
for node in fgraph .toposort ():
1229
- self .on_import (fgraph , node , "on_attach " )
1206
+ self .on_import (fgraph , node , "var_equiv_on_attach " )
1230
1207
1231
1208
def on_detach (self , fgraph ):
1232
- assert fgraph is self .fgraph
1233
1209
self .fgraph = None
1210
+ del fgraph ._eq_tracker_equiv
1211
+ del fgraph ._eq_tracker_active_nodes
1212
+ del fgraph ._eq_tracker_inactive_nodes
1213
+ del fgraph ._eq_tracker_fgraph
1214
+ del fgraph ._eq_tracker_all_variables_ever
1215
+ del fgraph ._eq_tracker_reasons
1216
+ del fgraph ._eq_tracker_replaced_by
1217
+ del fgraph ._eq_tracker_event_list
1234
1218
1235
1219
def on_prune (self , fgraph , node , reason ):
1236
- self .event_list .append (_FunctionGraphEvent ("prune" , node , reason = str (reason )))
1237
- assert node in self .active_nodes
1238
- assert node not in self .inactive_nodes
1239
- self .active_nodes .remove (node )
1240
- self .inactive_nodes .add (node )
1220
+ fgraph ._eq_tracker_event_list .append (
1221
+ _FunctionGraphEvent ("prune" , node , reason = str (reason ))
1222
+ )
1223
+ assert node in fgraph ._eq_tracker_active_nodes
1224
+ assert node not in fgraph ._eq_tracker_inactive_nodes
1225
+ fgraph ._eq_tracker_active_nodes .remove (node )
1226
+ fgraph ._eq_tracker_inactive_nodes .add (node )
1241
1227
1242
1228
def on_import (self , fgraph , node , reason ):
1243
- self .event_list .append (_FunctionGraphEvent ("import" , node , reason = str (reason )))
1229
+ fgraph ._eq_tracker_event_list .append (
1230
+ _FunctionGraphEvent ("import" , node , reason = str (reason ))
1231
+ )
1244
1232
1245
- assert node not in self . active_nodes
1246
- self . active_nodes .add (node )
1233
+ assert node not in fgraph . _eq_tracker_active_nodes
1234
+ fgraph . _eq_tracker_active_nodes .add (node )
1247
1235
1248
- if node in self . inactive_nodes :
1249
- self . inactive_nodes .remove (node )
1236
+ if node in fgraph . _eq_tracker_inactive_nodes :
1237
+ fgraph . _eq_tracker_inactive_nodes .remove (node )
1250
1238
for r in node .outputs :
1251
- assert r in self . equiv
1239
+ assert r in fgraph . _eq_tracker_equiv
1252
1240
else :
1253
1241
for r in node .outputs :
1254
- assert r not in self . equiv
1255
- self . equiv [r ] = {r }
1256
- self . all_variables_ever .append (r )
1257
- self . reasons .setdefault (r , [])
1258
- self . replaced_by .setdefault (r , [])
1242
+ assert r not in fgraph . _eq_tracker_equiv
1243
+ fgraph . _eq_tracker_equiv [r ] = {r }
1244
+ fgraph . _eq_tracker_all_variables_ever .append (r )
1245
+ fgraph . _eq_tracker_reasons .setdefault (r , [])
1246
+ fgraph . _eq_tracker_replaced_by .setdefault (r , [])
1259
1247
for r in node .inputs :
1260
- self . reasons .setdefault (r , [])
1261
- self . replaced_by .setdefault (r , [])
1248
+ fgraph . _eq_tracker_reasons .setdefault (r , [])
1249
+ fgraph . _eq_tracker_replaced_by .setdefault (r , [])
1262
1250
1263
1251
def on_change_input (self , fgraph , node , i , r , new_r , reason = None ):
1264
1252
reason = str (reason )
1265
- self . event_list .append (
1253
+ fgraph . _eq_tracker_event_list .append (
1266
1254
_FunctionGraphEvent ("change" , node , reason = reason , idx = i )
1267
1255
)
1268
1256
1269
- self . reasons .setdefault (new_r , [])
1270
- self . replaced_by .setdefault (new_r , [])
1257
+ fgraph . _eq_tracker_reasons .setdefault (new_r , [])
1258
+ fgraph . _eq_tracker_replaced_by .setdefault (new_r , [])
1271
1259
1272
1260
append_reason = True
1273
- for tup in self . reasons [new_r ]:
1261
+ for tup in fgraph . _eq_tracker_reasons [new_r ]:
1274
1262
if tup [0 ] == reason and tup [1 ] is r :
1275
1263
append_reason = False
1276
1264
@@ -1279,7 +1267,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1279
1267
# optimizations will change the graph
1280
1268
done = dict ()
1281
1269
used_ids = dict ()
1282
- self . reasons [new_r ].append (
1270
+ fgraph . _eq_tracker_reasons [new_r ].append (
1283
1271
(
1284
1272
reason ,
1285
1273
r ,
@@ -1303,19 +1291,19 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1303
1291
).getvalue (),
1304
1292
)
1305
1293
)
1306
- self . replaced_by [r ].append ((reason , new_r ))
1294
+ fgraph . _eq_tracker_replaced_by [r ].append ((reason , new_r ))
1307
1295
1308
- if r in self . equiv :
1309
- r_set = self . equiv [r ]
1296
+ if r in fgraph . _eq_tracker_equiv :
1297
+ r_set = fgraph . _eq_tracker_equiv [r ]
1310
1298
else :
1311
- r_set = self . equiv .setdefault (r , {r })
1312
- self . all_variables_ever .append (r )
1299
+ r_set = fgraph . _eq_tracker_equiv .setdefault (r , {r })
1300
+ fgraph . _eq_tracker_all_variables_ever .append (r )
1313
1301
1314
- if new_r in self . equiv :
1315
- new_r_set = self . equiv [new_r ]
1302
+ if new_r in fgraph . _eq_tracker_equiv :
1303
+ new_r_set = fgraph . _eq_tracker_equiv [new_r ]
1316
1304
else :
1317
- new_r_set = self . equiv .setdefault (new_r , {new_r })
1318
- self . all_variables_ever .append (new_r )
1305
+ new_r_set = fgraph . _eq_tracker_equiv .setdefault (new_r , {new_r })
1306
+ fgraph . _eq_tracker_all_variables_ever .append (new_r )
1319
1307
1320
1308
assert new_r in new_r_set
1321
1309
assert r in r_set
@@ -1324,17 +1312,11 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1324
1312
# transfer all the elements of the old one to the new one
1325
1313
r_set .update (new_r_set )
1326
1314
for like_new_r in new_r_set :
1327
- self . equiv [like_new_r ] = r_set
1315
+ fgraph . _eq_tracker_equiv [like_new_r ] = r_set
1328
1316
assert like_new_r in r_set
1329
1317
1330
- assert self .equiv [r ] is r_set
1331
- assert self .equiv [new_r ] is r_set
1332
-
1333
- def printstuff (self ):
1334
- for key in self .equiv :
1335
- print (key )
1336
- for e in self .equiv [key ]:
1337
- print (" " , e )
1318
+ assert fgraph ._eq_tracker_equiv [r ] is r_set
1319
+ assert fgraph ._eq_tracker_equiv [new_r ] is r_set
1338
1320
1339
1321
1340
1322
# List of default version of make thunk.
@@ -1390,9 +1372,7 @@ def make_all(
1390
1372
# Compute a topological ordering that IGNORES the destroy_map
1391
1373
# of destructive Ops. This will be OK, because every thunk is
1392
1374
# evaluated on a copy of its input.
1393
- fgraph_equiv = fgraph .equivalence_tracker
1394
- order_outputs = copy .copy (fgraph_equiv .all_variables_ever )
1395
- del fgraph_equiv
1375
+ order_outputs = copy .copy (fgraph ._eq_tracker_all_variables_ever )
1396
1376
order_outputs .reverse ()
1397
1377
order = io_toposort (fgraph .inputs , order_outputs )
1398
1378
@@ -1625,7 +1605,7 @@ def f():
1625
1605
# insert a given apply node. If that is not True,
1626
1606
# we would need to loop over all node outputs,
1627
1607
# But this make the output uglier.
1628
- reason = fgraph .equivalence_tracker . reasons [node .outputs [0 ]]
1608
+ reason = fgraph ._eq_tracker_reasons [node .outputs [0 ]]
1629
1609
if not reason :
1630
1610
raise
1631
1611
opt = str (reason [0 ][0 ])
@@ -1738,7 +1718,7 @@ def f():
1738
1718
# insert a given apply node. If that is not True,
1739
1719
# we would need to loop over all node outputs,
1740
1720
# But this make the output uglier.
1741
- reason = fgraph .equivalence_tracker . reasons [node .outputs [0 ]]
1721
+ reason = fgraph ._eq_tracker_reasons [node .outputs [0 ]]
1742
1722
if not reason :
1743
1723
raise
1744
1724
opt = str (reason [0 ][0 ])
@@ -1865,9 +1845,7 @@ def thunk():
1865
1845
# But it is very slow and it is not sure it will help.
1866
1846
gc .collect ()
1867
1847
1868
- _find_bad_optimizations (
1869
- order , fgraph .equivalence_tracker .reasons , r_vals
1870
- )
1848
+ _find_bad_optimizations (order , fgraph ._eq_tracker_reasons , r_vals )
1871
1849
1872
1850
#####
1873
1851
# Postcondition: the input and output variables are
@@ -2058,10 +2036,9 @@ def __init__(
2058
2036
2059
2037
# make the fgraph
2060
2038
for i in range (mode .stability_patience ):
2061
- fgraph , additional_outputs , equivalence_tracker = _optcheck_fgraph (
2039
+ fgraph , additional_outputs = _optcheck_fgraph (
2062
2040
inputs , outputs , accept_inplace
2063
2041
)
2064
- fgraph .equivalence_tracker = equivalence_tracker
2065
2042
2066
2043
with config .change_flags (compute_test_value = config .compute_test_value_opt ):
2067
2044
optimizer (fgraph )
@@ -2073,8 +2050,8 @@ def __init__(
2073
2050
if i == 0 :
2074
2051
fgraph0 = fgraph
2075
2052
else :
2076
- li = fgraph .equivalence_tracker . event_list
2077
- l0 = fgraph0 .equivalence_tracker . event_list
2053
+ li = fgraph ._eq_tracker_event_list
2054
+ l0 = fgraph0 ._eq_tracker_event_list
2078
2055
if li != l0 :
2079
2056
infolog = StringIO ()
2080
2057
print ("Optimization process is unstable..." , file = infolog )
0 commit comments