Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Features less stateful #832

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aesara/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def infer_shape(outs, inputs, input_shapes):
assert len(inp_shp) == inp.type.ndim

shape_feature = ShapeFeature()
shape_feature.on_attach(FunctionGraph([], []))
dummy_fgraph = FunctionGraph([], [])
shape_feature.on_attach(dummy_fgraph)

# Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes):
Expand All @@ -72,7 +73,6 @@ def local_traverse(out):

# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")

ret = []
Expand Down
163 changes: 69 additions & 94 deletions aesara/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from aesara.configdefaults import config
from aesara.graph.basic import Variable, io_toposort
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import AlreadyThere, BadOptimization
from aesara.graph.features import AlreadyThere, BadOptimization, Feature
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.utils import InconsistencyError, MethodNotDefined
from aesara.link.basic import Container, LocalLinker
Expand Down Expand Up @@ -437,7 +437,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
input_specs, output_specs, accept_inplace, force_clone=True
)
fgraph.attach_feature(equivalence_tracker)
return fgraph, updates, equivalence_tracker
return fgraph, updates


class DataDestroyed:
Expand Down Expand Up @@ -1172,98 +1172,84 @@ def __ne__(self, other):
return not (self == other)


class _VariableEquivalenceTracker:
class _VariableEquivalenceTracker(Feature):
"""
A FunctionGraph Feature that keeps tabs on an FunctionGraph and
tries to detect problems.

"""

fgraph = None
"""WRITEME"""

equiv = None
"""WRITEME"""

active_nodes = None
"""WRITEME"""

inactive_nodes = None
"""WRITEME"""

all_variables_ever = None
"""WRITEME"""

reasons = None
"""WRITEME"""

replaced_by = None
"""WRITEME"""

event_list = None
"""WRITEME"""

def __init__(self):
self.fgraph = None

def on_attach(self, fgraph):
if self.fgraph is not None:

if hasattr(fgraph, "_eq_tracker_equiv"):
raise AlreadyThere()

self.equiv = {}
self.active_nodes = set()
self.inactive_nodes = set()
self.fgraph = fgraph
self.all_variables_ever = []
self.reasons = {}
self.replaced_by = {}
self.event_list = []
fgraph._eq_tracker_equiv = {}
fgraph._eq_tracker_active_nodes = set()
fgraph._eq_tracker_inactive_nodes = set()
fgraph._eq_tracker_fgraph = fgraph
fgraph._eq_tracker_all_variables_ever = []
fgraph._eq_tracker_reasons = {}
fgraph._eq_tracker_replaced_by = {}
fgraph._eq_tracker_event_list = []

for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
self.on_import(fgraph, node, "var_equiv_on_attach")

def on_detach(self, fgraph):
assert fgraph is self.fgraph
self.fgraph = None
del fgraph._eq_tracker_equiv
del fgraph._eq_tracker_active_nodes
del fgraph._eq_tracker_inactive_nodes
del fgraph._eq_tracker_fgraph
del fgraph._eq_tracker_all_variables_ever
del fgraph._eq_tracker_reasons
del fgraph._eq_tracker_replaced_by
del fgraph._eq_tracker_event_list

def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent("prune", node, reason=str(reason)))
assert node in self.active_nodes
assert node not in self.inactive_nodes
self.active_nodes.remove(node)
self.inactive_nodes.add(node)
fgraph._eq_tracker_event_list.append(
_FunctionGraphEvent("prune", node, reason=str(reason))
)
assert node in fgraph._eq_tracker_active_nodes
assert node not in fgraph._eq_tracker_inactive_nodes
fgraph._eq_tracker_active_nodes.remove(node)
fgraph._eq_tracker_inactive_nodes.add(node)

def on_import(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent("import", node, reason=str(reason)))
fgraph._eq_tracker_event_list.append(
_FunctionGraphEvent("import", node, reason=str(reason))
)

assert node not in self.active_nodes
self.active_nodes.add(node)
assert node not in fgraph._eq_tracker_active_nodes
fgraph._eq_tracker_active_nodes.add(node)

if node in self.inactive_nodes:
self.inactive_nodes.remove(node)
if node in fgraph._eq_tracker_inactive_nodes:
fgraph._eq_tracker_inactive_nodes.remove(node)
for r in node.outputs:
assert r in self.equiv
assert r in fgraph._eq_tracker_equiv
else:
for r in node.outputs:
assert r not in self.equiv
self.equiv[r] = {r}
self.all_variables_ever.append(r)
self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, [])
assert r not in fgraph._eq_tracker_equiv
fgraph._eq_tracker_equiv[r] = {r}
fgraph._eq_tracker_all_variables_ever.append(r)
fgraph._eq_tracker_reasons.setdefault(r, [])
fgraph._eq_tracker_replaced_by.setdefault(r, [])
for r in node.inputs:
self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, [])
fgraph._eq_tracker_reasons.setdefault(r, [])
fgraph._eq_tracker_replaced_by.setdefault(r, [])

def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
reason = str(reason)
self.event_list.append(
fgraph._eq_tracker_event_list.append(
_FunctionGraphEvent("change", node, reason=reason, idx=i)
)

self.reasons.setdefault(new_r, [])
self.replaced_by.setdefault(new_r, [])
fgraph._eq_tracker_reasons.setdefault(new_r, [])
fgraph._eq_tracker_replaced_by.setdefault(new_r, [])

append_reason = True
for tup in self.reasons[new_r]:
for tup in fgraph._eq_tracker_reasons[new_r]:
if tup[0] == reason and tup[1] is r:
append_reason = False

Expand All @@ -1272,7 +1258,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
# optimizations will change the graph
done = dict()
used_ids = dict()
self.reasons[new_r].append(
fgraph._eq_tracker_reasons[new_r].append(
(
reason,
r,
Expand All @@ -1296,19 +1282,19 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
).getvalue(),
)
)
self.replaced_by[r].append((reason, new_r))
fgraph._eq_tracker_replaced_by[r].append((reason, new_r))

if r in self.equiv:
r_set = self.equiv[r]
if r in fgraph._eq_tracker_equiv:
r_set = fgraph._eq_tracker_equiv[r]
else:
r_set = self.equiv.setdefault(r, {r})
self.all_variables_ever.append(r)
r_set = fgraph._eq_tracker_equiv.setdefault(r, {r})
fgraph._eq_tracker_all_variables_ever.append(r)

if new_r in self.equiv:
new_r_set = self.equiv[new_r]
if new_r in fgraph._eq_tracker_equiv:
new_r_set = fgraph._eq_tracker_equiv[new_r]
else:
new_r_set = self.equiv.setdefault(new_r, {new_r})
self.all_variables_ever.append(new_r)
new_r_set = fgraph._eq_tracker_equiv.setdefault(new_r, {new_r})
fgraph._eq_tracker_all_variables_ever.append(new_r)

assert new_r in new_r_set
assert r in r_set
Expand All @@ -1317,17 +1303,11 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
# transfer all the elements of the old one to the new one
r_set.update(new_r_set)
for like_new_r in new_r_set:
self.equiv[like_new_r] = r_set
fgraph._eq_tracker_equiv[like_new_r] = r_set
assert like_new_r in r_set

assert self.equiv[r] is r_set
assert self.equiv[new_r] is r_set

def printstuff(self):
for key in self.equiv:
print(key)
for e in self.equiv[key]:
print(" ", e)
assert fgraph._eq_tracker_equiv[r] is r_set
assert fgraph._eq_tracker_equiv[new_r] is r_set


# List of default version of make thunk.
Expand Down Expand Up @@ -1382,9 +1362,7 @@ def make_all(
# Compute a topological ordering that IGNORES the destroy_map
# of destructive Ops. This will be OK, because every thunk is
# evaluated on a copy of its input.
fgraph_equiv = fgraph.equivalence_tracker
order_outputs = copy.copy(fgraph_equiv.all_variables_ever)
del fgraph_equiv
order_outputs = copy.copy(fgraph._eq_tracker_all_variables_ever)
order_outputs.reverse()
order = io_toposort(fgraph.inputs, order_outputs)

Expand Down Expand Up @@ -1618,7 +1596,7 @@ def f():
# insert a given apply node. If that is not True,
# we would need to loop over all node outputs,
# But this make the output uglier.
reason = fgraph.equivalence_tracker.reasons[node.outputs[0]]
reason = fgraph._eq_tracker_reasons[node.outputs[0]]
if not reason:
raise
opt = str(reason[0][0])
Expand Down Expand Up @@ -1731,7 +1709,7 @@ def f():
# insert a given apply node. If that is not True,
# we would need to loop over all node outputs,
# But this make the output uglier.
reason = fgraph.equivalence_tracker.reasons[node.outputs[0]]
reason = fgraph._eq_tracker_reasons[node.outputs[0]]
if not reason:
raise
opt = str(reason[0][0])
Expand Down Expand Up @@ -1858,9 +1836,7 @@ def thunk():
# But it is very slow and it is not sure it will help.
gc.collect()

_find_bad_optimizations(
order, fgraph.equivalence_tracker.reasons, r_vals
)
_find_bad_optimizations(order, fgraph._eq_tracker_reasons, r_vals)

#####
# Postcondition: the input and output variables are
Expand Down Expand Up @@ -2042,10 +2018,9 @@ def __init__(

# make the fgraph
for i in range(mode.stability_patience):
fgraph, additional_outputs, equivalence_tracker = _optcheck_fgraph(
fgraph, additional_outputs = _optcheck_fgraph(
inputs, outputs, accept_inplace
)
fgraph.equivalence_tracker = equivalence_tracker

with config.change_flags(compute_test_value=config.compute_test_value_opt):
optimizer(fgraph)
Expand All @@ -2057,8 +2032,8 @@ def __init__(
if i == 0:
fgraph0 = fgraph
else:
li = fgraph.equivalence_tracker.event_list
l0 = fgraph0.equivalence_tracker.event_list
li = fgraph._eq_tracker_event_list
l0 = fgraph0._eq_tracker_event_list
if li != l0:
infolog = StringIO()
print("Optimization process is unstable...", file=infolog)
Expand Down
29 changes: 14 additions & 15 deletions aesara/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import warnings
from itertools import chain
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union

import numpy as np
from typing_extensions import Literal
Expand Down Expand Up @@ -142,31 +142,30 @@ class Supervisor(Feature):

"""

def __init__(self, protected):
self.fgraph = None
self.protected = list(protected)

def clone(self):
return type(self)(self.protected)
def __init__(self, protected: Iterable[Variable]):
self.initial_protected = set(protected)

def on_attach(self, fgraph):
if hasattr(fgraph, "_supervisor"):
raise AlreadyThere(f"A Supervisor is already attached to {fgraph}.")
if hasattr(fgraph, "_supervisor_protected"):
# Add the protected variables from this `Supervisor` instance, in
# case something is trying to update them by adding another
# `Supervisor`
fgraph._supervisor_protected.update(self.initial_protected)
raise AlreadyThere("Supervisor feature is already present")

if self.fgraph is not None and self.fgraph != fgraph:
raise Exception("This Feature is already associated with a FunctionGraph")
fgraph._supervisor_protected = set(self.initial_protected)

fgraph._supervisor = self
self.fgraph = fgraph
def clone(self):
return type(self)(self.initial_protected)

def validate(self, fgraph):
if config.cycle_detection == "fast" and hasattr(fgraph, "has_destroyers"):
if fgraph.has_destroyers(self.protected):
if fgraph.has_destroyers(fgraph._supervisor_protected):
raise InconsistencyError("Trying to destroy protected variables.")
return True
if not hasattr(fgraph, "destroyers"):
return True
for r in self.protected + list(fgraph.outputs):
for r in chain(fgraph._supervisor_protected, fgraph.outputs):
if fgraph.destroyers(r):
raise InconsistencyError(f"Trying to destroy a protected variable: {r}")

Expand Down
Loading