30
30
from aesara .configdefaults import config
31
31
from aesara .graph .basic import Variable , io_toposort
32
32
from aesara .graph .destroyhandler import DestroyHandler
33
- from aesara .graph .features import BadOptimization
33
+ from aesara .graph .features import AlreadyThere , BadOptimization , Feature
34
34
from aesara .graph .op import COp , HasInnerGraph , Op
35
35
from aesara .graph .utils import InconsistencyError , MethodNotDefined
36
36
from aesara .link .basic import Container , LocalLinker
@@ -432,7 +432,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
432
432
equivalence_tracker = _VariableEquivalenceTracker ()
433
433
fgraph , updates = std_fgraph (input_specs , output_specs , accept_inplace )
434
434
fgraph .attach_feature (equivalence_tracker )
435
- return fgraph , updates , equivalence_tracker
435
+ return fgraph , updates
436
436
437
437
438
438
class DataDestroyed :
@@ -1178,96 +1178,84 @@ def __ne__(self, other):
1178
1178
return not (self == other )
1179
1179
1180
1180
1181
- class _VariableEquivalenceTracker :
1181
+ class _VariableEquivalenceTracker ( Feature ) :
1182
1182
"""
1183
1183
A FunctionGraph Feature that keeps tabs on an FunctionGraph and
1184
1184
tries to detect problems.
1185
1185
1186
1186
"""
1187
1187
1188
- fgraph = None
1189
- """WRITEME"""
1190
-
1191
- equiv = None
1192
- """WRITEME"""
1193
-
1194
- active_nodes = None
1195
- """WRITEME"""
1196
-
1197
- inactive_nodes = None
1198
- """WRITEME"""
1199
-
1200
- all_variables_ever = None
1201
- """WRITEME"""
1202
-
1203
- reasons = None
1204
- """WRITEME"""
1205
-
1206
- replaced_by = None
1207
- """WRITEME"""
1188
+ def on_attach (self , fgraph ):
1208
1189
1209
- event_list = None
1210
- """WRITEME"""
1190
+ if hasattr ( fgraph , "_eq_tracker_equiv" ):
1191
+ raise AlreadyThere ()
1211
1192
1212
- def __init__ (self ):
1213
- self .fgraph = None
1193
+ fgraph ._eq_tracker_equiv = {}
1194
+ fgraph ._eq_tracker_active_nodes = set ()
1195
+ fgraph ._eq_tracker_inactive_nodes = set ()
1196
+ fgraph ._eq_tracker_fgraph = fgraph
1197
+ fgraph ._eq_tracker_all_variables_ever = []
1198
+ fgraph ._eq_tracker_reasons = {}
1199
+ fgraph ._eq_tracker_replaced_by = {}
1200
+ fgraph ._eq_tracker_event_list = []
1214
1201
1215
- def on_attach (self , fgraph ):
1216
- assert self .fgraph is None
1217
- self .equiv = {}
1218
- self .active_nodes = set ()
1219
- self .inactive_nodes = set ()
1220
- self .fgraph = fgraph
1221
- self .all_variables_ever = []
1222
- self .reasons = {}
1223
- self .replaced_by = {}
1224
- self .event_list = []
1225
1202
for node in fgraph .toposort ():
1226
- self .on_import (fgraph , node , "on_attach " )
1203
+ self .on_import (fgraph , node , "var_equiv_on_attach " )
1227
1204
1228
1205
def on_detach (self , fgraph ):
1229
- assert fgraph is self .fgraph
1230
1206
self .fgraph = None
1207
+ del fgraph ._eq_tracker_equiv
1208
+ del fgraph ._eq_tracker_active_nodes
1209
+ del fgraph ._eq_tracker_inactive_nodes
1210
+ del fgraph ._eq_tracker_fgraph
1211
+ del fgraph ._eq_tracker_all_variables_ever
1212
+ del fgraph ._eq_tracker_reasons
1213
+ del fgraph ._eq_tracker_replaced_by
1214
+ del fgraph ._eq_tracker_event_list
1231
1215
1232
1216
def on_prune (self , fgraph , node , reason ):
1233
- self .event_list .append (_FunctionGraphEvent ("prune" , node , reason = str (reason )))
1234
- assert node in self .active_nodes
1235
- assert node not in self .inactive_nodes
1236
- self .active_nodes .remove (node )
1237
- self .inactive_nodes .add (node )
1217
+ fgraph ._eq_tracker_event_list .append (
1218
+ _FunctionGraphEvent ("prune" , node , reason = str (reason ))
1219
+ )
1220
+ assert node in fgraph ._eq_tracker_active_nodes
1221
+ assert node not in fgraph ._eq_tracker_inactive_nodes
1222
+ fgraph ._eq_tracker_active_nodes .remove (node )
1223
+ fgraph ._eq_tracker_inactive_nodes .add (node )
1238
1224
1239
1225
def on_import (self , fgraph , node , reason ):
1240
- self .event_list .append (_FunctionGraphEvent ("import" , node , reason = str (reason )))
1226
+ fgraph ._eq_tracker_event_list .append (
1227
+ _FunctionGraphEvent ("import" , node , reason = str (reason ))
1228
+ )
1241
1229
1242
- assert node not in self . active_nodes
1243
- self . active_nodes .add (node )
1230
+ assert node not in fgraph . _eq_tracker_active_nodes
1231
+ fgraph . _eq_tracker_active_nodes .add (node )
1244
1232
1245
- if node in self . inactive_nodes :
1246
- self . inactive_nodes .remove (node )
1233
+ if node in fgraph . _eq_tracker_inactive_nodes :
1234
+ fgraph . _eq_tracker_inactive_nodes .remove (node )
1247
1235
for r in node .outputs :
1248
- assert r in self . equiv
1236
+ assert r in fgraph . _eq_tracker_equiv
1249
1237
else :
1250
1238
for r in node .outputs :
1251
- assert r not in self . equiv
1252
- self . equiv [r ] = {r }
1253
- self . all_variables_ever .append (r )
1254
- self . reasons .setdefault (r , [])
1255
- self . replaced_by .setdefault (r , [])
1239
+ assert r not in fgraph . _eq_tracker_equiv
1240
+ fgraph . _eq_tracker_equiv [r ] = {r }
1241
+ fgraph . _eq_tracker_all_variables_ever .append (r )
1242
+ fgraph . _eq_tracker_reasons .setdefault (r , [])
1243
+ fgraph . _eq_tracker_replaced_by .setdefault (r , [])
1256
1244
for r in node .inputs :
1257
- self . reasons .setdefault (r , [])
1258
- self . replaced_by .setdefault (r , [])
1245
+ fgraph . _eq_tracker_reasons .setdefault (r , [])
1246
+ fgraph . _eq_tracker_replaced_by .setdefault (r , [])
1259
1247
1260
1248
def on_change_input (self , fgraph , node , i , r , new_r , reason = None ):
1261
1249
reason = str (reason )
1262
- self . event_list .append (
1250
+ fgraph . _eq_tracker_event_list .append (
1263
1251
_FunctionGraphEvent ("change" , node , reason = reason , idx = i )
1264
1252
)
1265
1253
1266
- self . reasons .setdefault (new_r , [])
1267
- self . replaced_by .setdefault (new_r , [])
1254
+ fgraph . _eq_tracker_reasons .setdefault (new_r , [])
1255
+ fgraph . _eq_tracker_replaced_by .setdefault (new_r , [])
1268
1256
1269
1257
append_reason = True
1270
- for tup in self . reasons [new_r ]:
1258
+ for tup in fgraph . _eq_tracker_reasons [new_r ]:
1271
1259
if tup [0 ] == reason and tup [1 ] is r :
1272
1260
append_reason = False
1273
1261
@@ -1276,7 +1264,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1276
1264
# optimizations will change the graph
1277
1265
done = dict ()
1278
1266
used_ids = dict ()
1279
- self . reasons [new_r ].append (
1267
+ fgraph . _eq_tracker_reasons [new_r ].append (
1280
1268
(
1281
1269
reason ,
1282
1270
r ,
@@ -1300,19 +1288,19 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1300
1288
).getvalue (),
1301
1289
)
1302
1290
)
1303
- self . replaced_by [r ].append ((reason , new_r ))
1291
+ fgraph . _eq_tracker_replaced_by [r ].append ((reason , new_r ))
1304
1292
1305
- if r in self . equiv :
1306
- r_set = self . equiv [r ]
1293
+ if r in fgraph . _eq_tracker_equiv :
1294
+ r_set = fgraph . _eq_tracker_equiv [r ]
1307
1295
else :
1308
- r_set = self . equiv .setdefault (r , {r })
1309
- self . all_variables_ever .append (r )
1296
+ r_set = fgraph . _eq_tracker_equiv .setdefault (r , {r })
1297
+ fgraph . _eq_tracker_all_variables_ever .append (r )
1310
1298
1311
- if new_r in self . equiv :
1312
- new_r_set = self . equiv [new_r ]
1299
+ if new_r in fgraph . _eq_tracker_equiv :
1300
+ new_r_set = fgraph . _eq_tracker_equiv [new_r ]
1313
1301
else :
1314
- new_r_set = self . equiv .setdefault (new_r , {new_r })
1315
- self . all_variables_ever .append (new_r )
1302
+ new_r_set = fgraph . _eq_tracker_equiv .setdefault (new_r , {new_r })
1303
+ fgraph . _eq_tracker_all_variables_ever .append (new_r )
1316
1304
1317
1305
assert new_r in new_r_set
1318
1306
assert r in r_set
@@ -1321,17 +1309,11 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1321
1309
# transfer all the elements of the old one to the new one
1322
1310
r_set .update (new_r_set )
1323
1311
for like_new_r in new_r_set :
1324
- self . equiv [like_new_r ] = r_set
1312
+ fgraph . _eq_tracker_equiv [like_new_r ] = r_set
1325
1313
assert like_new_r in r_set
1326
1314
1327
- assert self .equiv [r ] is r_set
1328
- assert self .equiv [new_r ] is r_set
1329
-
1330
- def printstuff (self ):
1331
- for key in self .equiv :
1332
- print (key )
1333
- for e in self .equiv [key ]:
1334
- print (" " , e )
1315
+ assert fgraph ._eq_tracker_equiv [r ] is r_set
1316
+ assert fgraph ._eq_tracker_equiv [new_r ] is r_set
1335
1317
1336
1318
1337
1319
# List of default version of make thunk.
@@ -1387,9 +1369,7 @@ def make_all(
1387
1369
# Compute a topological ordering that IGNORES the destroy_map
1388
1370
# of destructive Ops. This will be OK, because every thunk is
1389
1371
# evaluated on a copy of its input.
1390
- fgraph_equiv = fgraph .equivalence_tracker
1391
- order_outputs = copy .copy (fgraph_equiv .all_variables_ever )
1392
- del fgraph_equiv
1372
+ order_outputs = copy .copy (fgraph ._eq_tracker_all_variables_ever )
1393
1373
order_outputs .reverse ()
1394
1374
order = io_toposort (fgraph .inputs , order_outputs )
1395
1375
@@ -1622,7 +1602,7 @@ def f():
1622
1602
# insert a given apply node. If that is not True,
1623
1603
# we would need to loop over all node outputs,
1624
1604
# But this make the output uglier.
1625
- reason = fgraph .equivalence_tracker . reasons [node .outputs [0 ]]
1605
+ reason = fgraph ._eq_tracker_reasons [node .outputs [0 ]]
1626
1606
if not reason :
1627
1607
raise
1628
1608
opt = str (reason [0 ][0 ])
@@ -1735,7 +1715,7 @@ def f():
1735
1715
# insert a given apply node. If that is not True,
1736
1716
# we would need to loop over all node outputs,
1737
1717
# But this make the output uglier.
1738
- reason = fgraph .equivalence_tracker . reasons [node .outputs [0 ]]
1718
+ reason = fgraph ._eq_tracker_reasons [node .outputs [0 ]]
1739
1719
if not reason :
1740
1720
raise
1741
1721
opt = str (reason [0 ][0 ])
@@ -1862,9 +1842,7 @@ def thunk():
1862
1842
# But it is very slow and it is not sure it will help.
1863
1843
gc .collect ()
1864
1844
1865
- _find_bad_optimizations (
1866
- order , fgraph .equivalence_tracker .reasons , r_vals
1867
- )
1845
+ _find_bad_optimizations (order , fgraph ._eq_tracker_reasons , r_vals )
1868
1846
1869
1847
#####
1870
1848
# Postcondition: the input and output variables are
@@ -2045,10 +2023,9 @@ def __init__(
2045
2023
2046
2024
# make the fgraph
2047
2025
for i in range (mode .stability_patience ):
2048
- fgraph , additional_outputs , equivalence_tracker = _optcheck_fgraph (
2026
+ fgraph , additional_outputs = _optcheck_fgraph (
2049
2027
inputs , outputs , accept_inplace
2050
2028
)
2051
- fgraph .equivalence_tracker = equivalence_tracker
2052
2029
2053
2030
with config .change_flags (compute_test_value = config .compute_test_value_opt ):
2054
2031
optimizer (fgraph )
@@ -2060,8 +2037,8 @@ def __init__(
2060
2037
if i == 0 :
2061
2038
fgraph0 = fgraph
2062
2039
else :
2063
- li = fgraph .equivalence_tracker . event_list
2064
- l0 = fgraph0 .equivalence_tracker . event_list
2040
+ li = fgraph ._eq_tracker_event_list
2041
+ l0 = fgraph0 ._eq_tracker_event_list
2065
2042
if li != l0 :
2066
2043
infolog = StringIO ()
2067
2044
print ("Optimization process is unstable..." , file = infolog )
0 commit comments