diff --git a/aesara/compile/builders.py b/aesara/compile/builders.py index 031e619456..e873bd573b 100644 --- a/aesara/compile/builders.py +++ b/aesara/compile/builders.py @@ -47,7 +47,8 @@ def infer_shape(outs, inputs, input_shapes): assert len(inp_shp) == inp.type.ndim shape_feature = ShapeFeature() - shape_feature.on_attach(FunctionGraph([], [])) + dummy_fgraph = FunctionGraph([], []) + shape_feature.on_attach(dummy_fgraph) # Initialize shape_of with the input shapes for inp, inp_shp in zip(inputs, input_shapes): @@ -72,7 +73,6 @@ def local_traverse(out): # shape_feature.on_import does not actually use an fgraph # It will call infer_shape and set_shape appropriately - dummy_fgraph = None shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy") ret = [] diff --git a/aesara/compile/debugmode.py b/aesara/compile/debugmode.py index c3c116b091..08f296514d 100644 --- a/aesara/compile/debugmode.py +++ b/aesara/compile/debugmode.py @@ -31,7 +31,7 @@ from aesara.configdefaults import config from aesara.graph.basic import Variable, io_toposort from aesara.graph.destroyhandler import DestroyHandler -from aesara.graph.features import AlreadyThere, BadOptimization +from aesara.graph.features import AlreadyThere, BadOptimization, Feature from aesara.graph.op import HasInnerGraph, Op from aesara.graph.utils import InconsistencyError, MethodNotDefined from aesara.link.basic import Container, LocalLinker @@ -437,7 +437,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): input_specs, output_specs, accept_inplace, force_clone=True ) fgraph.attach_feature(equivalence_tracker) - return fgraph, updates, equivalence_tracker + return fgraph, updates class DataDestroyed: @@ -1172,98 +1172,84 @@ def __ne__(self, other): return not (self == other) -class _VariableEquivalenceTracker: +class _VariableEquivalenceTracker(Feature): """ A FunctionGraph Feature that keeps tabs on an FunctionGraph and tries to detect problems. """ - fgraph = None - """WRITEME""" - - equiv = None - """WRITEME""" - - active_nodes = None - """WRITEME""" - - inactive_nodes = None - """WRITEME""" - - all_variables_ever = None - """WRITEME""" - - reasons = None - """WRITEME""" - - replaced_by = None - """WRITEME""" - - event_list = None - """WRITEME""" - - def __init__(self): - self.fgraph = None - def on_attach(self, fgraph): - if self.fgraph is not None: + + if hasattr(fgraph, "_eq_tracker_equiv"): raise AlreadyThere() - self.equiv = {} - self.active_nodes = set() - self.inactive_nodes = set() - self.fgraph = fgraph - self.all_variables_ever = [] - self.reasons = {} - self.replaced_by = {} - self.event_list = [] + fgraph._eq_tracker_equiv = {} + fgraph._eq_tracker_active_nodes = set() + fgraph._eq_tracker_inactive_nodes = set() + fgraph._eq_tracker_fgraph = fgraph + fgraph._eq_tracker_all_variables_ever = [] + fgraph._eq_tracker_reasons = {} + fgraph._eq_tracker_replaced_by = {} + fgraph._eq_tracker_event_list = [] + for node in fgraph.toposort(): - self.on_import(fgraph, node, "on_attach") + self.on_import(fgraph, node, "var_equiv_on_attach") def on_detach(self, fgraph): - assert fgraph is self.fgraph self.fgraph = None + del fgraph._eq_tracker_equiv + del fgraph._eq_tracker_active_nodes + del fgraph._eq_tracker_inactive_nodes + del fgraph._eq_tracker_fgraph + del fgraph._eq_tracker_all_variables_ever + del fgraph._eq_tracker_reasons + del fgraph._eq_tracker_replaced_by + del fgraph._eq_tracker_event_list def on_prune(self, fgraph, node, reason): - self.event_list.append(_FunctionGraphEvent("prune", node, reason=str(reason))) - assert node in self.active_nodes - assert node not in self.inactive_nodes - self.active_nodes.remove(node) - self.inactive_nodes.add(node) + fgraph._eq_tracker_event_list.append( + _FunctionGraphEvent("prune", node, reason=str(reason)) + ) + assert node in fgraph._eq_tracker_active_nodes + assert node not in fgraph._eq_tracker_inactive_nodes + fgraph._eq_tracker_active_nodes.remove(node) + fgraph._eq_tracker_inactive_nodes.add(node) def on_import(self, fgraph, node, reason): - self.event_list.append(_FunctionGraphEvent("import", node, reason=str(reason))) + fgraph._eq_tracker_event_list.append( + _FunctionGraphEvent("import", node, reason=str(reason)) + ) - assert node not in self.active_nodes - self.active_nodes.add(node) + assert node not in fgraph._eq_tracker_active_nodes + fgraph._eq_tracker_active_nodes.add(node) - if node in self.inactive_nodes: - self.inactive_nodes.remove(node) + if node in fgraph._eq_tracker_inactive_nodes: + fgraph._eq_tracker_inactive_nodes.remove(node) for r in node.outputs: - assert r in self.equiv + assert r in fgraph._eq_tracker_equiv else: for r in node.outputs: - assert r not in self.equiv - self.equiv[r] = {r} - self.all_variables_ever.append(r) - self.reasons.setdefault(r, []) - self.replaced_by.setdefault(r, []) + assert r not in fgraph._eq_tracker_equiv + fgraph._eq_tracker_equiv[r] = {r} + fgraph._eq_tracker_all_variables_ever.append(r) + fgraph._eq_tracker_reasons.setdefault(r, []) + fgraph._eq_tracker_replaced_by.setdefault(r, []) for r in node.inputs: - self.reasons.setdefault(r, []) - self.replaced_by.setdefault(r, []) + fgraph._eq_tracker_reasons.setdefault(r, []) + fgraph._eq_tracker_replaced_by.setdefault(r, []) def on_change_input(self, fgraph, node, i, r, new_r, reason=None): reason = str(reason) - self.event_list.append( + fgraph._eq_tracker_event_list.append( _FunctionGraphEvent("change", node, reason=reason, idx=i) ) - self.reasons.setdefault(new_r, []) - self.replaced_by.setdefault(new_r, []) + fgraph._eq_tracker_reasons.setdefault(new_r, []) + fgraph._eq_tracker_replaced_by.setdefault(new_r, []) append_reason = True - for tup in self.reasons[new_r]: + for tup in fgraph._eq_tracker_reasons[new_r]: if tup[0] == reason and tup[1] is r: append_reason = False @@ -1272,7 +1258,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): # optimizations will change the graph done = dict() used_ids = dict() - self.reasons[new_r].append( + fgraph._eq_tracker_reasons[new_r].append( ( reason, r, @@ -1296,19 +1282,19 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): ).getvalue(), ) ) - self.replaced_by[r].append((reason, new_r)) + fgraph._eq_tracker_replaced_by[r].append((reason, new_r)) - if r in self.equiv: - r_set = self.equiv[r] + if r in fgraph._eq_tracker_equiv: + r_set = fgraph._eq_tracker_equiv[r] else: - r_set = self.equiv.setdefault(r, {r}) - self.all_variables_ever.append(r) + r_set = fgraph._eq_tracker_equiv.setdefault(r, {r}) + fgraph._eq_tracker_all_variables_ever.append(r) - if new_r in self.equiv: - new_r_set = self.equiv[new_r] + if new_r in fgraph._eq_tracker_equiv: + new_r_set = fgraph._eq_tracker_equiv[new_r] else: - new_r_set = self.equiv.setdefault(new_r, {new_r}) - self.all_variables_ever.append(new_r) + new_r_set = fgraph._eq_tracker_equiv.setdefault(new_r, {new_r}) + fgraph._eq_tracker_all_variables_ever.append(new_r) assert new_r in new_r_set assert r in r_set @@ -1317,17 +1303,11 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): # transfer all the elements of the old one to the new one r_set.update(new_r_set) for like_new_r in new_r_set: - self.equiv[like_new_r] = r_set + fgraph._eq_tracker_equiv[like_new_r] = r_set assert like_new_r in r_set - assert self.equiv[r] is r_set - assert self.equiv[new_r] is r_set - - def printstuff(self): - for key in self.equiv: - print(key) - for e in self.equiv[key]: - print(" ", e) + assert fgraph._eq_tracker_equiv[r] is r_set + assert fgraph._eq_tracker_equiv[new_r] is r_set # List of default version of make thunk. @@ -1382,9 +1362,7 @@ def make_all( # Compute a topological ordering that IGNORES the destroy_map # of destructive Ops. This will be OK, because every thunk is # evaluated on a copy of its input. - fgraph_equiv = fgraph.equivalence_tracker - order_outputs = copy.copy(fgraph_equiv.all_variables_ever) - del fgraph_equiv + order_outputs = copy.copy(fgraph._eq_tracker_all_variables_ever) order_outputs.reverse() order = io_toposort(fgraph.inputs, order_outputs) @@ -1618,7 +1596,7 @@ def f(): # insert a given apply node. If that is not True, # we would need to loop over all node outputs, # But this make the output uglier. - reason = fgraph.equivalence_tracker.reasons[node.outputs[0]] + reason = fgraph._eq_tracker_reasons[node.outputs[0]] if not reason: raise opt = str(reason[0][0]) @@ -1731,7 +1709,7 @@ def f(): # insert a given apply node. If that is not True, # we would need to loop over all node outputs, # But this make the output uglier. - reason = fgraph.equivalence_tracker.reasons[node.outputs[0]] + reason = fgraph._eq_tracker_reasons[node.outputs[0]] if not reason: raise opt = str(reason[0][0]) @@ -1858,9 +1836,7 @@ def thunk(): # But it is very slow and it is not sure it will help. gc.collect() - _find_bad_optimizations( - order, fgraph.equivalence_tracker.reasons, r_vals - ) + _find_bad_optimizations(order, fgraph._eq_tracker_reasons, r_vals) ##### # Postcondition: the input and output variables are @@ -2042,10 +2018,9 @@ def __init__( # make the fgraph for i in range(mode.stability_patience): - fgraph, additional_outputs, equivalence_tracker = _optcheck_fgraph( + fgraph, additional_outputs = _optcheck_fgraph( inputs, outputs, accept_inplace ) - fgraph.equivalence_tracker = equivalence_tracker with config.change_flags(compute_test_value=config.compute_test_value_opt): optimizer(fgraph) @@ -2057,8 +2032,8 @@ def __init__( if i == 0: fgraph0 = fgraph else: - li = fgraph.equivalence_tracker.event_list - l0 = fgraph0.equivalence_tracker.event_list + li = fgraph._eq_tracker_event_list + l0 = fgraph0._eq_tracker_event_list if li != l0: infolog = StringIO() print("Optimization process is unstable...", file=infolog) diff --git a/aesara/compile/function/types.py b/aesara/compile/function/types.py index cc41190d02..b28d914ecc 100644 --- a/aesara/compile/function/types.py +++ b/aesara/compile/function/types.py @@ -6,7 +6,7 @@ import time import warnings from itertools import chain -from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union import numpy as np from typing_extensions import Literal @@ -142,31 +142,30 @@ class Supervisor(Feature): """ - def __init__(self, protected): - self.fgraph = None - self.protected = list(protected) - - def clone(self): - return type(self)(self.protected) + def __init__(self, protected: Iterable[Variable]): + self.initial_protected = set(protected) def on_attach(self, fgraph): - if hasattr(fgraph, "_supervisor"): - raise AlreadyThere(f"A Supervisor is already attached to {fgraph}.") + if hasattr(fgraph, "_supervisor_protected"): + # Add the protected variables from this `Supervisor` instance, in + # case something is trying to update them by adding another + # `Supervisor` + fgraph._supervisor_protected.update(self.initial_protected) + raise AlreadyThere("Supervisor feature is already present") - if self.fgraph is not None and self.fgraph != fgraph: - raise Exception("This Feature is already associated with a FunctionGraph") + fgraph._supervisor_protected = set(self.initial_protected) - fgraph._supervisor = self - self.fgraph = fgraph + def clone(self): + return type(self)(self.initial_protected) def validate(self, fgraph): if config.cycle_detection == "fast" and hasattr(fgraph, "has_destroyers"): - if fgraph.has_destroyers(self.protected): + if fgraph.has_destroyers(fgraph._supervisor_protected): raise InconsistencyError("Trying to destroy protected variables.") return True if not hasattr(fgraph, "destroyers"): return True - for r in self.protected + list(fgraph.outputs): + for r in chain(fgraph._supervisor_protected, fgraph.outputs): if fgraph.destroyers(r): raise InconsistencyError(f"Trying to destroy a protected variable: {r}") diff --git a/aesara/graph/destroyhandler.py b/aesara/graph/destroyhandler.py index abc4894715..4649057a53 100644 --- a/aesara/graph/destroyhandler.py +++ b/aesara/graph/destroyhandler.py @@ -5,8 +5,11 @@ """ import itertools from collections import OrderedDict, deque +from types import MethodType +from typing import TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, Tuple +from typing import Type as TypingType +from typing import Union, cast -import aesara from aesara.configdefaults import config from aesara.graph.basic import Constant from aesara.graph.features import AlreadyThere, Bookkeeper @@ -14,6 +17,11 @@ from aesara.misc.ordered_set import OrderedSet +if TYPE_CHECKING: + from aesara.graph.basic import Apply, Variable + from aesara.graph.fg import FunctionGraph + + class ProtocolError(Exception): """ Raised when FunctionGraph calls DestroyHandler callbacks in @@ -23,7 +31,9 @@ class ProtocolError(Exception): """ -def _contains_cycle(fgraph, orderings): +def _contains_cycle( + fgraph: "FunctionGraph", orderings: Dict["Apply", Set["Apply"]] +) -> bool: """ Function to check if the given graph contains a cycle @@ -76,12 +86,18 @@ def _contains_cycle(fgraph, orderings): # dictionary also runs slower when storing ids than when # storing objects. + # TODO FIXME: This use of mixed `Apply` and `Variable` types is confusing + # and unnecessary. We might need to start using nullary `Apply` nodes for + # atomic variables in order to fix this, though. + # dict mapping an Apply or Variable instance to the number # of its parents (including parents imposed by orderings) # that haven't been visited yet - parent_counts = {} + parent_counts: Dict[Union["Apply", "Variable"], int] = {} # dict mapping an Apply or Variable instance to its children - node_to_children = {} + node_to_children: Dict[ + Union["Apply", "Variable"], List[Union["Apply", "Variable"]] + ] = {} # visitable: A container holding all Variable and Apply instances # that can currently be visited according to the graph topology @@ -94,7 +110,7 @@ def _contains_cycle(fgraph, orderings): # we don't care about the traversal order here as much as we do # in io_toposort because we aren't trying to generate an ordering # on the nodes - visitable = deque() + visitable: Deque[Union["Apply", "Variable"]] = deque() # IG: visitable could in principle be initialized to fgraph.inputs # + fgraph.orphans... if there were an fgraph.orphans structure. @@ -109,7 +125,6 @@ def _contains_cycle(fgraph, orderings): # Pass through all the nodes to build visitable, parent_count, and # node_to_children for var in fgraph.variables: - # this is faster than calling get_parents owner = var.owner # variables don't appear in orderings, so we don't need to worry @@ -125,7 +140,7 @@ def _contains_cycle(fgraph, orderings): parent_counts[var] = 0 for a_n in fgraph.apply_nodes: - parents = list(a_n.inputs) + parents: List[Union["Apply", "Variable"]] = list(a_n.inputs) # This is faster than conditionally extending # IG: I tried using a shared empty_list = [] constructed # outside of the for loop to avoid constructing multiple @@ -175,21 +190,37 @@ def _contains_cycle(fgraph, orderings): return visited != len(parent_counts) -def _build_droot_impact(destroy_handler): - droot = {} # destroyed view + nonview variables -> foundation - impact = {} # destroyed nonview variable -> it + all views of it - root_destroyer = {} # root -> destroyer apply - - for app in destroy_handler.destroyers: +def _build_droot_impact( + fgraph: "FunctionGraph", +) -> Tuple[ + Dict["Variable", "Variable"], + Dict["Variable", Set["Variable"]], + Dict["Variable", "Apply"], +]: + # destroyed view + nonview variables -> foundation + droot: Dict["Variable", "Variable"] = OrderedDict() + # destroyed nonview variable -> it + all views of it + impact: Dict["Variable", Set["Variable"]] = OrderedDict() + # root -> destroyer apply + root_destroyer: Dict["Variable", "Apply"] = OrderedDict() + + # TODO FIXME: How do we get a type interface working for these + # `FunctionGraph` additions? + assert hasattr(fgraph, "_destroyhandler_destroyers") + assert hasattr(fgraph, "view_i") + assert hasattr(fgraph, "view_o") + + for app in fgraph._destroyhandler_destroyers: for output_idx, input_idx_list in app.op.destroy_map.items(): if len(input_idx_list) != 1: raise NotImplementedError() - input_idx = input_idx_list[0] - input = app.inputs[input_idx] + + input_idx: int = input_idx_list[0] + input: "Variable" = app.inputs[input_idx] # Find non-view variable which is ultimately viewed by input. - view_i = destroy_handler.view_i - _r = input + view_i: Dict["Variable", "Variable"] = fgraph.view_i + _r: Optional["Variable"] = input while _r is not None: r = _r _r = view_i.get(r) @@ -197,6 +228,7 @@ def _build_droot_impact(destroy_handler): if input_root in droot: raise InconsistencyError(f"Multiple destroyers of {input_root}") + droot[input_root] = input_root root_destroyer[input_root] = app @@ -204,11 +236,11 @@ def _build_droot_impact(destroy_handler): # an OrderedSet input_impact input_impact = OrderedSet() - q = deque() + q: Deque["Variable"] = deque() q.append(input_root) while len(q) > 0: v = q.popleft() - for n in destroy_handler.view_o.get(v, []): + for n in fgraph.view_o.get(v, []): input_impact.add(n) q.append(n) @@ -216,13 +248,15 @@ def _build_droot_impact(destroy_handler): assert v not in droot droot[v] = input_root - impact[input_root] = input_impact + impact[input_root] = cast(Set["Variable"], input_impact) impact[input_root].add(input_root) return droot, impact, root_destroyer -def fast_inplace_check(fgraph, inputs): +def fast_inplace_check( + fgraph: "FunctionGraph", inputs: List["Variable"] +) -> List["Variable"]: """ Return the variables in inputs that are possible candidate for as inputs of inplace operation. @@ -233,12 +267,11 @@ def fast_inplace_check(fgraph, inputs): Inputs Variable that you want to use as inplace destination. """ - Supervisor = aesara.compile.function.types.Supervisor - protected_inputs = [ - f.protected for f in fgraph._features if isinstance(f, Supervisor) - ] - protected_inputs = sum(protected_inputs, []) # flatten the list - protected_inputs.extend(fgraph.outputs) + assert hasattr(fgraph, "has_destroyers") + + protected_inputs: Iterable["Variable"] = getattr( + fgraph, "_supervisor_protected", () + ) inputs = [ i @@ -250,7 +283,7 @@ def fast_inplace_check(fgraph, inputs): return inputs -class DestroyHandler(Bookkeeper): # noqa +class DestroyHandler(Bookkeeper): """ The DestroyHandler class detects when a graph is impossible to evaluate because of aliasing and destructive operations. @@ -293,42 +326,12 @@ class DestroyHandler(Bookkeeper): # noqa """ - pickle_rm_attr = ["destroyers", "has_destroyers"] - def __init__(self, do_imports_on_attach=True, algo=None): - self.fgraph = None self.do_imports_on_attach = do_imports_on_attach - """ - Maps every variable in the graph to its "foundation" (deepest - ancestor in view chain). - TODO: change name to var_to_vroot. - - """ - self.droot = OrderedDict() - - """ - Maps a variable to all variables that are indirect or direct views of it - (including itself) essentially the inverse of droot. - TODO: do all variables appear in this dict, or only those that are - foundations? - TODO: do only destroyed variables go in here? one old docstring said so. - TODO: rename to x_to_views after reverse engineering what x is - - """ - self.impact = OrderedDict() - - """ - If a var is destroyed, then this dict will map - droot[var] to the apply node that destroyed var - TODO: rename to vroot_to_destroyer - - """ - self.root_destroyer = OrderedDict() if algo is None: algo = config.cycle_detection self.algo = algo - self.fail_validate = OrderedDict() def clone(self): return type(self)(self.do_imports_on_attach, self.algo) @@ -350,48 +353,73 @@ def on_attach(self, fgraph): """ - if any(hasattr(fgraph, attr) for attr in ("destroyers", "destroy_handler")): + if hasattr(fgraph, "destroy_handler"): raise AlreadyThere("DestroyHandler feature is already present") - if self.fgraph is not None and self.fgraph != fgraph: - raise Exception( - "A DestroyHandler instance can only serve one FunctionGraph" - ) - - # Annotate the FunctionGraph # - self.unpickle(fgraph) fgraph.destroy_handler = self - self.fgraph = fgraph - self.destroyers = ( - OrderedSet() - ) # set of Apply instances with non-null destroy_map - self.view_i = {} # variable -> variable used in calculation - self.view_o = ( - {} - ) # variable -> set of variables that use this one as a direct input + fgraph.fail_validate: Dict["Variable", "Variable"] = OrderedDict() + """ + Maps every variable in the graph to its "foundation" (deepest + ancestor in view chain). + TODO: change name to var_to_vroot. + + """ + fgraph.droot: Dict["Variable", "Variable"] = OrderedDict() + + """ + Maps a variable to all variables that are indirect or direct views of it + (including itself) essentially the inverse of droot. + TODO: do all variables appear in this dict, or only those that are + foundations? + TODO: do only destroyed variables go in here? one old docstring said so. + TODO: rename to x_to_views after reverse engineering what x is + + """ + fgraph.impact: Dict["Variable", "Apply"] = OrderedDict() + + """ + If a var is destroyed, then this dict will map + droot[var] to the apply node that destroyed var + TODO: rename to vroot_to_destroyer + + """ + fgraph.root_destroyer: Dict["Variable", "Apply"] = OrderedDict() + + # set of Apply instances with non-null destroy_map + fgraph._destroyhandler_destroyers: Set["Apply"] = OrderedSet() + # variable -> variable used in calculation + fgraph.view_i: Dict["Variable", "Variable"] = OrderedDict() + # variable -> set of variables that use this one as a direct input + fgraph.view_o: Dict["Variable", Set["Variable"]] = OrderedDict() # clients: how many times does an apply use a given variable - self.clients = OrderedDict() # variable -> apply -> ninputs - self.stale_droot = True + fgraph._destroy_handler_clients: Dict[ + "Variable", Dict["Apply", int] + ] = OrderedDict() + fgraph.stale_droot: bool = True + + fgraph.debug_all_apps: Set["Apply"] = set() - self.debug_all_apps = set() if self.do_imports_on_attach: - Bookkeeper.on_attach(self, fgraph) + super().on_attach(fgraph) - def unpickle(self, fgraph): - def get_destroyers_of(r): - droot, _, root_destroyer = self.refresh_droot_impact() + def get_destroyers_of( + fgraph: "FunctionGraph", r: "Variable" + ) -> List["Variable"]: + droot, _, root_destroyer = self.refresh_droot_impact(fgraph) try: return [root_destroyer[droot[r]]] except Exception: return [] - fgraph.destroyers = get_destroyers_of + fgraph.destroyers = MethodType(get_destroyers_of, fgraph) - def has_destroyers(protected_list): + def has_destroyers( + fgraph: "FunctionGraph", protected_vars: Iterable["Variable"] + ) -> bool: if self.algo != "fast": - droot, _, root_destroyer = self.refresh_droot_impact() - for protected_var in protected_list: + droot, _, root_destroyer = self.refresh_droot_impact(fgraph) + for protected_var in protected_vars: try: root_destroyer[droot[protected_var]] return True @@ -399,54 +427,56 @@ def has_destroyers(protected_list): pass return False - def recursive_destroys_finder(protected_var): + def recursive_destroys_finder(protected_var: "Variable") -> bool: # protected_var is the idx'th input of app. - for (app, idx) in fgraph.clients[protected_var]: + for app, idx in fgraph.clients[protected_var]: if app == "output": continue - destroy_maps = app.op.destroy_map.values() - # If True means that the apply node, destroys the protected_var. - if idx in [dmap for sublist in destroy_maps for dmap in sublist]: - return True - for var_idx in app.op.view_map.keys(): - if idx in app.op.view_map[var_idx]: - # We need to recursively check the destroy_map of all the - # outputs that we have a view_map on. - if recursive_destroys_finder(app.outputs[var_idx]): - return True + else: + assert isinstance(app, Apply) + destroy_maps = app.op.destroy_map.values() + # If True means that the apply node, destroys the protected_var. + if idx in [ + dmap for sublist in destroy_maps for dmap in sublist + ]: + return True + for var_idx in app.op.view_map.keys(): + if idx in app.op.view_map[var_idx]: + # We need to recursively check the destroy_map of all the + # outputs that we have a view_map on. + if recursive_destroys_finder(app.outputs[var_idx]): + return True return False - for protected_var in protected_list: + for protected_var in protected_vars: if recursive_destroys_finder(protected_var): return True return False - fgraph.has_destroyers = has_destroyers + fgraph.has_destroyers = MethodType(has_destroyers, fgraph) - def refresh_droot_impact(self): + def refresh_droot_impact(self, fgraph): """ - Makes sure self.droot, self.impact, and self.root_destroyer are up to + Makes sure ``droot``, ``impact``, and ``root_destroyer`` are up to date, and returns them (see docstrings for these properties above). """ - if self.stale_droot: - self.droot, self.impact, self.root_destroyer = _build_droot_impact(self) - self.stale_droot = False - return self.droot, self.impact, self.root_destroyer + if fgraph.stale_droot: + fgraph.droot, fgraph.impact, fgraph.root_destroyer = _build_droot_impact( + fgraph + ) + fgraph.stale_droot = False + return fgraph.droot, fgraph.impact, fgraph.root_destroyer def on_detach(self, fgraph): - if fgraph is not self.fgraph: - raise Exception("detaching wrong fgraph", fgraph) - del self.destroyers - del self.view_i - del self.view_o - del self.clients - del self.stale_droot - assert self.fgraph.destroyer_handler is self - delattr(self.fgraph, "destroyers") - delattr(self.fgraph, "has_destroyers") - delattr(self.fgraph, "destroy_handler") - self.fgraph = None + del fgraph._destroyhandler_destroyers + del fgraph.view_i + del fgraph.view_o + del fgraph._destroy_handler_clients + del fgraph.stale_droot + delattr(fgraph, "destroyers") + delattr(fgraph, "has_destroyers") + delattr(fgraph, "destroy_handler") def fast_destroy(self, fgraph, app, reason): """ @@ -467,12 +497,12 @@ def fast_destroy(self, fgraph, app, reason): for inp_idx in inputs: inp = app.inputs[inp_idx] if getattr(inp.tag, "indestructible", False) or isinstance(inp, Constant): - self.fail_validate[app] = InconsistencyError( + fgraph.fail_validate[app] = InconsistencyError( f"Attempting to destroy indestructible variables: {inp}" ) elif len(fgraph.clients[inp]) > 1: - self.fail_validate[app] = InconsistencyError( - "Destroyed variable has more than one client. " + str(reason) + fgraph.fail_validate[app] = InconsistencyError( + f"Destroyed variable has more than one client. {reason}" ) elif inp.owner: app2 = inp.owner @@ -482,14 +512,14 @@ def fast_destroy(self, fgraph, app, reason): if v: v = v.get(inp_idx2, []) if len(v) > 0: - self.fail_validate[app] = InconsistencyError( - "Destroyed variable has view_map. " + str(reason) + fgraph.fail_validate[app] = InconsistencyError( + f"Destroyed variable has view_map. {reason}" ) elif d: d = d.get(inp_idx2, []) if len(d) > 0: - self.fail_validate[app] = InconsistencyError( - "Destroyed variable has destroy_map. " + str(reason) + fgraph.fail_validate[app] = InconsistencyError( + f"Destroyed variable has destroy_map. {reason}" ) # These 2 assertions are commented since this function is called so many times @@ -498,20 +528,17 @@ def fast_destroy(self, fgraph, app, reason): # assert len(d) <= 1 def on_import(self, fgraph, app, reason): - """ - Add Apply instance to set which must be computed. - - """ - if app in self.debug_all_apps: + """Add an `Apply` instance to the set which must be computed.""" + if app in fgraph.debug_all_apps: raise ProtocolError("double import") - self.debug_all_apps.add(app) - # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) + fgraph.debug_all_apps.add(app) + # print 'DH IMPORT', app, id(app), id(self), len(fgraph.debug_all_apps) # If it's a destructive op, add it to our watch list dmap = app.op.destroy_map vmap = app.op.view_map if dmap: - self.destroyers.add(app) + fgraph._destroyhandler_destroyers.add(app) if self.algo == "fast": self.fast_destroy(fgraph, app, reason) @@ -523,40 +550,42 @@ def on_import(self, fgraph, app, reason): ) o = app.outputs[o_idx] i = app.inputs[i_idx_list[0]] - self.view_i[o] = i - self.view_o.setdefault(i, OrderedSet()).add(o) + fgraph.view_i[o] = i + fgraph.view_o.setdefault(i, OrderedSet()).add(o) - # update self.clients + # update fgraph._destroy_handler_clients for i, input in enumerate(app.inputs): - self.clients.setdefault(input, OrderedDict()).setdefault(app, 0) - self.clients[input][app] += 1 + fgraph._destroy_handler_clients.setdefault(input, OrderedDict()).setdefault( + app, 0 + ) + fgraph._destroy_handler_clients[input][app] += 1 for i, output in enumerate(app.outputs): - self.clients.setdefault(output, OrderedDict()) + fgraph._destroy_handler_clients.setdefault(output, OrderedDict()) - self.stale_droot = True + fgraph.stale_droot = True def on_prune(self, fgraph, app, reason): """ Remove Apply instance from set which must be computed. """ - if app not in self.debug_all_apps: + if app not in fgraph.debug_all_apps: raise ProtocolError("prune without import") - self.debug_all_apps.remove(app) + fgraph.debug_all_apps.remove(app) - # UPDATE self.clients + # UPDATE fgraph._destroy_handler_clients for input in set(app.inputs): - del self.clients[input][app] + del fgraph._destroy_handler_clients[input][app] if app.op.destroy_map: - self.destroyers.remove(app) + fgraph._destroyhandler_destroyers.remove(app) # Note: leaving empty client dictionaries in the struct. # Why? It's a pain to remove them. I think they aren't doing any harm, they will be # deleted on_detach(). - # UPDATE self.view_i, self.view_o + # UPDATE fgraph.view_i, fgraph.view_o for o_idx, i_idx_list in app.op.view_map.items(): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs @@ -564,17 +593,25 @@ def on_prune(self, fgraph, app, reason): o = app.outputs[o_idx] i = app.inputs[i_idx_list[0]] - del self.view_i[o] - - self.view_o[i].remove(o) - if not self.view_o[i]: - del self.view_o[i] - - self.stale_droot = True - if app in self.fail_validate: - del self.fail_validate[app] - - def on_change_input(self, fgraph, app, i, old_r, new_r, reason): + del fgraph.view_i[o] + + fgraph.view_o[i].remove(o) + if not fgraph.view_o[i]: + del fgraph.view_o[i] + + fgraph.stale_droot = True + if app in fgraph.fail_validate: + del fgraph.fail_validate[app] + + def on_change_input( + self, + fgraph, + app, + i, + old_r, + new_r, + reason, + ): """ app.inputs[i] changed from old_r to new_r. @@ -584,18 +621,20 @@ def on_change_input(self, fgraph, app, i, old_r, new_r, reason): # considered 'outputs' of the graph. pass else: - if app not in self.debug_all_apps: + if app not in fgraph.debug_all_apps: raise ProtocolError("change without import") - # UPDATE self.clients - self.clients[old_r][app] -= 1 - if self.clients[old_r][app] == 0: - del self.clients[old_r][app] + # UPDATE fgraph._destroy_handler_clients + fgraph._destroy_handler_clients[old_r][app] -= 1 + if fgraph._destroy_handler_clients[old_r][app] == 0: + del fgraph._destroy_handler_clients[old_r][app] - self.clients.setdefault(new_r, OrderedDict()).setdefault(app, 0) - self.clients[new_r][app] += 1 + fgraph._destroy_handler_clients.setdefault(new_r, OrderedDict()).setdefault( + app, 0 + ) + fgraph._destroy_handler_clients[new_r][app] += 1 - # UPDATE self.view_i, self.view_o + # UPDATE fgraph.view_i, fgraph.view_o for o_idx, i_idx_list in app.op.view_map.items(): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs @@ -606,46 +645,44 @@ def on_change_input(self, fgraph, app, i, old_r, new_r, reason): if app.inputs[i_idx] is not new_r: raise ProtocolError("wrong new_r on change") - self.view_i[output] = new_r + fgraph.view_i[output] = new_r - self.view_o[old_r].remove(output) - if not self.view_o[old_r]: - del self.view_o[old_r] + fgraph.view_o[old_r].remove(output) + if not fgraph.view_o[old_r]: + del fgraph.view_o[old_r] - self.view_o.setdefault(new_r, OrderedSet()).add(output) + fgraph.view_o.setdefault(new_r, OrderedSet()).add(output) if self.algo == "fast": - if app in self.fail_validate: - del self.fail_validate[app] + if app in fgraph.fail_validate: + del fgraph.fail_validate[app] self.fast_destroy(fgraph, app, reason) - self.stale_droot = True + fgraph.stale_droot = True - def validate(self, fgraph): + def validate(self, fgraph) -> bool: """ - Return None. - Raise InconsistencyError when - a) orderings() raises an error - b) orderings cannot be topologically sorted. + a) `FunctionGraph.orderings` raises an error + b) `FunctionGraph.orderings` cannot be topologically sorted. """ - if self.destroyers: + if fgraph._destroyhandler_destroyers: if self.algo == "fast": - if self.fail_validate: - app_err_pairs = self.fail_validate - self.fail_validate = OrderedDict() - # self.fail_validate can only be a hint that maybe/probably + if fgraph.fail_validate: + app_err_pairs = fgraph.fail_validate + fgraph.fail_validate = OrderedDict() + # fgraph.fail_validate can only be a hint that maybe/probably # there is a cycle.This is because inside replace() we could # record many reasons to not accept a change, but we don't # know which one will fail first inside validate(). Thus,the # graph might have already changed when we raise the - # self.fail_validate error. So before raising the error, we + # fgraph.fail_validate error. So before raising the error, we # double check here. for app in app_err_pairs: if app in fgraph.apply_nodes: self.fast_destroy(fgraph, app, "validate") - if self.fail_validate: - self.fail_validate = app_err_pairs + if fgraph.fail_validate: + fgraph.fail_validate = app_err_pairs raise app_err_pairs[app] else: ords = self.orderings(fgraph, ordered=False) @@ -665,120 +702,124 @@ def validate(self, fgraph): pass return True - def orderings(self, fgraph, ordered=True): - """ - Return orderings induced by destructive operations. + def orderings(self, fgraph, ordered: bool = True) -> Dict["Apply", Set["Apply"]]: + """Return orderings induced by destructive operations. - Raise InconsistencyError when - a) attempting to destroy indestructable variable, or - b) attempting to destroy a value multiple times, or - c) an Apply destroys (illegally) one of its own inputs by aliasing + Raise an `InconsistencyError` when + a) attempting to destroy indestructible variable, or + b) attempting to destroy a value multiple times, or + c) an `Apply` destroys (illegally) one of its own inputs by aliasing """ if ordered: - set_type = OrderedSet - rval = OrderedDict() + set_type = cast(TypingType[Set["Apply"]], OrderedSet) else: set_type = set - rval = dict() - - if self.destroyers: - # BUILD DATA STRUCTURES - # CHECK for multiple destructions during construction of variables - - droot, impact, __ignore = self.refresh_droot_impact() - - # check for destruction of constants - illegal_destroy = [ - r - for r in droot - if getattr(r.tag, "indestructible", False) or isinstance(r, Constant) - ] - if illegal_destroy: - raise InconsistencyError( - f"Attempting to destroy indestructible variables: {illegal_destroy}" + + rval: Dict["Apply", Set["Apply"]] = OrderedDict() + + if not fgraph._destroyhandler_destroyers: + return rval + + # BUILD DATA STRUCTURES + # CHECK for multiple destructions during construction of variables + droot, impact, __ignore = self.refresh_droot_impact(fgraph) + + # check for destruction of constants + illegal_destructions = [ + r + for r in droot + if getattr(r.tag, "indestructible", False) or isinstance(r, Constant) + ] + if illegal_destructions: + raise InconsistencyError( + f"Attempting to destroy indestructible variables: {illegal_destructions}" + ) + + # add destroyed variable clients as computational dependencies + for app in fgraph._destroyhandler_destroyers: + # keep track of clients that should run before the current Apply + root_clients: Set["Apply"] = set_type() + # for each destroyed input... + for output_idx, input_idx_list in app.op.destroy_map.items(): + destroyed_idx = input_idx_list[0] + destroyed_variable = app.inputs[destroyed_idx] + root = droot[destroyed_variable] + root_impact = impact[root] + # we generally want to put all clients of things which depend on root + # as pre-requisites of app. + # But, app is itself one such client! + # App will always be a client of the node we're destroying + # (destroyed_variable, but the tricky thing is when it is also a client of + # *another variable* viewing on the root. Generally this is illegal, (e.g., + # add_inplace(x, x.T). In some special cases though, the in-place op will + # actually be able to work properly with multiple destroyed inputs (e.g, + # add_inplace(x, x). An Op that can still work in this case should declare + # so via the 'destroyhandler_tolerate_same' attribute or + # 'destroyhandler_tolerate_aliased' attribute. + # + # destroyhandler_tolerate_same should be a list of pairs of the form + # [(idx0, idx1), (idx0, idx2), ...] + # The first element of each pair is the input index of a destroyed + # variable. + # The second element of each pair is the index of a different input where + # we will permit exactly the same variable to appear. + # For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed + # input is also allowed to appear as the second argument. + # + # destroyhandler_tolerate_aliased is the same sort of list of + # pairs. + # op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the + # destroyhandler to IGNORE an aliasing between a destroyed + # input idx0 and another input idx1. + # This is generally a bad idea, but it is safe in some + # cases, such as + # - the op reads from the aliased idx1 before modifying idx0 + # - the idx0 and idx1 are guaranteed not to overlap (e.g. + # they are pointed at different rows of a matrix). + # + + # CHECK FOR INPUT ALIASING + # OPT: pre-compute this on import + tolerate_same = getattr(app.op, "destroyhandler_tolerate_same", []) + assert isinstance(tolerate_same, list) + tolerated = { + idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx + } + tolerated.add(destroyed_idx) + tolerate_aliased = getattr( + app.op, "destroyhandler_tolerate_aliased", [] ) + assert isinstance(tolerate_aliased, list) + ignored = { + idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx + } + for i, input in enumerate(app.inputs): + if i in ignored: + continue + if input in root_impact and ( + i not in tolerated or input is not destroyed_variable + ): + raise InconsistencyError( + f"Input aliasing: {app} ({destroyed_idx}, {i})" + ) - # add destroyed variable clients as computational dependencies - for app in self.destroyers: - # keep track of clients that should run before the current Apply - root_clients = set_type() - # for each destroyed input... - for output_idx, input_idx_list in app.op.destroy_map.items(): - destroyed_idx = input_idx_list[0] - destroyed_variable = app.inputs[destroyed_idx] - root = droot[destroyed_variable] - root_impact = impact[root] - # we generally want to put all clients of things which depend on root - # as pre-requisites of app. - # But, app is itself one such client! - # App will always be a client of the node we're destroying - # (destroyed_variable, but the tricky thing is when it is also a client of - # *another variable* viewing on the root. Generally this is illegal, (e.g., - # add_inplace(x, x.T). In some special cases though, the in-place op will - # actually be able to work properly with multiple destroyed inputs (e.g, - # add_inplace(x, x). An Op that can still work in this case should declare - # so via the 'destroyhandler_tolerate_same' attribute or - # 'destroyhandler_tolerate_aliased' attribute. - # - # destroyhandler_tolerate_same should be a list of pairs of the form - # [(idx0, idx1), (idx0, idx2), ...] - # The first element of each pair is the input index of a destroyed - # variable. - # The second element of each pair is the index of a different input where - # we will permit exactly the same variable to appear. - # For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed - # input is also allowed to appear as the second argument. - # - # destroyhandler_tolerate_aliased is the same sort of list of - # pairs. - # op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the - # destroyhandler to IGNORE an aliasing between a destroyed - # input idx0 and another input idx1. - # This is generally a bad idea, but it is safe in some - # cases, such as - # - the op reads from the aliased idx1 before modifying idx0 - # - the idx0 and idx1 are guaranteed not to overlap (e.g. - # they are pointed at different rows of a matrix). - # - - # CHECK FOR INPUT ALIASING - # OPT: pre-compute this on import - tolerate_same = getattr(app.op, "destroyhandler_tolerate_same", []) - assert isinstance(tolerate_same, list) - tolerated = { - idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx - } - tolerated.add(destroyed_idx) - tolerate_aliased = getattr( - app.op, "destroyhandler_tolerate_aliased", [] + # add the rule: app must be preceded by all other Apply instances that + # depend on destroyed_input + for r in root_impact: + assert not [ + a + for a, c in fgraph._destroy_handler_clients[r].items() + if not c + ] + root_clients.update( + [a for a, c in fgraph._destroy_handler_clients[r].items() if c] ) - assert isinstance(tolerate_aliased, list) - ignored = { - idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx - } - for i, input in enumerate(app.inputs): - if i in ignored: - continue - if input in root_impact and ( - i not in tolerated or input is not destroyed_variable - ): - raise InconsistencyError( - f"Input aliasing: {app} ({destroyed_idx}, {i})" - ) - - # add the rule: app must be preceded by all other Apply instances that - # depend on destroyed_input - for r in root_impact: - assert not [a for a, c in self.clients[r].items() if not c] - root_clients.update( - [a for a, c in self.clients[r].items() if c] - ) - # app itself is a client of the destroyed inputs, - # but should not run before itself - root_clients.remove(app) - if root_clients: - rval[app] = root_clients + # app itself is a client of the destroyed inputs, + # but should not run before itself + root_clients.remove(app) + if root_clients: + rval[app] = root_clients return rval diff --git a/aesara/graph/features.py b/aesara/graph/features.py index 73a625409f..6216584521 100644 --- a/aesara/graph/features.py +++ b/aesara/graph/features.py @@ -1,10 +1,11 @@ import inspect import sys import time +import types import warnings from collections import OrderedDict -from functools import partial from io import StringIO +from typing import TYPE_CHECKING, Dict, Optional, Set import numpy as np @@ -14,6 +15,10 @@ from aesara.graph.utils import InconsistencyError +if TYPE_CHECKING: + from aesara.graph.basic import Apply + + class AlreadyThere(Exception): """ Raised by a Feature's on_attach callback method if the FunctionGraph @@ -286,7 +291,7 @@ def on_detach(self, fgraph): """ - def on_import(self, fgraph, node, reason): + def on_import(self, fgraph, node: "Apply", reason: Optional[str]): """ Called whenever a node is imported into `fgraph`, which is just before the node is actually connected to the graph. @@ -297,7 +302,15 @@ def on_import(self, fgraph, node, reason): """ - def on_change_input(self, fgraph, node, i, var, new_var, reason=None): + def on_change_input( + self, + fgraph, + node: "Apply", + i: int, + var: "Variable", + new_var: "Variable", + reason: Optional[str] = None, + ): """ Called whenever ``node.inputs[i]`` is changed from `var` to `new_var`. At the moment the callback is done, the change has already taken place. @@ -307,14 +320,14 @@ def on_change_input(self, fgraph, node, i, var, new_var, reason=None): """ - def on_prune(self, fgraph, node, reason): + def on_prune(self, fgraph, node: "Apply", reason: Optional[str]) -> None: """ Called whenever a node is pruned (removed) from the `fgraph`, after it is disconnected from the graph. """ - def orderings(self, fgraph): + def orderings(self, fgraph, ordered: bool = True) -> Dict["Apply", Set["Apply"]]: """ Called by `FunctionGraph.toposort`. It should return a dictionary of ``{node: predecessors}`` where ``predecessors`` is a list of @@ -341,25 +354,13 @@ def clone(self): class Bookkeeper(Feature): def on_attach(self, fgraph): for node in io_toposort(fgraph.inputs, fgraph.outputs): - self.on_import(fgraph, node, "on_attach") + self.on_import(fgraph, node, "Bookkeeper.on_attach") def on_detach(self, fgraph): for node in io_toposort(fgraph.inputs, fgraph.outputs): self.on_prune(fgraph, node, "Bookkeeper.detach") -class GetCheckpoint: - def __init__(self, history, fgraph): - self.h = history - self.fgraph = fgraph - self.nb = 0 - - def __call__(self): - self.h.history[self.fgraph] = [] - self.nb += 1 - return self.nb - - class LambdaExtract: def __init__(self, fgraph, node, i, r, reason=None): self.fgraph = fgraph @@ -375,73 +376,64 @@ def __call__(self): class History(Feature): - """Keep an history of changes to an FunctionGraph. + """Keep a history of changes to a `FunctionGraph`. - This history can be reverted up to the last checkpoint.. We can - revert to only 1 point in the past. This limit was added to lower - the memory usage. + A `FunctionGraph` can be reverted up to the last checkpoint using this + `Feature`. It can revert to only one point in the past. This limit was + added to lower memory usage. """ - pickle_rm_attr = ["checkpoint", "revert"] - - def __init__(self): - self.history = {} - def on_attach(self, fgraph): if hasattr(fgraph, "checkpoint") or hasattr(fgraph, "revert"): raise AlreadyThere( "History feature is already present or in" " conflict with another plugin." ) - self.history[fgraph] = [] - # Don't call unpickle here, as ReplaceValidate.on_attach() - # call to History.on_attach() will call the - # ReplaceValidate.unpickle and not History.unpickle - fgraph.checkpoint = GetCheckpoint(self, fgraph) - fgraph.revert = partial(self.revert, fgraph) + fgraph._history_is_reverting = False + fgraph._history_nb = 0 + fgraph._history_history = [] + fgraph.checkpoint = types.MethodType(self.checkpoint, fgraph) + fgraph.revert = types.MethodType(self.revert, fgraph) def clone(self): return type(self)() - def unpickle(self, fgraph): - fgraph.checkpoint = GetCheckpoint(self, fgraph) - fgraph.revert = partial(self.revert, fgraph) - def on_detach(self, fgraph): - """ - Should remove any dynamically added functionality - that it installed into the function_graph - """ del fgraph.checkpoint del fgraph.revert - del self.history[fgraph] + del fgraph._history_history + del fgraph._history_is_reverting def on_change_input(self, fgraph, node, i, r, new_r, reason=None): - if self.history[fgraph] is None: + if fgraph._history_is_reverting: return - h = self.history[fgraph] - h.append(LambdaExtract(fgraph, node, i, r, reason)) + fgraph._history_history.append(LambdaExtract(fgraph, node, i, r, str(reason))) - def revert(self, fgraph, checkpoint): + @staticmethod + def checkpoint(fgraph): + fgraph._history_history = [] + fgraph._history_nb += 1 + return fgraph._history_nb + + @staticmethod + def revert(fgraph, checkpoint): """ Reverts the graph to whatever it was at the provided checkpoint (undoes all replacements). A checkpoint at any - given time can be obtained using self.checkpoint(). + given time can be obtained using :meth:`self.checkpoint`. """ - h = self.history[fgraph] - self.history[fgraph] = None - assert fgraph.checkpoint.nb == checkpoint + h = fgraph._history_history + fgraph._history_is_reverting = True + assert fgraph._history_nb == checkpoint while h: f = h.pop() f() - self.history[fgraph] = h + fgraph._history_is_reverting = False class Validator(Feature): - pickle_rm_attr = ["validate", "consistent"] - def on_attach(self, fgraph): for attr in ("validate", "validate_time"): if hasattr(fgraph, attr): @@ -449,15 +441,8 @@ def on_attach(self, fgraph): "Validator feature is already present or in" " conflict with another plugin." ) - # Don't call unpickle here, as ReplaceValidate.on_attach() - # call to History.on_attach() will call the - # ReplaceValidate.unpickle and not History.unpickle - fgraph.validate = partial(self.validate_, fgraph) - fgraph.consistent = partial(self.consistent_, fgraph) - - def unpickle(self, fgraph): - fgraph.validate = partial(self.validate_, fgraph) - fgraph.consistent = partial(self.consistent_, fgraph) + fgraph.validate = types.MethodType(self.validate_, fgraph) + fgraph.consistent = types.MethodType(self.consistent_, fgraph) def on_detach(self, fgraph): """ @@ -467,7 +452,8 @@ def on_detach(self, fgraph): del fgraph.validate del fgraph.consistent - def validate_(self, fgraph): + @staticmethod + def validate_(fgraph): """ If the caller is replace_all_validate, just raise the exception. replace_all_validate will print out the @@ -499,7 +485,8 @@ def validate_(self, fgraph): fgraph.profile.validate_time += t1 - t0 return ret - def consistent_(self, fgraph): + @staticmethod + def consistent_(fgraph): try: fgraph.validate() return True @@ -508,12 +495,6 @@ def consistent_(self, fgraph): class ReplaceValidate(History, Validator): - pickle_rm_attr = ( - ["replace_validate", "replace_all_validate", "replace_all_validate_remove"] - + History.pickle_rm_attr - + Validator.pickle_rm_attr - ) - def on_attach(self, fgraph): for attr in ( "replace_validate", @@ -525,42 +506,40 @@ def on_attach(self, fgraph): "ReplaceValidate feature is already present" " or in conflict with another plugin." ) - self._nodes_removed = set() - self.fail_validate = False + fgraph._replace_nodes_removed = set() + fgraph._replace_validate_failed = False + History.on_attach(self, fgraph) Validator.on_attach(self, fgraph) - self.unpickle(fgraph) - def clone(self): - return type(self)() - - def unpickle(self, fgraph): - History.unpickle(self, fgraph) - Validator.unpickle(self, fgraph) - fgraph.replace_validate = partial(self.replace_validate, fgraph) - fgraph.replace_all_validate = partial(self.replace_all_validate, fgraph) - fgraph.replace_all_validate_remove = partial( + fgraph.replace_validate = types.MethodType(self.replace_validate, fgraph) + fgraph.replace_all_validate = types.MethodType( + self.replace_all_validate, fgraph + ) + fgraph.replace_all_validate_remove = types.MethodType( self.replace_all_validate_remove, fgraph ) + def clone(self): + return type(self)() + def on_detach(self, fgraph): - """ - Should remove any dynamically added functionality - that it installed into the function_graph - """ History.on_detach(self, fgraph) Validator.on_detach(self, fgraph) - del self._nodes_removed + del fgraph._replace_nodes_removed + del fgraph._replace_validate_failed del fgraph.replace_validate del fgraph.replace_all_validate del fgraph.replace_all_validate_remove - def replace_validate(self, fgraph, r, new_r, reason=None, **kwargs): - self.replace_all_validate(fgraph, [(r, new_r)], reason=reason, **kwargs) + @staticmethod + def replace_validate(fgraph, r, new_r, reason=None, **kwargs): + ReplaceValidate.replace_all_validate( + fgraph, [(r, new_r)], reason=reason, **kwargs + ) - def replace_all_validate( - self, fgraph, replacements, reason=None, verbose=None, **kwargs - ): + @staticmethod + def replace_all_validate(fgraph, replacements, reason=None, verbose=None, **kwargs): chk = fgraph.checkpoint() if verbose is None: @@ -615,8 +594,9 @@ def replace_all_validate( # The return is needed by replace_all_validate_remove return chk + @staticmethod def replace_all_validate_remove( - self, fgraph, replacements, remove, reason=None, warn=True, **kwargs + fgraph, replacements, remove, reason=None, warn=True, **kwargs ): """ As replace_all_validate, revert the replacement if the ops @@ -624,7 +604,7 @@ def replace_all_validate_remove( """ chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs) - self._nodes_removed.update(remove) + fgraph._replace_nodes_removed.update(remove) for rm in remove: if rm in fgraph.apply_nodes or rm in fgraph.variables: fgraph.revert(chk) @@ -644,78 +624,53 @@ def __getstate__(self): return d def on_import(self, fgraph, node, reason): - if node in self._nodes_removed: - self.fail_validate = True + if node in fgraph._replace_nodes_removed: + fgraph._replace_validate_failed = True def validate(self, fgraph): - if self.fail_validate: - self.fail_validate = False + if fgraph._replace_validate_failed: + fgraph._replace_validate_failed = False raise InconsistencyError("Trying to reintroduce a removed node") class NodeFinder(Bookkeeper): - def __init__(self): - self.fgraph = None - self.d = {} - def on_attach(self, fgraph): if hasattr(fgraph, "get_nodes"): raise AlreadyThere("NodeFinder is already present") - if self.fgraph is not None and self.fgraph != fgraph: - raise Exception("A NodeFinder instance can only serve one FunctionGraph.") + fgraph._finder_ops_to_nodes = {} - self.fgraph = fgraph - fgraph.get_nodes = partial(self.query, fgraph) - Bookkeeper.on_attach(self, fgraph) + def query(self, op): + return self._finder_ops_to_nodes.get(op, []) + + fgraph.get_nodes = types.MethodType(query, fgraph) + super().on_attach(fgraph) def clone(self): return type(self)() def on_detach(self, fgraph): - """ - Should remove any dynamically added functionality - that it installed into the function_graph - """ - if self.fgraph is not fgraph: - raise Exception( - "This NodeFinder instance was not attached to the" " provided fgraph." - ) - self.fgraph = None del fgraph.get_nodes - Bookkeeper.on_detach(self, fgraph) + del fgraph._finder_ops_to_nodes def on_import(self, fgraph, node, reason): try: - self.d.setdefault(node.op, []).append(node) - except TypeError: # node.op is unhashable + fgraph._finder_ops_to_nodes.setdefault(node.op, []).append(node) + except TypeError: + # In case the `Op` is unhashable return - except Exception as e: - print("OFFENDING node", type(node), type(node.op), file=sys.stderr) - try: - print("OFFENDING node hash", hash(node.op), file=sys.stderr) - except Exception: - print("OFFENDING node not hashable", file=sys.stderr) - raise e def on_prune(self, fgraph, node, reason): try: - nodes = self.d[node.op] - except TypeError: # node.op is unhashable + nodes = fgraph._finder_ops_to_nodes[node.op] + except TypeError: + # In case the `Op` is unhashable return + nodes.remove(node) - if not nodes: - del self.d[node.op] - def query(self, fgraph, op): - try: - all = self.d.get(op, []) - except TypeError: - raise TypeError( - f"{op} in unhashable and cannot be queried by the optimizer" - ) - all = list(all) - return all + if not nodes: + del fgraph._finder_ops_to_nodes[node.op] class PrintListener(Feature): @@ -782,7 +737,6 @@ def validate(self, fgraph): return True for out in tuple(fgraph.outputs[i] for i in self.protected_out_ids): - node = out.owner if node is None: diff --git a/aesara/graph/fg.py b/aesara/graph/fg.py index 8bb214421b..6db46af48e 100644 --- a/aesara/graph/fg.py +++ b/aesara/graph/fg.py @@ -1,6 +1,7 @@ """A container for specifying and manipulating a graph with distinct inputs and outputs.""" import time from collections import OrderedDict +from types import MethodType from typing import ( TYPE_CHECKING, Any, @@ -699,26 +700,19 @@ def attach_feature(self, feature: Feature) -> None: """Add a ``graph.features.Feature`` to this function graph and trigger its ``on_attach`` callback.""" # Filter out literally identical `Feature`s if feature in self._features: - return # the feature is already present + return # Filter out functionally identical `Feature`s. - # `Feature`s may use their `on_attach` method to raise - # `AlreadyThere` if they detect that some - # installed `Feature` does the same thing already - attach = getattr(feature, "on_attach", None) - if attach is not None: - try: - attach(self) - except AlreadyThere: - return + # `Feature`s may use their `on_attach` method to raise `AlreadyThere` + # if they detect that some installed `Feature` does the same thing + # already + try: + feature.on_attach(self) + except AlreadyThere: + return + self.execute_callbacks_times.setdefault(feature, 0.0) - # It would be nice if we could require a specific class instead of - # a "workalike" so we could do actual error checking - # if not isinstance(feature, Feature): - # raise TypeError("Expected Feature instance, got "+\ - # str(type(feature))) - # Add the feature self._features.append(feature) def remove_feature(self, feature: Feature) -> None: @@ -733,9 +727,8 @@ def remove_feature(self, feature: Feature) -> None: self._features.remove(feature) except ValueError: return - detach = getattr(feature, "on_detach", None) - if detach is not None: - detach(self) + + feature.on_detach(self) def execute_callbacks(self, name: str, *args, **kwargs) -> None: """Execute callbacks. @@ -935,22 +928,18 @@ def clone_get_equiv( return e, equiv def __getstate__(self): - # This is needed as some features introduce instance methods - # This is not picklable - d = self.__dict__.copy() - for feature in self._features: - for attr in getattr(feature, "pickle_rm_attr", []): - del d[attr] - # XXX: The `Feature` `DispatchingFeature` takes functions as parameter - # and they can be lambda functions, making them unpicklable. + # Remove methods that were attached by features + self_dict = { + k: v for k, v in self.__dict__.items() if not isinstance(v, MethodType) + } - # execute_callbacks_times have reference to optimizer, and they can't - # be pickled as the decorators with parameters aren't pickable. - if "execute_callbacks_times" in d: - del d["execute_callbacks_times"] + # `execute_callbacks_times` holds references to optimizers, so they + # can't be pickled + if "execute_callbacks_times" in self_dict: + del self_dict["execute_callbacks_times"] - return d + return self_dict def __setstate__(self, dct): self.__dict__.update(dct) diff --git a/aesara/graph/type.py b/aesara/graph/type.py index e08f40e09a..a5e7f7fcc6 100644 --- a/aesara/graph/type.py +++ b/aesara/graph/type.py @@ -1,4 +1,4 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union from typing_extensions import TypeAlias @@ -262,14 +262,33 @@ def values_eq_approx(cls, a: D, b: D) -> bool: return cls.values_eq(a, b) -class HasDataType: +class HasDataType(ABC): """A mixin for a type that has a :attr:`dtype` attribute.""" dtype: str + @classmethod + def __subclasshook__(cls, C): + if cls is HasDataType: + if any("dtype" in B.__dict__ for B in C.__mro__): + return True + return NotImplemented + -class HasShape: +class HasShape(ABC): """A mixin for a type that has :attr:`shape` and :attr:`ndim` attributes.""" ndim: int shape: Tuple[Optional[int], ...] + + @classmethod + def __subclasshook__(cls, C): + if cls is HasShape: + has_shape, has_ndim = False, False + for B in C.__mro__: + has_shape = has_shape or ("shape" in B.__dict__) + has_ndim = has_ndim or ("ndim" in B.__dict__) + + if has_shape and has_ndim: + return True + return NotImplemented diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index dd84184344..d7206a1067 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -69,6 +69,8 @@ if TYPE_CHECKING: + from numpy.typing import NDArray + from aesara.tensor import TensorLike @@ -260,8 +262,11 @@ def _obj_is_wrappable_as_tensor(x): def get_scalar_constant_value( - orig_v, elemwise=True, only_process_constants=False, max_recur=10 -): + orig_v, + elemwise: bool = True, + only_process_constants: bool = False, + max_recur: int = 10, +) -> "NDArray": """Return the constant scalar(0-D) value underlying variable `v`. If `v` is the output of dimshuffles, fills, allocs, etc, diff --git a/aesara/tensor/rewriting/basic.py b/aesara/tensor/rewriting/basic.py index 831324b27b..8d4d8e1990 100644 --- a/aesara/tensor/rewriting/basic.py +++ b/aesara/tensor/rewriting/basic.py @@ -326,7 +326,7 @@ def dimshuffled_alloc(i): return False input_shapes = [ - tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim)) + tuple(fgraph.shape_feature.get_shape(fgraph, i, j) for j in range(i.type.ndim)) for i in node.inputs ] bcasted_shape = broadcast_shape( @@ -1022,7 +1022,7 @@ def local_useless_switch(fgraph, node): out = correct_out input_shapes = [ - tuple(shape_feature.get_shape(inp, i) for i in range(inp.type.ndim)) + tuple(shape_feature.get_shape(fgraph, inp, i) for i in range(inp.type.ndim)) for inp in node.inputs ] diff --git a/aesara/tensor/rewriting/elemwise.py b/aesara/tensor/rewriting/elemwise.py index 91123de506..69cbc8857a 100644 --- a/aesara/tensor/rewriting/elemwise.py +++ b/aesara/tensor/rewriting/elemwise.py @@ -4,7 +4,6 @@ from typing import Optional from warnings import warn -import aesara import aesara.scalar.basic as aes from aesara import compile from aesara.compile.mode import get_target_language @@ -118,13 +117,9 @@ def apply(self, fgraph): else: update_outs = [] - protected_inputs = [ - f.protected - for f in fgraph._features - if isinstance(f, aesara.compile.function.types.Supervisor) - ] - protected_inputs = sum(protected_inputs, []) # flatten the list - protected_inputs.extend(fgraph.outputs) + protected_inputs = getattr(fgraph, "_supervisor_protected", set()) + protected_inputs.update(fgraph.outputs) + for node in list(io_toposort(fgraph.inputs, fgraph.outputs)): op = node.op if not isinstance(op, self.op): @@ -184,7 +179,6 @@ def apply(self, fgraph): raised_warning = not verbose for candidate_output in candidate_outputs: - # If the output of the node can be established as an update # output of the fgraph, visit the candidate_inputs in an order # that will improve the chances of making the node operate @@ -193,7 +187,6 @@ def apply(self, fgraph): sorted_candidate_inputs = candidate_inputs if candidate_out_var in update_outs: - # The candidate output is an update. Sort the # variables in candidate_inputs in the following order: # - Vars corresponding to the actual updated input @@ -225,12 +218,10 @@ def apply(self, fgraph): hasattr(fgraph, "destroy_handler") and inp.owner and any( - fgraph.destroy_handler.root_destroyer.get(up_inp, None) - is inp.owner + fgraph.root_destroyer.get(up_inp, None) is inp.owner for up_inp in updated_inputs ) ): - # the candidate input is a variable computed # inplace on the updated input via a sequence of # one or more inplace operations diff --git a/aesara/tensor/rewriting/shape.py b/aesara/tensor/rewriting/shape.py index 87d77b1322..e8dc56b47e 100644 --- a/aesara/tensor/rewriting/shape.py +++ b/aesara/tensor/rewriting/shape.py @@ -1,6 +1,15 @@ import traceback -from io import StringIO -from typing import Optional +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) from typing import cast as type_cast from warnings import warn @@ -17,7 +26,7 @@ copy_stack_trace, node_rewriter, ) -from aesara.graph.utils import InconsistencyError, get_variable_trace_string +from aesara.graph.utils import InconsistencyError from aesara.tensor.basic import ( MakeVector, as_tensor_variable, @@ -47,10 +56,23 @@ unbroadcast, ) from aesara.tensor.subtensor import Subtensor, get_idx_list -from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes +from aesara.tensor.type import HasShape, TensorType, discrete_dtypes, integer_dtypes from aesara.tensor.type_other import NoneConst +if TYPE_CHECKING: + from numpy.typing import ArrayLike + + from aesara.graph.basic import Apply + from aesara.graph.tensor.var import TensorVariable + + InputShapesType = List[Optional[Tuple[Variable, ...]]] + OutputShapesType = List[Optional[Tuple[Variable, ...]]] + ShapeInferFunctionType = Callable[ + [FunctionGraph, "Apply", InputShapesType], OutputShapesType + ] + + class ShapeFeature(Feature): r"""A `Feature` that tracks shape information in a graph. @@ -73,81 +95,43 @@ class ShapeFeature(Feature): graph nodes have multiple clients. Many rewrites refuse to work on nodes with multiple clients. - Lifting is done by using an `.infer_shape` function if one is - present, or else using a conservative default. An Op that - supports shape-lifting should define a infer_shape(self, fgraph, node, - input_shapes) function. The argument input_shapes is a tuple of - tuples... there is an interior tuple for each input to the node. - The tuple has as many elements as dimensions. The element in - position i of tuple j represents the i'th shape component of the - j'th input. The function should return a tuple of tuples. One - output tuple for each node.output. Again, the i'th element of the - j'th output tuple represents the output[j].shape[i] of the - function. If an output is not a TensorType, then None should be - returned instead of a tuple for that output. - - For example the infer_shape for a matrix-matrix product would accept - input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),). - - Inferring the shape of internal nodes in the graph is important - for doing size-driven rewrites. If we know how big various - intermediate results will be, we can estimate the cost of many Ops - accurately, and generate c-code that is specific [e.g. unrolled] - to particular sizes. - - In cases where you cannot figure out the shape, raise a ShapeError. - - Notes - ----- - Right now there is only the ConvOp that could really take - advantage of this shape inference, but it is worth it even - just for the ConvOp. All that's necessary to do shape - inference is 1) to mark shared inputs as having a particular - shape, either via a .tag or some similar hacking; and 2) to - add an optional In() argument to promise that inputs will - have a certain shape (or even to have certain shapes in - certain dimensions). - - We can't automatically infer the shape of shared variables as they can - change of shape during the execution by default. - - To use this shape information in rewrites, use the - ``shape_of`` dictionary. - - For example: + Lifting is done by using an :meth:`Op.infer_shape` method if one is + present, or else using a conservative default.. - .. code-block:: python + Inferring the shape of internal nodes in the graph is important for doing + size-driven rewrites. If we know how big various intermediate results will + be, we can estimate the cost of many `Op`\s accurately, and generate code + that is specific (e.g. unrolled) to particular sizes. - try: - shape_of = fgraph.shape_feature.shape_of - except AttributeError: - # This can happen when the mode doesn't include the ShapeFeature. - return + In cases where `ShapeFeature` cannot figure out the shape, it raises a + `ShapeError`. - shape_of_output_zero = shape_of[node.output[0]] + .. note:: - The ``shape_of_output_zero`` symbol will contain a tuple, whose - elements are either integers or symbolic integers. + We can't automatically infer the shape of shared variables as they can + change of shape during the execution by default. - TODO: check to see if the symbols are necessarily - non-constant... or are integer literals sometimes Aesara - constants?? That would be confusing. + To use the shape information gathered by a `FunctionGraph`-attached + `ShapeFeature` in rewrites, use the :meth:`ShapeFeature.get_shape` method. """ + lscalar_one = constant(1, dtype="int64", ndim=0) - def get_node_infer_shape(self, node): + def get_node_infer_shape( + self, fgraph: FunctionGraph, node: "Apply" + ) -> "OutputShapesType": try: - shape_infer = node.op.infer_shape + shape_infer: "ShapeInferFunctionType" = node.op.infer_shape except AttributeError: shape_infer = self.default_infer_shape try: o_shapes = shape_infer( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] + fgraph, node, [self.shape_of[r] for r in node.inputs] ) except ShapeError: o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] + fgraph, node, [self.shape_of[r] for r in node.inputs] ) except NotImplementedError as e: raise NotImplementedError( @@ -168,77 +152,104 @@ def get_node_infer_shape(self, node): else: warn(msg) o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] + fgraph, node, [self.shape_of[r] for r in node.inputs] ) return o_shapes - def get_shape(self, var, idx): - """Rewrites can call this to get a `Shape_i`. + def get_shape(self, fgraph: FunctionGraph, var: Variable, idx: int) -> Variable: + """Get the shape of `var` at index `idx`. - It is better to call this then use directly ``shape_of[var][idx]`` - as this method should update `shape_of` if needed. + It is better to call this than use ``ShapeFeature.shape_of[var][idx]``, + since this method will update `ShapeFeature.shape_of` when needed. TODO: Up to now, we don't update it in all cases. Update in all cases. + """ - r = self.shape_of[var][idx] + var_shape = self.shape_of[var] + assert var_shape is not None + + var_idx_shape = var_shape[idx] + if ( - r.owner - and isinstance(r.owner.op, Shape_i) - and r.owner.inputs[0] not in self.fgraph.variables + var_idx_shape.owner + and isinstance(var_idx_shape.owner.op, Shape_i) + and var_idx_shape.owner.inputs[0] not in fgraph.variables ): assert var.owner node = var.owner - # recur on inputs + + # Recurse on inputs + # TODO FIXME: Remove the recursion here. for i in node.inputs: - if getattr(i.type, "ndim", None) > 0: - self.get_shape(i, 0) - o_shapes = self.get_node_infer_shape(node) + if isinstance(i.type, HasShape): + self.get_shape(fgraph, i, 0) + + o_shapes = self.get_node_infer_shape(fgraph, node) assert len(o_shapes) == len(node.outputs) # Only change the variables and dimensions that would introduce # extra computation for new_shps, out in zip(o_shapes, node.outputs): - if not hasattr(out.type, "ndim"): + if not isinstance(out.type, HasShape): continue - merged_shps = list(self.shape_of[out]) + out_shape = self.shape_of[out] + assert out_shape is not None + + merged_shps = list(out_shape) + changed = False for i in range(out.type.ndim): n_r = merged_shps[i] if ( n_r.owner and isinstance(n_r.owner.op, Shape_i) - and n_r.owner.inputs[0] not in self.fgraph.variables + and n_r.owner.inputs[0] not in fgraph.variables ): changed = True + + assert new_shps is not None + merged_shps[i] = new_shps[i] + if changed: self.set_shape(out, merged_shps, override=True) - r = self.shape_of[var][idx] - return r - def shape_ir(self, i, r): - """Return symbolic r.shape[i] for tensor variable r, int i.""" - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - return constant(r.type.shape[i], dtype="int64") + var_shape = self.shape_of[var] + assert var_shape is not None + + var_idx_shape = var_shape[idx] + + return var_idx_shape + + def shape_ir(self, i: int, r: Variable) -> Variable: + r"""Return symbolic `r.shape[i]`.""" + if isinstance(r.type, HasShape) and r.type.shape[i] is not None: + return constant(r.type.shape[i], dtype="int64", ndim=0) else: # Do not call make_node for test_value s = Shape_i(i)(r) + + assert isinstance(s, Variable) + try: - s = get_scalar_constant_value(s) + s = constant(get_scalar_constant_value(s), dtype="int64", ndim=0) except NotScalarConstantError: pass + return s - def shape_tuple(self, r): + def shape_tuple(self, r: Variable) -> Optional[Tuple[Variable, ...]]: """Return a tuple of symbolic shape vars for tensor variable r.""" - if not hasattr(r.type, "ndim"): + if not isinstance(r.type, HasShape): # This happen for NoneConst. return None return tuple(self.shape_ir(i, r) for i in range(r.type.ndim)) - def default_infer_shape(self, fgraph, node, i_shapes): + def default_infer_shape( + self, fgraph: FunctionGraph, node: "Apply", i_shapes: "InputShapesType" + ) -> "OutputShapesType": """Return a list of shape tuple or None for the outputs of node. This function is used for Ops that don't implement infer_shape. @@ -254,52 +265,41 @@ def default_infer_shape(self, fgraph, node, i_shapes): rval.append(None) return rval - def unpack(self, s_i, var): - """Return a symbolic integer scalar for the shape element s_i. + def to_symbolic_int( + self, s_i: Union[int, float, np.integer, "ArrayLike", Variable] + ) -> Variable: + """Return a symbolic integer scalar for the shape element `s_i`. - The s_i argument was produced by the infer_shape() of an Op subclass. + TODO: Re-evaluate the need for this, since it's effectively eager + canonicalization. - var: the variable that correspond to s_i. This is just for - error reporting. + Parameters + ---------- + s_i + The `s_i` argument is assumed to be produced by an :meth:`Op.infer_shape`. """ - assert s_i is not None - if s_i == 1: return self.lscalar_one - if isinstance(s_i, float) and int(s_i) == s_i: - s_i = int(s_i) - if isinstance(s_i, (np.integer, int)) or ( + + if isinstance(s_i, (float, int, np.integer)) or ( isinstance(s_i, np.ndarray) and s_i.ndim == 0 ): - # this shape is a constant - if s_i < 0: - msg = "There is a negative shape in the graph!" - msg += get_variable_trace_string(var) - # The rest of the pipeline don't handle correctly this - # case. So we have 2 choices, stop compilation or - # consider the shape as unknown. As we have more - # chance to give the stack trace here then later, I - # choose that options as it would give better error - # message. - raise AssertionError(msg) - return constant(s_i, dtype="int64") - if isinstance(s_i, (tuple, list)): - # this dimension is the same as many of the inputs - # which tells us that if one of the inputs is known, - # the others all become known. - # TODO: should be implemented in Elemwise, and Dot - # - # worst case, we loop over shape_of and replace things - raise NotImplementedError(s_i) - - # s_i is x.shape[i] for some x, we change it to shape_of[x][i] + assert int(s_i) == s_i and s_i >= 0 + return constant(s_i, dtype="int64", ndim=0) + + assert isinstance(s_i, Variable) + + # TODO FIXME: This is eager canonicalization; we should let the + # relevant canonicalization passes do their job and not perform the + # same logic manually. if ( s_i.owner and isinstance(s_i.owner.op, Subtensor) and s_i.owner.inputs[0].owner and isinstance(s_i.owner.inputs[0].owner.op, Shape) ): + # s_i is x.shape[i] for some x, we change it to shape_of[x][i] assert s_i.type.ndim == 0 assert len(s_i.owner.op.idx_list) == 1 @@ -312,32 +312,33 @@ def unpack(self, s_i, var): try: i = get_scalar_constant_value(idx) except NotScalarConstantError: - pass + return s_i else: # Executed only if no exception was raised x = s_i.owner.inputs[0].owner.inputs[0] # x should already have been imported, and should be in shape_of. - s_i = self.shape_of[x][i] + s_x = self.shape_of[x] + assert s_x is not None + s_i = s_x[i] - if s_i.type.dtype in integer_dtypes: - if getattr(s_i.type, "ndim", 0): - raise TypeError("Shape element must be scalar", s_i) - return s_i - else: - raise TypeError( - "Unsupported shape element", s_i, type(s_i), getattr(s_i, "type", None) - ) + if s_i.type.dtype not in integer_dtypes or getattr(s_i.type, "ndim", 0) != 0: + raise TypeError(f"Shape element {str(s_i)} must be an integer scalar") + + return s_i - def set_shape(self, r, s, override=False): + def set_shape( + self, r: Variable, s: Optional[Sequence[Variable]], override: bool = False + ) -> None: """Assign the shape `s` to previously un-shaped variable `r`. Parameters ---------- - r : a variable - s : None or a tuple of symbolic integers - override : If False, it mean r is a new object in the fgraph. - If True, it mean r is already in the fgraph and we want to - override its shape. + r + s + override + If ``False``, it means `r` is a new, unseen term. + If ``True``, it means `r` is assumed to have already been seen and + we want to override its shape. """ if not override: @@ -349,36 +350,36 @@ def set_shape(self, r, s, override=False): raise TypeError("shapes must be tuple/list", (r, s)) if r.type.ndim != len(s): - sio = StringIO() - aesara.printing.debugprint(r, file=sio, print_type=True) raise AssertionError( - f"Something inferred a shape with {len(s)} dimensions " - f"for a variable with {int(r.type.ndim)} dimensions" - f" for the variable:\n{sio.getvalue()}" + f"A shape with {len(s)} dimensions was inferred for {r}: " + f"a variable with {int(r.type.ndim)} dimensions." ) - shape_vars = [] + shape_vars: Tuple[Variable, ...] = () for i in range(r.type.ndim): - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - shape_vars.append(constant(r.type.shape[i], dtype="int64")) + if isinstance(r.type, HasShape) and r.type.shape[i] is not None: + shape_vars += (constant(r.type.shape[i], dtype="int64", ndim=0),) else: - shape_vars.append(self.unpack(s[i], r)) + shape_vars += (self.to_symbolic_int(s[i]),) + assert all( - not hasattr(r.type, "shape") + not isinstance(r.type, HasShape) or r.type.shape[i] != 1 or self.lscalar_one.equals(shape_vars[i]) or self.lscalar_one.equals(extract_constant(shape_vars[i])) for i in range(r.type.ndim) ) + self.shape_of[r] = tuple(shape_vars) + for sv in shape_vars: self.shape_of_reverse_index.setdefault(sv, set()).add(r) - def update_shape(self, r, other_r): - """Replace shape of r by shape of other_r. + def update_shape(self, r: Variable, other_r: Variable) -> None: + """Replace shape of `r` by shape of `other_r`. - If, on some dimensions, the shape of other_r is not informative, - keep the shape of r on those dimensions. + If, on some dimensions, the shape of `other_r` is not informative, keep + the shape of `r` on those dimensions. """ # other_r should already have a shape @@ -395,6 +396,7 @@ def update_shape(self, r, other_r): # If no info is known on r's shape, use other_shape self.set_shape(r, other_shape) return + if ( other_r.owner and r.owner @@ -407,12 +409,17 @@ def update_shape(self, r, other_r): return # Merge other_shape with r_shape, giving the priority to other_shape - merged_shape = [] + merged_shape: Tuple[Variable, ...] = () for i, ps in enumerate(other_shape): - if r_shape is None and other_shape: - merged_shape.append(other_shape[i]) - elif ( - ps.owner + if r_shape is None: + merged_shape += (ps,) + continue + + rs = r_shape[i] + if ( + # TODO FIXME: This is another instance of eager + # canonicalization that we need to address. + ps.owner is not None and isinstance(getattr(ps.owner, "op", None), Shape_i) and ps.owner.op.i == i and ps.owner.inputs[0] in (r, other_r) @@ -421,20 +428,21 @@ def update_shape(self, r, other_r): # For now, we consider 2 cases of uninformative other_shape[i]: # - Shape_i(i)(other_r); # - Shape_i(i)(r). - merged_shape.append(r_shape[i]) - elif isinstance(r_shape[i], (Constant, int)): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(r_shape[i]) - elif isinstance(other_shape[i], (Constant, int)): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(other_shape[i]) - elif other_shape[i] == r_shape[i]: - # This mean the shape is equivalent - # We do not want to do the ancestor check in those cases - merged_shape.append(r_shape[i]) - elif r_shape[i] in ancestors([other_shape[i]]): + merged_shape += (rs,) + elif isinstance(rs, Constant): + # We always prefer constants + merged_shape += (rs,) + elif isinstance(ps, Constant): + merged_shape += (ps,) + elif ps == rs: + # The shapes are equivalent. We do not want to do the ancestor + # check in those cases + merged_shape += (rs,) + elif ( + # TODO FIXME: This could be unnecessarily costly. + rs + in ancestors([ps]) + ): # Another case where we want to use r_shape[i] is when # other_shape[i] actually depends on r_shape[i]. In that case, # we do not want to substitute an expression with another that @@ -442,12 +450,13 @@ def update_shape(self, r, other_r): # to cycles: if (in the future) r_shape[i] gets replaced by an # expression of other_shape[i], other_shape[i] may end up # depending on itself. - merged_shape.append(r_shape[i]) + merged_shape += (rs,) else: - merged_shape.append(other_shape[i]) + merged_shape += (ps,) + assert all( ( - not hasattr(r.type, "shape") + not isinstance(r.type, HasShape) or r.type.shape[i] != 1 and other_r.type.shape[i] != 1 ) @@ -457,73 +466,69 @@ def update_shape(self, r, other_r): ) for i in range(r.type.ndim) ) - self.shape_of[r] = tuple(merged_shape) - for sv in self.shape_of[r]: + + self.shape_of[r] = merged_shape + for sv in merged_shape: self.shape_of_reverse_index.setdefault(sv, set()).add(r) - def set_shape_i(self, r, i, s_i): + def set_shape_i(self, r: Variable, i: int, s_i: Variable) -> None: """Replace element i of shape_of[r] by s_i""" - assert r in self.shape_of + prev_shape = self.shape_of[r] + assert prev_shape is not None + # prev_shape is a tuple, so we cannot change it inplace, # so we build another one. - new_shape = [] + new_shape: Tuple[Variable, ...] = () for j, s_j in enumerate(prev_shape): if j == i: - new_shape.append(self.unpack(s_i, r)) + new_shape += (self.to_symbolic_int(s_i),) else: - new_shape.append(s_j) + new_shape += (s_j,) + assert all( - not hasattr(r.type, "shape") + not isinstance(r.type, HasShape) or r.type.shape[idx] != 1 or self.lscalar_one.equals(new_shape[idx]) or self.lscalar_one.equals(extract_constant(new_shape[idx])) for idx in range(r.type.ndim) ) - self.shape_of[r] = tuple(new_shape) - for sv in self.shape_of[r]: + + self.shape_of[r] = new_shape + + for sv in new_shape: self.shape_of_reverse_index.setdefault(sv, set()).add(r) - def init_r(self, r): + def init_r(self, r: Variable) -> None: """Register r's shape in the shape_of dictionary.""" if r not in self.shape_of: self.set_shape(r, self.shape_tuple(r)) - def make_vector_shape(self, r): - return as_tensor_variable(self.shape_of[r], ndim=1, dtype="int64") + def make_vector_shape(self, r: Variable) -> "TensorVariable": + r_shape = self.shape_of[r] + assert r_shape is not None + return as_tensor_variable(r_shape, ndim=1, dtype="int64") def on_attach(self, fgraph): if hasattr(fgraph, "shape_feature"): raise AlreadyThere("This FunctionGraph already has a ShapeFeature") - if hasattr(self, "fgraph") and self.fgraph != fgraph: - raise Exception("This ShapeFeature is already attached to a graph") - - self.fgraph = fgraph - fgraph.shape_feature = self - # Must be local to the object as otherwise we reuse the same - # variable for multiple fgraph! - self.lscalar_one = constant(1, dtype="int64") + assert self.lscalar_one.type.dtype == "int64" - self.fgraph = fgraph - # Variable -> tuple(scalars) or None (All tensor vars map to tuple) - self.shape_of = {} - # Variable -> - self.scheduled = {} - # shape var -> graph v - self.shape_of_reverse_index = {} + self.shape_of: Dict[Variable, Optional[Tuple[Variable, ...]]] = {} + self.scheduled: Dict["Apply", Variable] = {} + self.shape_of_reverse_index: Dict[Variable, Set[Variable]] = {} for node in fgraph.toposort(): self.on_import(fgraph, node, reason="on_attach") def on_detach(self, fgraph): - self.shape_of = {} - self.scheduled = {} - self.shape_of_reverse_index = {} - self.fgraph = None + self.shape_of.clear() + self.scheduled.clear() + self.shape_of_reverse_index.clear() del fgraph.shape_feature def on_import(self, fgraph, node, reason): @@ -537,7 +542,7 @@ def on_import(self, fgraph, node, reason): # make sure we have shapes for the inputs self.init_r(r) - o_shapes = self.get_node_infer_shape(node) + o_shapes = self.get_node_infer_shape(fgraph, node) # this is packed information # an element of o_shapes is either None or a tuple @@ -573,7 +578,7 @@ def on_import(self, fgraph, node, reason): assert str(d.dtype) != "uint64", node new_shape += sh[len(new_shape) : i + 1] if isinstance(d, Constant): - casted_d = constant(d.data, dtype="int64") + casted_d = constant(d.data, dtype="int64", ndim=0) else: casted_d = cast(d, "int64") new_shape[i] = casted_d @@ -684,10 +689,10 @@ def same_shape( return False if dim_x is not None: - sx = [sx[dim_x]] + sx = (sx[dim_x],) if dim_y is not None: - sy = [sy[dim_y]] + sy = (sy[dim_y],) if len(sx) != len(sy): return False @@ -711,10 +716,10 @@ def same_shape( ) canon_shapes = canon_shapes_fg.outputs - sx = canon_shapes[: len(sx)] - sy = canon_shapes[len(sx) :] + sx_ = canon_shapes[: len(sx)] + sy_ = canon_shapes[len(sx) :] - for dx, dy in zip(sx, sy): + for dx, dy in zip(sx_, sy_): if not equal_computations([dx], [dy]): return False @@ -884,7 +889,7 @@ def local_useless_reshape(fgraph, node): # Match shape_of[input][dim] or its constant equivalent if shape_feature: - inpshp_i = shape_feature.get_shape(inp, dim) + inpshp_i = shape_feature.get_shape(fgraph, inp, dim) if inpshp_i == outshp_i or ( extract_constant(inpshp_i, only_process_constants=1) == extract_constant(outshp_i, only_process_constants=1) diff --git a/aesara/typed_list/type.py b/aesara/typed_list/type.py index 83e8a40f24..f62b1446d3 100644 --- a/aesara/typed_list/type.py +++ b/aesara/typed_list/type.py @@ -146,3 +146,4 @@ def c_code_cache_version(self): dtype = property(lambda self: self.ttype) ndim = property(lambda self: self.ttype.ndim + 1) + shape = property(lambda self: (None,) + self.ttype.shape) diff --git a/doc/config.rst b/doc/config.rst index a6782b9679..73b3064862 100644 --- a/doc/config.rst +++ b/doc/config.rst @@ -188,27 +188,27 @@ import ``aesara`` and print the config variable, as in: .. attribute:: cycle_detection - String value, either ``regular`` or ``fast``` + String value, either ``"regular"`` or ``"fast"``` - Default: ``regular`` + Default: ``"regular"`` - If :attr:`cycle_detection` is set to ``regular``, most in-place operations are allowed, - but graph compilation is slower. If :attr:`cycle_detection` is set to ``faster``, + If :attr:`cycle_detection` is set to ``"regular"``, most in-place operations are allowed, + but graph compilation is slower. If :attr:`cycle_detection` is set to ``"fast"``, less in-place operations are allowed, but graph compilation is faster. .. attribute:: check_stack_trace - String value, either ``off``, ``log``, ``warn``, ``raise`` + String value, either ``"off"``, ``"log"``, ``"warn"``, ``"raise"`` - Default: ``off`` + Default: ``"off"`` This is a flag for checking stack traces during graph rewriting. - If :attr:`check_stack_trace` is set to ``off``, no check is performed on the - stack trace. If :attr:`check_stack_trace` is set to ``log`` or ``warn``, a + If :attr:`check_stack_trace` is set to ``"off"``, no check is performed on the + stack trace. If :attr:`check_stack_trace` is set to ``"log"`` or ``"warn"``, a dummy stack trace is inserted that indicates which rewrite inserted the - variable that had an empty stack trace, but, when ``warn`` is set, a warning + variable that had an empty stack trace, but, when ``"warn"`` is set, a warning is also printed. - If :attr:`check_stack_trace` is set to ``raise``, an exception is raised if a + If :attr:`check_stack_trace` is set to ``"raise"``, an exception is raised if a stack trace is missing. .. attribute:: openmp diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index da057baded..b148310fc8 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -8,12 +8,13 @@ from aesara.compile import shared from aesara.compile.debugmode import DebugMode, InvalidValueError from aesara.compile.function import function -from aesara.compile.function.types import UnusedInputError +from aesara.compile.function.types import Supervisor, UnusedInputError from aesara.compile.io import In, Out from aesara.compile.mode import Mode, get_default_mode from aesara.compile.ops import update_placeholder from aesara.configdefaults import config from aesara.graph.basic import Constant +from aesara.graph.fg import FunctionGraph from aesara.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter from aesara.graph.utils import MissingInputError from aesara.link.vm import VMLinker @@ -33,6 +34,7 @@ vector, ) from aesara.utils import exc_message +from tests.graph.utils import MyVariable, op1 def PatternOptimizer(p1, p2, ign=True): @@ -1285,3 +1287,38 @@ def test_update_placeholder(): # The second update shouldn't be present assert len(f1.maker.fgraph.outputs) == 2 assert f1.maker.fgraph.update_mapping == {1: 3} + + +class TestSupervisor: + def test_basic(self): + var1 = MyVariable("var1") + var2 = MyVariable("var2") + var3 = op1(var2, var1) + fg = FunctionGraph([var1, var2], [var3], clone=False) + + hf = Supervisor([var1]) + fg.attach_feature(hf) + + assert fg._supervisor_protected == {var1} + + # Make sure we can update the protected variables by + # adding another `Supervisor` + hf = Supervisor([var2]) + fg.attach_feature(hf) + + assert fg._supervisor_protected == {var1, var2} + + def test_pickle(self): + var1 = MyVariable("var1") + var2 = MyVariable("var2") + var3 = op1(var2, var1) + fg = FunctionGraph([var1, var2], [var3], clone=False) + + hf = Supervisor([var1]) + fg.attach_feature(hf) + + fg_pkld = pickle.dumps(fg) + fg_unpkld = pickle.loads(fg_pkld) + + assert any(isinstance(ft, Supervisor) for ft in fg_unpkld._features) + assert all(hasattr(fg, attr) for attr in ("_supervisor_protected",)) diff --git a/tests/graph/test_destroyhandler.py b/tests/graph/test_destroyhandler.py index 3470284e66..8f9b002d35 100644 --- a/tests/graph/test_destroyhandler.py +++ b/tests/graph/test_destroyhandler.py @@ -1,10 +1,11 @@ +import pickle from copy import copy import pytest from aesara.configdefaults import config from aesara.graph.basic import Apply, Constant, Variable, clone -from aesara.graph.destroyhandler import DestroyHandler +from aesara.graph.destroyhandler import DestroyHandler, fast_inplace_check from aesara.graph.features import ReplaceValidate from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op @@ -121,9 +122,9 @@ def inputs(): return x, y, z -def create_fgraph(inputs, outputs, validate=True): +def create_fgraph(inputs, outputs, validate=True, algo=None): e = FunctionGraph(inputs, outputs, clone=False) - e.attach_feature(DestroyHandler()) + e.attach_feature(DestroyHandler(algo=algo)) e.attach_feature(ReplaceValidate()) if validate: e.validate() @@ -144,15 +145,23 @@ def test_misc(): e = transpose_view(transpose_view(transpose_view(transpose_view(x)))) g = create_fgraph([x, y, z], [e]) assert g.consistent() + OpKeyPatternNodeRewriter((transpose_view, (transpose_view, "x")), "x").rewrite(g) + assert str(g) == "FunctionGraph(x)" + new_e = add(x, y) g.replace_validate(x, new_e) assert str(g) == "FunctionGraph(Add(x, y))" + g.replace(new_e, dot(add_in_place(x, y), transpose_view(x))) assert str(g) == "FunctionGraph(Dot(AddInPlace(x, y), TransposeView(x)))" assert not g.consistent() + (dh,) = [f for f in g._features if isinstance(f, DestroyHandler)] + g.remove_feature(dh) + assert not hasattr(g, "destroyers") + @assertFailure_fast def test_aliased_inputs_replacement(): @@ -173,14 +182,15 @@ def test_aliased_inputs_replacement(): assert g.consistent() -def test_indestructible(): +@pytest.mark.parametrize("algo", [None, "fast"]) +def test_indestructible(algo): x, y, z = inputs() x.tag.indestructible = True x = copy(x) # checking if indestructible survives the copy! assert x.tag.indestructible e = add_in_place(x, y) - g = create_fgraph([x, y, z], [e], False) + g = create_fgraph([x, y, z], [e], False, algo=algo) assert not g.consistent() g.replace_validate(e, add(x, y)) assert g.consistent() @@ -460,3 +470,47 @@ def test_multiple_inplace(): ).rewrite(g) assert g.consistent() assert fail.failures == 1 + + +def test_pickle(): + x, y, z = inputs() + tv = transpose_view(x) + e = add_in_place(x, tv) + fg = create_fgraph([x, y], [e], False) + assert not fg.consistent() + + fg_pkld = pickle.dumps(fg) + fg_unpkld = pickle.loads(fg_pkld) + + assert any(isinstance(ft, DestroyHandler) for ft in fg_unpkld._features) + assert all(hasattr(fg, attr) for attr in ("_destroyhandler_destroyers",)) + + +def test_fast_inplace_check(): + + x, y = MyVariable("x"), MyVariable("y") + e = add_in_place(x, y) + fg = FunctionGraph(outputs=[e], clone=False) + fg.attach_feature(DestroyHandler()) + + res = fast_inplace_check(fg, fg.inputs) + assert res == [y] + + +def test_fast_destroy(): + """Make sure `DestroyHandler.fast_destroy` catches basic inconsistencies.""" + x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") + + w = add_in_place(x, dot(y, x)) + with pytest.raises(InconsistencyError): + create_fgraph([x, y], [w], algo="fast") + + w = add_in_place(x, y) + w = add_in_place(w, z) + with pytest.raises(InconsistencyError): + create_fgraph([x, y, z], [w], algo="fast") + + w = transpose_view(x) + w = add_in_place(w, y) + with pytest.raises(InconsistencyError): + create_fgraph([x, y], [w], algo="fast") diff --git a/tests/graph/test_features.py b/tests/graph/test_features.py index bd5044fb29..57ec87f961 100644 --- a/tests/graph/test_features.py +++ b/tests/graph/test_features.py @@ -1,11 +1,13 @@ +import pickle + import pytest from aesara.graph.basic import Apply, Variable -from aesara.graph.features import Feature, NodeFinder, ReplaceValidate +from aesara.graph.features import Feature, History, NodeFinder, ReplaceValidate from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op from aesara.graph.type import Type -from tests.graph.utils import MyVariable, op1 +from tests.graph.utils import MyVariable, MyVariable2, op1, op2 class TestNodeFinder: @@ -60,17 +62,12 @@ def perform(self, *args, **kwargs): def MyVariable(name): return Variable(MyType(name), None, None) - def inputs(): - x = MyVariable("x") - y = MyVariable("y") - z = MyVariable("z") - return x, y, z - - x, y, z = inputs() + x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e0 = dot(y, z) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) g = FunctionGraph([x, y, z], [e], clone=False) - g.attach_feature(NodeFinder()) + nf = NodeFinder() + g.attach_feature(nf) assert hasattr(g, "get_nodes") for type, num in ((add, 3), (sigmoid, 3), (dot, 2)): @@ -86,8 +83,39 @@ def inputs(): if len([t for t in g.get_nodes(type)]) != num: raise Exception("Expected: %i times %s" % (num, type)) + g.remove_feature(nf) + assert not hasattr(g, "get_nodes") + assert not hasattr(g, "_finder_ops_to_nodes") + class TestReplaceValidate: + def test_basic(self): + var1 = MyVariable("var1") + var2 = MyVariable("var2") + var3 = op1(var2, var1) + fg = FunctionGraph([var1, var2], [var3], clone=False) + + # One should already be attached + (rv_feature,) = fg._features + assert isinstance(rv_feature, ReplaceValidate) + + fg.attach_feature(ReplaceValidate()) + + assert hasattr(fg, "_replace_nodes_removed") + assert hasattr(fg, "_replace_validate_failed") + + rv_feature.replace_all_validate(fg, [(var3, var1)]) + assert var3 not in fg.variables + + # This `Variable` has a different `Type` + var4 = MyVariable2("var4") + with pytest.raises(TypeError): + rv_feature.replace_all_validate(fg, [(var1, var4)]) + + fg.remove_feature(rv_feature) + assert not hasattr(fg, "_replace_nodes_removed") + assert not hasattr(fg, "_replace_validate_failed") + def test_verbose(self, capsys): var1 = MyVariable("var1") var2 = MyVariable("var2") @@ -120,3 +148,78 @@ def validate(self, *args): capres = capsys.readouterr() assert "rewriting: validate failed on node Op1.0" in capres.out + + def test_pickle(self): + var1 = MyVariable("var1") + var2 = MyVariable("var2") + var3 = op1(var2, var1) + fg = FunctionGraph([var1, var2], [var3], clone=False) + + rv_feature = ReplaceValidate() + fg.attach_feature(rv_feature) + + fg_pkld = pickle.dumps(fg) + fg_unpkld = pickle.loads(fg_pkld) + + assert ReplaceValidate in set(type(ft) for ft in fg_unpkld._features) + assert all( + hasattr(fg, attr) + for attr in ( + "replace_validate", + "replace_all_validate", + "replace_all_validate_remove", + "checkpoint", + "revert", + "validate", + "consistent", + ) + ) + + +class TestHistory: + def test_basic(self): + var1 = MyVariable("var1") + var2 = MyVariable("var2") + var3 = op1(var2, var1) + fg = FunctionGraph([var1, var2], [var3], clone=False) + + hf = History() + fg.attach_feature(hf) + + assert hasattr(fg, "_history_is_reverting") + assert hasattr(fg, "_history_history") + + chkpnt = fg.checkpoint() + + fg.replace_all([(var3, op2(var2, var1))]) + assert var3 not in fg.variables + + assert fg._history_history + fg.revert(chkpnt) + assert not fg._history_is_reverting + + assert not fg._history_history + assert var3 in fg.variables + + def test_pickle(self): + var1 = MyVariable("var1") + var2 = MyVariable("var2") + var3 = op1(var2, var1) + fg = FunctionGraph([var1, var2], [var3], clone=False) + + hf = History() + fg.attach_feature(hf) + + fg_pkld = pickle.dumps(fg) + fg_unpkld = pickle.loads(fg_pkld) + + assert any(isinstance(ft, History) for ft in fg_unpkld._features) + assert all( + hasattr(fg, attr) + for attr in ( + "checkpoint", + "revert", + "validate", + "consistent", + ) + ) diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index e93829d6d2..67075be055 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -7,6 +7,7 @@ from aesara.configdefaults import config from aesara.graph.basic import Variable from aesara.graph.fg import FunctionGraph +from aesara.graph.op import Op from aesara.graph.type import Type from aesara.misc.safe_asarray import _asarray from aesara.scalar.basic import ScalarConstant @@ -46,7 +47,6 @@ ) from aesara.tensor.type_other import NoneConst from aesara.tensor.var import TensorVariable -from aesara.typed_list import make_list from tests import unittest_tools as utt from tests.graph.utils import MyType2 from tests.tensor.utils import eval_outputs, random @@ -560,12 +560,37 @@ def test_reshape(self): @config.change_flags(compute_test_value="raise") def test_nonstandard_shapes(): - a = tensor3(config.floatX) + """Make sure shape inference works when `Op.infer_shape` isn't implemented. + + This also checks that the `HasShape` abstract mixin works when it isn't + explicitly used in a `Type` class definition. + + """ + a = tensor3(name="a", dtype=config.floatX) a.tag.test_value = np.random.random((2, 3, 4)).astype(config.floatX) - b = tensor3(config.floatX) + b = tensor3(name="b", dtype=config.floatX) b.tag.test_value = np.random.random((2, 3, 4)).astype(config.floatX) - tl = make_list([a, b]) + class ListType(Type): + def filter(self, data, **kwargs): + return data + + ndim = 1 + shape = (None,) * 4 + + class MakeList(Op): + itypes = [a.type, b.type] + otypes = [ListType()] + + def perform(self, node, inputs, outputs): + outputs[0][0] = list(inputs) + + make_list = MakeList() + tl = make_list(a, b) + + # from aesara.typed_list import make_list + # + # tl = make_list([a, b]) tl_shape = shape(tl) assert np.array_equal(tl_shape.get_test_value(), (2, 2, 3, 4))