Skip to content

Commit 320ad0a

Browse files
Make _VariableEquivalenceTracker a proper Feature and less stateful
1 parent 81d6d36 commit 320ad0a

File tree

1 file changed

+69
-92
lines changed

1 file changed

+69
-92
lines changed

aesara/compile/debugmode.py

+69-92
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from aesara.configdefaults import config
3131
from aesara.graph.basic import Variable, graph_inputs, io_toposort
3232
from aesara.graph.destroyhandler import DestroyHandler
33-
from aesara.graph.features import BadOptimization
33+
from aesara.graph.features import AlreadyThere, BadOptimization, Feature
3434
from aesara.graph.fg import InconsistencyError
3535
from aesara.graph.op import COp, HasInnerGraph, Op
3636
from aesara.graph.utils import MethodNotDefined
@@ -433,7 +433,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
433433
equivalence_tracker = _VariableEquivalenceTracker()
434434
fgraph, updates = std_fgraph(input_specs, output_specs, accept_inplace)
435435
fgraph.attach_feature(equivalence_tracker)
436-
return fgraph, updates, equivalence_tracker
436+
return fgraph, updates
437437

438438

439439
class DataDestroyed:
@@ -1181,96 +1181,84 @@ def __ne__(self, other):
11811181
return not (self == other)
11821182

11831183

1184-
class _VariableEquivalenceTracker:
1184+
class _VariableEquivalenceTracker(Feature):
11851185
"""
11861186
A FunctionGraph Feature that keeps tabs on an FunctionGraph and
11871187
tries to detect problems.
11881188
11891189
"""
11901190

1191-
fgraph = None
1192-
"""WRITEME"""
1193-
1194-
equiv = None
1195-
"""WRITEME"""
1196-
1197-
active_nodes = None
1198-
"""WRITEME"""
1199-
1200-
inactive_nodes = None
1201-
"""WRITEME"""
1202-
1203-
all_variables_ever = None
1204-
"""WRITEME"""
1205-
1206-
reasons = None
1207-
"""WRITEME"""
1208-
1209-
replaced_by = None
1210-
"""WRITEME"""
1191+
def on_attach(self, fgraph):
12111192

1212-
event_list = None
1213-
"""WRITEME"""
1193+
if hasattr(fgraph, "_eq_tracker_equiv"):
1194+
raise AlreadyThere()
12141195

1215-
def __init__(self):
1216-
self.fgraph = None
1196+
fgraph._eq_tracker_equiv = {}
1197+
fgraph._eq_tracker_active_nodes = set()
1198+
fgraph._eq_tracker_inactive_nodes = set()
1199+
fgraph._eq_tracker_fgraph = fgraph
1200+
fgraph._eq_tracker_all_variables_ever = []
1201+
fgraph._eq_tracker_reasons = {}
1202+
fgraph._eq_tracker_replaced_by = {}
1203+
fgraph._eq_tracker_event_list = []
12171204

1218-
def on_attach(self, fgraph):
1219-
assert self.fgraph is None
1220-
self.equiv = {}
1221-
self.active_nodes = set()
1222-
self.inactive_nodes = set()
1223-
self.fgraph = fgraph
1224-
self.all_variables_ever = []
1225-
self.reasons = {}
1226-
self.replaced_by = {}
1227-
self.event_list = []
12281205
for node in fgraph.toposort():
1229-
self.on_import(fgraph, node, "on_attach")
1206+
self.on_import(fgraph, node, "var_equiv_on_attach")
12301207

12311208
def on_detach(self, fgraph):
1232-
assert fgraph is self.fgraph
12331209
self.fgraph = None
1210+
del fgraph._eq_tracker_equiv
1211+
del fgraph._eq_tracker_active_nodes
1212+
del fgraph._eq_tracker_inactive_nodes
1213+
del fgraph._eq_tracker_fgraph
1214+
del fgraph._eq_tracker_all_variables_ever
1215+
del fgraph._eq_tracker_reasons
1216+
del fgraph._eq_tracker_replaced_by
1217+
del fgraph._eq_tracker_event_list
12341218

12351219
def on_prune(self, fgraph, node, reason):
1236-
self.event_list.append(_FunctionGraphEvent("prune", node, reason=str(reason)))
1237-
assert node in self.active_nodes
1238-
assert node not in self.inactive_nodes
1239-
self.active_nodes.remove(node)
1240-
self.inactive_nodes.add(node)
1220+
fgraph._eq_tracker_event_list.append(
1221+
_FunctionGraphEvent("prune", node, reason=str(reason))
1222+
)
1223+
assert node in fgraph._eq_tracker_active_nodes
1224+
assert node not in fgraph._eq_tracker_inactive_nodes
1225+
fgraph._eq_tracker_active_nodes.remove(node)
1226+
fgraph._eq_tracker_inactive_nodes.add(node)
12411227

12421228
def on_import(self, fgraph, node, reason):
1243-
self.event_list.append(_FunctionGraphEvent("import", node, reason=str(reason)))
1229+
fgraph._eq_tracker_event_list.append(
1230+
_FunctionGraphEvent("import", node, reason=str(reason))
1231+
)
12441232

1245-
assert node not in self.active_nodes
1246-
self.active_nodes.add(node)
1233+
assert node not in fgraph._eq_tracker_active_nodes
1234+
fgraph._eq_tracker_active_nodes.add(node)
12471235

1248-
if node in self.inactive_nodes:
1249-
self.inactive_nodes.remove(node)
1236+
if node in fgraph._eq_tracker_inactive_nodes:
1237+
fgraph._eq_tracker_inactive_nodes.remove(node)
12501238
for r in node.outputs:
1251-
assert r in self.equiv
1239+
assert r in fgraph._eq_tracker_equiv
12521240
else:
12531241
for r in node.outputs:
1254-
assert r not in self.equiv
1255-
self.equiv[r] = {r}
1256-
self.all_variables_ever.append(r)
1257-
self.reasons.setdefault(r, [])
1258-
self.replaced_by.setdefault(r, [])
1242+
assert r not in fgraph._eq_tracker_equiv
1243+
fgraph._eq_tracker_equiv[r] = {r}
1244+
fgraph._eq_tracker_all_variables_ever.append(r)
1245+
fgraph._eq_tracker_reasons.setdefault(r, [])
1246+
fgraph._eq_tracker_replaced_by.setdefault(r, [])
12591247
for r in node.inputs:
1260-
self.reasons.setdefault(r, [])
1261-
self.replaced_by.setdefault(r, [])
1248+
fgraph._eq_tracker_reasons.setdefault(r, [])
1249+
fgraph._eq_tracker_replaced_by.setdefault(r, [])
12621250

12631251
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12641252
reason = str(reason)
1265-
self.event_list.append(
1253+
fgraph._eq_tracker_event_list.append(
12661254
_FunctionGraphEvent("change", node, reason=reason, idx=i)
12671255
)
12681256

1269-
self.reasons.setdefault(new_r, [])
1270-
self.replaced_by.setdefault(new_r, [])
1257+
fgraph._eq_tracker_reasons.setdefault(new_r, [])
1258+
fgraph._eq_tracker_replaced_by.setdefault(new_r, [])
12711259

12721260
append_reason = True
1273-
for tup in self.reasons[new_r]:
1261+
for tup in fgraph._eq_tracker_reasons[new_r]:
12741262
if tup[0] == reason and tup[1] is r:
12751263
append_reason = False
12761264

@@ -1279,7 +1267,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12791267
# optimizations will change the graph
12801268
done = dict()
12811269
used_ids = dict()
1282-
self.reasons[new_r].append(
1270+
fgraph._eq_tracker_reasons[new_r].append(
12831271
(
12841272
reason,
12851273
r,
@@ -1303,19 +1291,19 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
13031291
).getvalue(),
13041292
)
13051293
)
1306-
self.replaced_by[r].append((reason, new_r))
1294+
fgraph._eq_tracker_replaced_by[r].append((reason, new_r))
13071295

1308-
if r in self.equiv:
1309-
r_set = self.equiv[r]
1296+
if r in fgraph._eq_tracker_equiv:
1297+
r_set = fgraph._eq_tracker_equiv[r]
13101298
else:
1311-
r_set = self.equiv.setdefault(r, {r})
1312-
self.all_variables_ever.append(r)
1299+
r_set = fgraph._eq_tracker_equiv.setdefault(r, {r})
1300+
fgraph._eq_tracker_all_variables_ever.append(r)
13131301

1314-
if new_r in self.equiv:
1315-
new_r_set = self.equiv[new_r]
1302+
if new_r in fgraph._eq_tracker_equiv:
1303+
new_r_set = fgraph._eq_tracker_equiv[new_r]
13161304
else:
1317-
new_r_set = self.equiv.setdefault(new_r, {new_r})
1318-
self.all_variables_ever.append(new_r)
1305+
new_r_set = fgraph._eq_tracker_equiv.setdefault(new_r, {new_r})
1306+
fgraph._eq_tracker_all_variables_ever.append(new_r)
13191307

13201308
assert new_r in new_r_set
13211309
assert r in r_set
@@ -1324,17 +1312,11 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
13241312
# transfer all the elements of the old one to the new one
13251313
r_set.update(new_r_set)
13261314
for like_new_r in new_r_set:
1327-
self.equiv[like_new_r] = r_set
1315+
fgraph._eq_tracker_equiv[like_new_r] = r_set
13281316
assert like_new_r in r_set
13291317

1330-
assert self.equiv[r] is r_set
1331-
assert self.equiv[new_r] is r_set
1332-
1333-
def printstuff(self):
1334-
for key in self.equiv:
1335-
print(key)
1336-
for e in self.equiv[key]:
1337-
print(" ", e)
1318+
assert fgraph._eq_tracker_equiv[r] is r_set
1319+
assert fgraph._eq_tracker_equiv[new_r] is r_set
13381320

13391321

13401322
# List of default version of make thunk.
@@ -1390,9 +1372,7 @@ def make_all(
13901372
# Compute a topological ordering that IGNORES the destroy_map
13911373
# of destructive Ops. This will be OK, because every thunk is
13921374
# evaluated on a copy of its input.
1393-
fgraph_equiv = fgraph.equivalence_tracker
1394-
order_outputs = copy.copy(fgraph_equiv.all_variables_ever)
1395-
del fgraph_equiv
1375+
order_outputs = copy.copy(fgraph._eq_tracker_all_variables_ever)
13961376
order_outputs.reverse()
13971377
order = io_toposort(fgraph.inputs, order_outputs)
13981378

@@ -1625,7 +1605,7 @@ def f():
16251605
# insert a given apply node. If that is not True,
16261606
# we would need to loop over all node outputs,
16271607
# But this make the output uglier.
1628-
reason = fgraph.equivalence_tracker.reasons[node.outputs[0]]
1608+
reason = fgraph._eq_tracker_reasons[node.outputs[0]]
16291609
if not reason:
16301610
raise
16311611
opt = str(reason[0][0])
@@ -1738,7 +1718,7 @@ def f():
17381718
# insert a given apply node. If that is not True,
17391719
# we would need to loop over all node outputs,
17401720
# But this make the output uglier.
1741-
reason = fgraph.equivalence_tracker.reasons[node.outputs[0]]
1721+
reason = fgraph._eq_tracker_reasons[node.outputs[0]]
17421722
if not reason:
17431723
raise
17441724
opt = str(reason[0][0])
@@ -1865,9 +1845,7 @@ def thunk():
18651845
# But it is very slow and it is not sure it will help.
18661846
gc.collect()
18671847

1868-
_find_bad_optimizations(
1869-
order, fgraph.equivalence_tracker.reasons, r_vals
1870-
)
1848+
_find_bad_optimizations(order, fgraph._eq_tracker_reasons, r_vals)
18711849

18721850
#####
18731851
# Postcondition: the input and output variables are
@@ -2058,10 +2036,9 @@ def __init__(
20582036

20592037
# make the fgraph
20602038
for i in range(mode.stability_patience):
2061-
fgraph, additional_outputs, equivalence_tracker = _optcheck_fgraph(
2039+
fgraph, additional_outputs = _optcheck_fgraph(
20622040
inputs, outputs, accept_inplace
20632041
)
2064-
fgraph.equivalence_tracker = equivalence_tracker
20652042

20662043
with config.change_flags(compute_test_value=config.compute_test_value_opt):
20672044
optimizer(fgraph)
@@ -2073,8 +2050,8 @@ def __init__(
20732050
if i == 0:
20742051
fgraph0 = fgraph
20752052
else:
2076-
li = fgraph.equivalence_tracker.event_list
2077-
l0 = fgraph0.equivalence_tracker.event_list
2053+
li = fgraph._eq_tracker_event_list
2054+
l0 = fgraph0._eq_tracker_event_list
20782055
if li != l0:
20792056
infolog = StringIO()
20802057
print("Optimization process is unstable...", file=infolog)

0 commit comments

Comments
 (0)