Skip to content

Commit c53e48e

Browse files
Make _VariableEquivalenceTracker a proper Feature and less stateful
1 parent ea8b04e commit c53e48e

File tree

1 file changed

+69
-92
lines changed

1 file changed

+69
-92
lines changed

aesara/compile/debugmode.py

+69-92
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from aesara.configdefaults import config
3131
from aesara.graph.basic import Variable, io_toposort
3232
from aesara.graph.destroyhandler import DestroyHandler
33-
from aesara.graph.features import BadOptimization
33+
from aesara.graph.features import AlreadyThere, BadOptimization, Feature
3434
from aesara.graph.op import COp, HasInnerGraph, Op
3535
from aesara.graph.utils import InconsistencyError, MethodNotDefined
3636
from aesara.link.basic import Container, LocalLinker
@@ -432,7 +432,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
432432
equivalence_tracker = _VariableEquivalenceTracker()
433433
fgraph, updates = std_fgraph(input_specs, output_specs, accept_inplace)
434434
fgraph.attach_feature(equivalence_tracker)
435-
return fgraph, updates, equivalence_tracker
435+
return fgraph, updates
436436

437437

438438
class DataDestroyed:
@@ -1178,96 +1178,84 @@ def __ne__(self, other):
11781178
return not (self == other)
11791179

11801180

1181-
class _VariableEquivalenceTracker:
1181+
class _VariableEquivalenceTracker(Feature):
11821182
"""
11831183
A FunctionGraph Feature that keeps tabs on an FunctionGraph and
11841184
tries to detect problems.
11851185
11861186
"""
11871187

1188-
fgraph = None
1189-
"""WRITEME"""
1190-
1191-
equiv = None
1192-
"""WRITEME"""
1193-
1194-
active_nodes = None
1195-
"""WRITEME"""
1196-
1197-
inactive_nodes = None
1198-
"""WRITEME"""
1199-
1200-
all_variables_ever = None
1201-
"""WRITEME"""
1202-
1203-
reasons = None
1204-
"""WRITEME"""
1205-
1206-
replaced_by = None
1207-
"""WRITEME"""
1188+
def on_attach(self, fgraph):
12081189

1209-
event_list = None
1210-
"""WRITEME"""
1190+
if hasattr(fgraph, "_eq_tracker_equiv"):
1191+
raise AlreadyThere()
12111192

1212-
def __init__(self):
1213-
self.fgraph = None
1193+
fgraph._eq_tracker_equiv = {}
1194+
fgraph._eq_tracker_active_nodes = set()
1195+
fgraph._eq_tracker_inactive_nodes = set()
1196+
fgraph._eq_tracker_fgraph = fgraph
1197+
fgraph._eq_tracker_all_variables_ever = []
1198+
fgraph._eq_tracker_reasons = {}
1199+
fgraph._eq_tracker_replaced_by = {}
1200+
fgraph._eq_tracker_event_list = []
12141201

1215-
def on_attach(self, fgraph):
1216-
assert self.fgraph is None
1217-
self.equiv = {}
1218-
self.active_nodes = set()
1219-
self.inactive_nodes = set()
1220-
self.fgraph = fgraph
1221-
self.all_variables_ever = []
1222-
self.reasons = {}
1223-
self.replaced_by = {}
1224-
self.event_list = []
12251202
for node in fgraph.toposort():
1226-
self.on_import(fgraph, node, "on_attach")
1203+
self.on_import(fgraph, node, "var_equiv_on_attach")
12271204

12281205
def on_detach(self, fgraph):
1229-
assert fgraph is self.fgraph
12301206
self.fgraph = None
1207+
del fgraph._eq_tracker_equiv
1208+
del fgraph._eq_tracker_active_nodes
1209+
del fgraph._eq_tracker_inactive_nodes
1210+
del fgraph._eq_tracker_fgraph
1211+
del fgraph._eq_tracker_all_variables_ever
1212+
del fgraph._eq_tracker_reasons
1213+
del fgraph._eq_tracker_replaced_by
1214+
del fgraph._eq_tracker_event_list
12311215

12321216
def on_prune(self, fgraph, node, reason):
1233-
self.event_list.append(_FunctionGraphEvent("prune", node, reason=str(reason)))
1234-
assert node in self.active_nodes
1235-
assert node not in self.inactive_nodes
1236-
self.active_nodes.remove(node)
1237-
self.inactive_nodes.add(node)
1217+
fgraph._eq_tracker_event_list.append(
1218+
_FunctionGraphEvent("prune", node, reason=str(reason))
1219+
)
1220+
assert node in fgraph._eq_tracker_active_nodes
1221+
assert node not in fgraph._eq_tracker_inactive_nodes
1222+
fgraph._eq_tracker_active_nodes.remove(node)
1223+
fgraph._eq_tracker_inactive_nodes.add(node)
12381224

12391225
def on_import(self, fgraph, node, reason):
1240-
self.event_list.append(_FunctionGraphEvent("import", node, reason=str(reason)))
1226+
fgraph._eq_tracker_event_list.append(
1227+
_FunctionGraphEvent("import", node, reason=str(reason))
1228+
)
12411229

1242-
assert node not in self.active_nodes
1243-
self.active_nodes.add(node)
1230+
assert node not in fgraph._eq_tracker_active_nodes
1231+
fgraph._eq_tracker_active_nodes.add(node)
12441232

1245-
if node in self.inactive_nodes:
1246-
self.inactive_nodes.remove(node)
1233+
if node in fgraph._eq_tracker_inactive_nodes:
1234+
fgraph._eq_tracker_inactive_nodes.remove(node)
12471235
for r in node.outputs:
1248-
assert r in self.equiv
1236+
assert r in fgraph._eq_tracker_equiv
12491237
else:
12501238
for r in node.outputs:
1251-
assert r not in self.equiv
1252-
self.equiv[r] = {r}
1253-
self.all_variables_ever.append(r)
1254-
self.reasons.setdefault(r, [])
1255-
self.replaced_by.setdefault(r, [])
1239+
assert r not in fgraph._eq_tracker_equiv
1240+
fgraph._eq_tracker_equiv[r] = {r}
1241+
fgraph._eq_tracker_all_variables_ever.append(r)
1242+
fgraph._eq_tracker_reasons.setdefault(r, [])
1243+
fgraph._eq_tracker_replaced_by.setdefault(r, [])
12561244
for r in node.inputs:
1257-
self.reasons.setdefault(r, [])
1258-
self.replaced_by.setdefault(r, [])
1245+
fgraph._eq_tracker_reasons.setdefault(r, [])
1246+
fgraph._eq_tracker_replaced_by.setdefault(r, [])
12591247

12601248
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12611249
reason = str(reason)
1262-
self.event_list.append(
1250+
fgraph._eq_tracker_event_list.append(
12631251
_FunctionGraphEvent("change", node, reason=reason, idx=i)
12641252
)
12651253

1266-
self.reasons.setdefault(new_r, [])
1267-
self.replaced_by.setdefault(new_r, [])
1254+
fgraph._eq_tracker_reasons.setdefault(new_r, [])
1255+
fgraph._eq_tracker_replaced_by.setdefault(new_r, [])
12681256

12691257
append_reason = True
1270-
for tup in self.reasons[new_r]:
1258+
for tup in fgraph._eq_tracker_reasons[new_r]:
12711259
if tup[0] == reason and tup[1] is r:
12721260
append_reason = False
12731261

@@ -1276,7 +1264,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12761264
# optimizations will change the graph
12771265
done = dict()
12781266
used_ids = dict()
1279-
self.reasons[new_r].append(
1267+
fgraph._eq_tracker_reasons[new_r].append(
12801268
(
12811269
reason,
12821270
r,
@@ -1300,19 +1288,19 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
13001288
).getvalue(),
13011289
)
13021290
)
1303-
self.replaced_by[r].append((reason, new_r))
1291+
fgraph._eq_tracker_replaced_by[r].append((reason, new_r))
13041292

1305-
if r in self.equiv:
1306-
r_set = self.equiv[r]
1293+
if r in fgraph._eq_tracker_equiv:
1294+
r_set = fgraph._eq_tracker_equiv[r]
13071295
else:
1308-
r_set = self.equiv.setdefault(r, {r})
1309-
self.all_variables_ever.append(r)
1296+
r_set = fgraph._eq_tracker_equiv.setdefault(r, {r})
1297+
fgraph._eq_tracker_all_variables_ever.append(r)
13101298

1311-
if new_r in self.equiv:
1312-
new_r_set = self.equiv[new_r]
1299+
if new_r in fgraph._eq_tracker_equiv:
1300+
new_r_set = fgraph._eq_tracker_equiv[new_r]
13131301
else:
1314-
new_r_set = self.equiv.setdefault(new_r, {new_r})
1315-
self.all_variables_ever.append(new_r)
1302+
new_r_set = fgraph._eq_tracker_equiv.setdefault(new_r, {new_r})
1303+
fgraph._eq_tracker_all_variables_ever.append(new_r)
13161304

13171305
assert new_r in new_r_set
13181306
assert r in r_set
@@ -1321,17 +1309,11 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
13211309
# transfer all the elements of the old one to the new one
13221310
r_set.update(new_r_set)
13231311
for like_new_r in new_r_set:
1324-
self.equiv[like_new_r] = r_set
1312+
fgraph._eq_tracker_equiv[like_new_r] = r_set
13251313
assert like_new_r in r_set
13261314

1327-
assert self.equiv[r] is r_set
1328-
assert self.equiv[new_r] is r_set
1329-
1330-
def printstuff(self):
1331-
for key in self.equiv:
1332-
print(key)
1333-
for e in self.equiv[key]:
1334-
print(" ", e)
1315+
assert fgraph._eq_tracker_equiv[r] is r_set
1316+
assert fgraph._eq_tracker_equiv[new_r] is r_set
13351317

13361318

13371319
# List of default version of make thunk.
@@ -1387,9 +1369,7 @@ def make_all(
13871369
# Compute a topological ordering that IGNORES the destroy_map
13881370
# of destructive Ops. This will be OK, because every thunk is
13891371
# evaluated on a copy of its input.
1390-
fgraph_equiv = fgraph.equivalence_tracker
1391-
order_outputs = copy.copy(fgraph_equiv.all_variables_ever)
1392-
del fgraph_equiv
1372+
order_outputs = copy.copy(fgraph._eq_tracker_all_variables_ever)
13931373
order_outputs.reverse()
13941374
order = io_toposort(fgraph.inputs, order_outputs)
13951375

@@ -1622,7 +1602,7 @@ def f():
16221602
# insert a given apply node. If that is not True,
16231603
# we would need to loop over all node outputs,
16241604
# But this make the output uglier.
1625-
reason = fgraph.equivalence_tracker.reasons[node.outputs[0]]
1605+
reason = fgraph._eq_tracker_reasons[node.outputs[0]]
16261606
if not reason:
16271607
raise
16281608
opt = str(reason[0][0])
@@ -1735,7 +1715,7 @@ def f():
17351715
# insert a given apply node. If that is not True,
17361716
# we would need to loop over all node outputs,
17371717
# But this make the output uglier.
1738-
reason = fgraph.equivalence_tracker.reasons[node.outputs[0]]
1718+
reason = fgraph._eq_tracker_reasons[node.outputs[0]]
17391719
if not reason:
17401720
raise
17411721
opt = str(reason[0][0])
@@ -1862,9 +1842,7 @@ def thunk():
18621842
# But it is very slow and it is not sure it will help.
18631843
gc.collect()
18641844

1865-
_find_bad_optimizations(
1866-
order, fgraph.equivalence_tracker.reasons, r_vals
1867-
)
1845+
_find_bad_optimizations(order, fgraph._eq_tracker_reasons, r_vals)
18681846

18691847
#####
18701848
# Postcondition: the input and output variables are
@@ -2045,10 +2023,9 @@ def __init__(
20452023

20462024
# make the fgraph
20472025
for i in range(mode.stability_patience):
2048-
fgraph, additional_outputs, equivalence_tracker = _optcheck_fgraph(
2026+
fgraph, additional_outputs = _optcheck_fgraph(
20492027
inputs, outputs, accept_inplace
20502028
)
2051-
fgraph.equivalence_tracker = equivalence_tracker
20522029

20532030
with config.change_flags(compute_test_value=config.compute_test_value_opt):
20542031
optimizer(fgraph)
@@ -2060,8 +2037,8 @@ def __init__(
20602037
if i == 0:
20612038
fgraph0 = fgraph
20622039
else:
2063-
li = fgraph.equivalence_tracker.event_list
2064-
l0 = fgraph0.equivalence_tracker.event_list
2040+
li = fgraph._eq_tracker_event_list
2041+
l0 = fgraph0._eq_tracker_event_list
20652042
if li != l0:
20662043
infolog = StringIO()
20672044
print("Optimization process is unstable...", file=infolog)

0 commit comments

Comments
 (0)