From b96c3e6ac4a22e2414bfa49f6d1820764b78f6a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 19 Apr 2023 16:25:05 +0200 Subject: [PATCH 1/3] feat(common): add a pure python egraph implementation --- ibis/common/collections.py | 225 +++++++++- ibis/common/egraph.py | 582 ++++++++++++++++++++++++++ ibis/common/graph.py | 3 + ibis/common/grounds.py | 5 +- ibis/common/tests/test_collections.py | 72 +++- ibis/common/tests/test_egraph.py | 454 ++++++++++++++++++++ ibis/common/tests/test_validators.py | 5 + ibis/expr/operations/core.py | 3 - 8 files changed, 1339 insertions(+), 10 deletions(-) create mode 100644 ibis/common/egraph.py create mode 100644 ibis/common/tests/test_egraph.py diff --git a/ibis/common/collections.py b/ibis/common/collections.py index 804fa98b1645..3a56872afd9b 100644 --- a/ibis/common/collections.py +++ b/ibis/common/collections.py @@ -1,12 +1,12 @@ from __future__ import annotations from types import MappingProxyType -from typing import Any, Hashable, Mapping, TypeVar +from typing import Any, Hashable, Iterable, Iterator, Mapping, Set, TypeVar from public import public from typing_extensions import Self -K = TypeVar("K") +K = TypeVar("K", bound=Hashable) V = TypeVar("V") @@ -208,4 +208,225 @@ def __repr__(self): return f"{self.__class__.__name__}({super().__repr__()})" +class DisjointSet(Mapping[K, Set[K]]): + """Disjoint set data structure. + + Also known as union-find data structure. It is a data structure that keeps + track of a set of elements partitioned into a number of disjoint (non-overlapping) + subsets. It provides near-constant-time operations to add new sets, to merge + existing sets, and to determine whether elements are in the same set. + + Parameters + ---------- + data : + Initial data to add to the disjoint set. + + Examples + -------- + >>> ds = DisjointSet() + >>> ds.add(1) + 1 + >>> ds.add(2) + 2 + >>> ds.add(3) + 3 + >>> ds.union(1, 2) + True + >>> ds.union(2, 3) + True + >>> ds.find(1) + 1 + >>> ds.find(2) + 1 + >>> ds.find(3) + 1 + >>> ds.union(1, 3) + False + """ + + __slots__ = ("_parents", "_classes") + + def __init__(self, data: Iterable[K] | None = None): + self._parents = {} + self._classes = {} + if data is not None: + for id in data: + self.add(id) + + def __contains__(self, id) -> bool: + """Check if the given id is in the disjoint set. + + Parameters + ---------- + id : + The id to check. + + Returns + ------- + ined: + True if the id is in the disjoint set, False otherwise. + """ + return id in self._parents + + def __getitem__(self, id) -> set[K]: + """Get the set of ids that are in the same class as the given id. + + Parameters + ---------- + id : + The id to get the class for. + + Returns + ------- + class: + The set of ids that are in the same class as the given id, including + the given id. + """ + id = self._parents[id] + return self._classes[id] + + def __iter__(self) -> Iterator[K]: + """Iterate over the ids in the disjoint set.""" + return iter(self._parents) + + def __len__(self) -> int: + """Get the number of ids in the disjoint set.""" + return len(self._parents) + + def __eq__(self, other: Self[K]) -> bool: + """Check if the disjoint set is equal to another disjoint set. + + Parameters + ---------- + other : + The other disjoint set to compare to. + + Returns + ------- + equal: + True if the disjoint sets are equal, False otherwise. + """ + if not isinstance(other, DisjointSet): + return NotImplemented + return self._parents == other._parents + + def add(self, id: K) -> K: + """Add a new id to the disjoint set. + + If the id is not in the disjoint set, it will be added to the disjoint set + along with a new class containing only the given id. + + Parameters + ---------- + id : + The id to add to the disjoint set. + + Returns + ------- + id: + The id that was added to the disjoint set. + """ + if id in self._parents: + return self._parents[id] + self._parents[id] = id + self._classes[id] = {id} + return id + + def find(self, id: K) -> K: + """Find the root of the class that the given id is in. + + Also called as the canonicalized id or the representative id. + + Parameters + ---------- + id : + The id to find the canonicalized id for. + + Returns + ------- + id: + The canonicalized id for the given id. + """ + return self._parents[id] + + def union(self, id1, id2) -> bool: + """Merge the classes that the given ids are in. + + If the ids are already in the same class, this will return False. Otherwise + it will merge the classes and return True. + + Parameters + ---------- + id1 : + The first id to merge the classes for. + id2 : + The second id to merge the classes for. + + Returns + ------- + merged: + True if the classes were merged, False otherwise. + """ + # Find the root of each class + id1 = self._parents[id1] + id2 = self._parents[id2] + if id1 == id2: + return False + + # Merge the smaller eclass into the larger one, aka. union-find by size + class1 = self._classes[id1] + class2 = self._classes[id2] + if len(class1) >= len(class2): + id1, id2 = id2, id1 + class1, class2 = class2, class1 + + # Update the parent pointers, this is called path compression but done + # during the union operation to keep the find operation minimal + for id in class1: + self._parents[id] = id2 + + # Do the actual merging and clear the other eclass + class2 |= class1 + class1.clear() + + return True + + def connected(self, id1, id2): + """Check if the given ids are in the same class. + + True if both ids have the same canonicalized id, False otherwise. + + Parameters + ---------- + id1 : + The first id to check. + id2 : + The second id to check. + + Returns + ------- + connected: + True if the ids are connected, False otherwise. + """ + return self._parents[id1] == self._parents[id2] + + def verify(self): + """Verify that the disjoint set is not corrupted. + + Check that each id's canonicalized id's class. In general corruption + should not happen if the public API is used, but this is a sanity check + to make sure that the internal data structures are not corrupted. + + Returns + ------- + verified: + True if the disjoint set is not corrupted, False otherwise. + """ + for id in self._parents: + if id not in self._classes[self._parents[id]]: + raise RuntimeError( + f"DisjointSet is corrupted: {id} is not in its class" + ) + + public(frozendict=FrozenDict, dotdict=DotDict) diff --git a/ibis/common/egraph.py b/ibis/common/egraph.py new file mode 100644 index 000000000000..53259d26a8df --- /dev/null +++ b/ibis/common/egraph.py @@ -0,0 +1,582 @@ +from __future__ import annotations + +import collections +import itertools +import math +from typing import Any + +from ibis.common.collections import DisjointSet +from ibis.common.graph import Node +from ibis.util import promote_list + + +class Slotted: + """A lightweight alternative to `ibis.common.grounds.Concrete`. + + This class is used to create immutable dataclasses with slots and a precomputed + hash value for quicker dictionary lookups. + """ + + __slots__ = ('__precomputed_hash__',) + + def __init__(self, *args): + for name, value in itertools.zip_longest(self.__slots__, args): + object.__setattr__(self, name, value) + object.__setattr__(self, "__precomputed_hash__", hash(args)) + + def __eq__(self, other): + if self is other: + return True + if type(self) is not type(other): + return NotImplemented + for name in self.__slots__: + if getattr(self, name) != getattr(other, name): + return False + return True + + def __hash__(self): + return self.__precomputed_hash__ + + def __setattr__(self, name, value): + raise AttributeError("Can't set attributes on immutable ENode instance") + + +class Variable(Slotted): + """A named capture in a pattern. + + Parameters + ---------- + name : str + The name of the variable. + """ + + __slots__ = ("name",) + + def __repr__(self): + return f"${self.name}" + + def substitute(self, egraph, enode, subst): + """Substitute the variable with the corresponding value in the substitution. + + Parameters + ---------- + egraph : EGraph + The egraph instance. + enode : ENode + The matched enode. + subst : dict + The substitution dictionary. + + Returns + ------- + value : Any + The substituted value. + """ + return subst[self.name] + + +# Pattern corresponsds to a selection which is flattened to a join of selections +class Pattern(Slotted): + """A non-ground term, tree of enodes possibly containing variables. + + This class is used to represent a pattern in a query. The pattern is almost + identical to an ENode, except that it can contain variables. + + Parameters + ---------- + head : type + The head or python type of the ENode to match against. + args : tuple + The arguments of the pattern. The arguments can be enodes, patterns, + variables or leaf values. + name : str, optional + The name of the pattern which is used to refer to it in a rewrite rule. + """ + + __slots__ = ("head", "args", "name") + + # TODO(kszucs): consider to raise if the pattern matches none + def __init__(self, head, args, name=None, conditions=None): + super().__init__(head, tuple(args), name) + + def matches_none(self): + """Evaluate whether the pattern is guaranteed to match nothing. + + This can be evaluated before the matching loop starts, so eventually can + be eliminated from the flattened query. + """ + return len(self.head.__argnames__) != len(self.args) + + def matches_all(self): + """Evaluate whether the pattern is guaranteed to match everything. + + This can be evaluated before the matching loop starts, so eventually can + be eliminated from the flattened query. + """ + return not self.matches_none() and all( + isinstance(arg, Variable) for arg in self.args + ) + + def __repr__(self): + argstring = ", ".join(map(repr, self.args)) + return f"P{self.head.__name__}({argstring})" + + def __rshift__(self, rhs): + """Syntax sugar to create a rewrite rule.""" + return Rewrite(self, rhs) + + def __rmatmul__(self, name): + """Syntax sugar to create a named pattern.""" + return self.__class__(self.head, self.args, name) + + def to_enode(self): + """Convert the pattern to an ENode. + + None of the arguments can be a pattern or a variable. + + Returns + ------- + enode : ENode + The pattern converted to an ENode. + """ + # TODO(kszucs): ensure that self is a ground term + return ENode(self.head, self.args) + + def flatten(self, var=None, counter=None): + """Recursively flatten the pattern to a join of selections. + + `Pattern(Add, (Pattern(Mul, ($x, 1)), $y))` is turned into a join of + selections by introducing auxilary variables where each selection gets + executed as a dictionary lookup. + + In SQL terms this is equivalent to the following query: + SELECT m.0 AS $x, a.1 AS $y FROM Add a JOIN Mul m ON a.0 = m.id WHERE m.1 = 1 + + Parameters + ---------- + var : Variable + The variable to assign to the flattened pattern. + counter : Iterator[int] + The counter to generate unique variable names for auxilary variables + connecting the selections. + + Yields + ------ + (var, pattern) : tuple[Variable, Pattern] + The variable and the flattened pattern where the flattened pattern + cannot contain any patterns just variables. + """ + # TODO(kszucs): convert a pattern to a query object instead by flattening it + counter = counter or itertools.count() + + if var is None: + if self.name is None: + var = Variable(next(counter)) + else: + var = Variable(self.name) + + args = [] + for arg in self.args: + if isinstance(arg, Pattern): + if arg.name is None: + aux = Variable(next(counter)) + else: + aux = Variable(arg.name) + yield from arg.flatten(aux, counter) + args.append(aux) + else: + args.append(arg) + + yield (var, Pattern(self.head, args)) + + def substitute(self, egraph, enode, subst): + """Substitute the variables in the pattern with the corresponding values. + + Parameters + ---------- + egraph : EGraph + The egraph instance. + enode : ENode + The matched enode. + subst : dict + The substitution dictionary. + + Returns + ------- + enode : ENode + The substituted pattern which is a ground term aka. an ENode. + """ + args = [] + for arg in self.args: + if isinstance(arg, (Variable, Pattern)): + arg = arg.substitute(egraph, enode, subst) + args.append(arg) + return ENode(self.head, tuple(args)) + + +class DynamicApplier(Slotted): + """A dynamic applier which calls a function to compute the result.""" + + __slots__ = ("func",) + + def substitute(self, egraph, enode, subst): + kwargs = {k: v for k, v in subst.items() if isinstance(k, str)} + result = self.func(egraph, enode, **kwargs) + return result.to_enode() if isinstance(result, Pattern) else result + + +class Rewrite(Slotted): + """A rewrite rule which matches a pattern and applies a pattern or a function.""" + + __slots__ = ("matcher", "applier") + + def __init__(self, matcher, applier): + if callable(applier): + applier = DynamicApplier(applier) + elif not isinstance(applier, (Pattern, Variable)): + raise TypeError( + "applier must be a Pattern or a Variable returning an ENode" + ) + super().__init__(matcher, applier) + + def __repr__(self): + return f"{self.lhs} >> {self.rhs}" + + +class ENode(Slotted, Node): + """A ground term which is a node in the EGraph, called ENode. + + Parameters + ---------- + head : type + The type of the Node the ENode represents. + args : tuple + The arguments of the ENode which are either ENodes or leaf values. + """ + + __slots__ = ("head", "args") + + def __init__(self, head, args): + super().__init__(head, tuple(args)) + + @property + def __argnames__(self): + """Implementation for the `ibis.common.graph.Node` protocol.""" + return self.head.__argnames__ + + @property + def __args__(self): + """Implementation for the `ibis.common.graph.Node` protocol.""" + return self.args + + def __repr__(self): + argstring = ", ".join(map(repr, self.args)) + return f"E{self.head.__name__}({argstring})" + + def __lt__(self, other): + return False + + @classmethod + def from_node(cls, node: Any): + """Convert an `ibis.common.graph.Node` to an `ENode`.""" + + def mapper(node, _, **kwargs): + return cls(node.__class__, kwargs.values()) + + return node.map(mapper)[node] + + def to_node(self): + """Convert the ENode back to an `ibis.common.graph.Node`.""" + + def mapper(node, _, **kwargs): + return node.head(**kwargs) + + return self.map(mapper)[self] + + +# TODO: move every E* into the Egraph so its API only uses Nodes +# TODO: track whether the egraph is saturated or not +# TODO: support parent classes in etables (Join <= InnerJoin) + + +class EGraph: + __slots__ = ("_nodes", "_etables", "_eclasses") + + def __init__(self): + # store the nodes before converting them to enodes, so we can spare the initial + # node traversal and omit the creation of enodes + self._nodes = {} + # map enode heads to their eclass ids and their arguments, this is required for + # the relational e-matching (Node => dict[type, tuple[Union[ENode, Any], ...]]) + self._etables = collections.defaultdict(dict) + # map enodes to their eclass, this is the heart of the egraph + self._eclasses = DisjointSet() + + def __repr__(self): + return f"EGraph({self._eclasses})" + + def _as_enode(self, node: Node) -> ENode: + """Convert a node to an enode.""" + # order is important here since ENode is a subclass of Node + if isinstance(node, ENode): + return node + elif isinstance(node, Node): + return self._nodes.get(node) or ENode.from_node(node) + else: + raise TypeError(node) + + def add(self, node: Node) -> ENode: + """Add a node to the egraph. + + The node is converted to an enode and added to the egraph. If the enode is + already present in the egraph, then the canonical enode is returned. + + Parameters + ---------- + node : + The node to add to the egraph. + + Returns + ------- + enode : + The canonical enode. + """ + enode = self._as_enode(node) + if enode in self._eclasses: + return self._eclasses.find(enode) + + args = [] + for arg in enode.args: + if isinstance(arg, ENode): + args.append(self.add(arg)) + else: + args.append(arg) + + enode = ENode(enode.head, args) + self._eclasses.add(enode) + self._etables[enode.head][enode] = tuple(args) + + return enode + + def union(self, node1: Node, node2: Node) -> ENode: + """Union two nodes in the egraph. + + The nodes are converted to enodes which must be present in the egraph. + The eclasses of the nodes are merged and the canonical enode is returned. + + Parameters + ---------- + node1 : + The first node to union. + node2 : + The second node to union. + + Returns + ------- + enode : + The canonical enode. + """ + enode1 = self._as_enode(node1) + enode2 = self._as_enode(node2) + return self._eclasses.union(enode1, enode2) + + def _match_args(self, args, patargs): + """Match the arguments of an enode against a pattern's arguments. + + An enode matches a pattern if each of the arguments are: + - both leaf values and equal + - both enodes and in the same eclass + - an enode and a variable, in which case the variable gets bound to the enode + + Parameters + ---------- + args : tuple + The arguments of the enode. Since an enode is a ground term, the arguments + are either enodes or leaf values. + patargs : tuple + The arguments of the pattern. Since a pattern is a flat term (flattened + using auxilliary variables), the arguments are either variables or leaf + values. + + Returns + ------- + dict[str, Any] : + The mapping of variable names to enodes or leaf values. + """ + subst = {} + for arg, patarg in zip(args, patargs): + if isinstance(patarg, Variable): + if patarg.name is None: + pass + elif isinstance(arg, ENode): + subst[patarg.name] = self._eclasses.find(arg) + else: + subst[patarg.name] = arg + elif isinstance(arg, ENode): + if self._eclasses.find(arg) != self._eclasses.find(arg): + return None + elif patarg != arg: + return None + return subst + + def match(self, pattern: Pattern) -> dict[ENode, dict[str, Any]]: + """Match a pattern in the egraph. + + The pattern is converted to a conjunctive query (list of flat patterns) and + matched against the relations represented by the egraph. This is called the + relational e-matching. + + Parameters + ---------- + pattern : + The pattern to match in the egraph. + + Returns + ------- + matches : + A dictionary mapping the matched enodes to their substitutions. + """ + # patterns could be reordered to match on the most selective one first + patterns = dict(reversed(list(pattern.flatten()))) + if any(pat.matches_none() for pat in patterns.values()): + return {} + + # extract the first pattern + (auxvar, pattern), *rest = patterns.items() + matches = {} + + # match the first pattern and create the initial substitutions + rel = self._etables[pattern.head] + for enode, args in rel.items(): + if (subst := self._match_args(args, pattern.args)) is not None: + subst[auxvar.name] = enode + matches[enode] = subst + + # match the rest of the patterns and extend the substitutions + for auxvar, pattern in rest: + rel = self._etables[pattern.head] + tmp = {} + for enode, subst in matches.items(): + if args := rel.get(subst[auxvar.name]): + if (newsubst := self._match_args(args, pattern.args)) is not None: + tmp[enode] = {**subst, **newsubst} + matches = tmp + + return matches + + def apply(self, rewrites: list[Rewrite]) -> int: + """Apply the given rewrites to the egraph. + + Iteratively match the patterns and apply the rewrites to the graph. The returned + number of changes is the number of eclasses that were merged. This is the + number of changes made to the egraph. The egraph is saturated if the number of + changes is zero. + + Parameters + ---------- + rewrites : + A list of rewrites to apply. + + Returns + ------- + n_changes + The number of changes made to the egraph. + """ + n_changes = 0 + for rewrite in promote_list(rewrites): + for match, subst in self.match(rewrite.matcher).items(): + enode = rewrite.applier.substitute(self, match, subst) + enode = self.add(enode) + n_changes += self._eclasses.union(match, enode) + return n_changes + + def run(self, rewrites: list[Rewrite], n: int = 10) -> bool: + """Run the match-apply cycles for the given number of iterations. + + Parameters + ---------- + rewrites : + A list of rewrites to apply. + n : + The number of iterations to run. + + Returns + ------- + saturated : + True if the egraph is saturated, False otherwise. + """ + return any(not self.apply(rewrites) for _i in range(n)) + + # TODO(kszucs): investigate whether the costs and best enodes could be maintained + # during the union operations after each match-apply cycle + def extract(self, node: Node) -> Node: + """Extract a node from the egraph. + + The node is converted to an enode which recursively gets converted to an + enode having the lowest cost according to equivalence classes. Currently + the cost function is hardcoded as the depth of the enode. + + Parameters + ---------- + node : + The node to extract from the egraph. + + Returns + ------- + node : + The extracted node. + """ + enode = self._as_enode(node) + enode = self._eclasses.find(enode) + costs = {en: (math.inf, None) for en in self._eclasses.keys()} + + def enode_cost(enode): + cost = 1 + for arg in enode.args: + if isinstance(arg, ENode): + cost += costs[arg][0] + else: + cost += 1 + return cost + + changed = True + while changed: + changed = False + for en, enodes in self._eclasses.items(): + new_cost = min((enode_cost(en), en) for en in enodes) + if costs[en][0] != new_cost[0]: + changed = True + costs[en] = new_cost + + def extract(en): + if not isinstance(en, ENode): + return en + best = costs[en][1] + args = tuple(extract(a) for a in best.args) + return best.head(*args) + + return extract(enode) + + def equivalent(self, node1: Node, node2: Node) -> bool: + """Check if two nodes are equivalent. + + The nodes are converted to enodes and checked for equivalence: they are + equivalent if they are in the same equivalence class. + + Parameters + ---------- + node1 : + The first node. + node2 : + The second node. + + Returns + ------- + equivalent : + True if the nodes are equivalent, False otherwise. + """ + enode1 = self._as_enode(node1) + enode2 = self._as_enode(node2) + enode1 = self._eclasses.find(enode1) + enode2 = self._eclasses.find(enode2) + return enode1 == enode2 diff --git a/ibis/common/graph.py b/ibis/common/graph.py index ae1e2f69ead8..c41495c7ff69 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -25,6 +25,9 @@ def __argnames__(self) -> Sequence: def __children__(self, filter=None): return tuple(_flatten_collections(self.__args__, filter or Node)) + def __rich_repr__(self): + return zip(self.__argnames__, self.__args__) + def map(self, fn, filter=None): results = {} for node in Graph.from_bfs(self, filter=filter).toposort(): diff --git a/ibis/common/grounds.py b/ibis/common/grounds.py index 3851c95eaa90..fe0fedc81d4b 100644 --- a/ibis/common/grounds.py +++ b/ibis/common/grounds.py @@ -27,10 +27,7 @@ def __call__(cls, *args, **kwargs) -> Base: class Base(metaclass=BaseMeta): __slots__ = ('__weakref__',) - - @classmethod - def __create__(cls, *args, **kwargs) -> Base: - return type.__call__(cls, *args, **kwargs) + __create__ = classmethod(type.__call__) class AnnotableMeta(BaseMeta): diff --git a/ibis/common/tests/test_collections.py b/ibis/common/tests/test_collections.py index d7b0b389b3f4..c24129481f49 100644 --- a/ibis/common/tests/test_collections.py +++ b/ibis/common/tests/test_collections.py @@ -2,7 +2,7 @@ import pytest -from ibis.common.collections import DotDict, FrozenDict, MapSet +from ibis.common.collections import DisjointSet, DotDict, FrozenDict, MapSet from ibis.tests.util import assert_pickle_roundtrip @@ -221,3 +221,73 @@ def test_frozendict(): assert hash(d) assert_pickle_roundtrip(d) + + +def test_disjoint_set(): + ds = DisjointSet() + ds.add(1) + ds.add(2) + ds.add(3) + ds.add(4) + + ds1 = DisjointSet([1, 2, 3, 4]) + assert ds == ds1 + assert ds[1] == {1} + assert ds[2] == {2} + assert ds[3] == {3} + assert ds[4] == {4} + + assert ds.union(1, 2) is True + assert ds[1] == {1, 2} + assert ds[2] == {1, 2} + assert ds.union(2, 3) is True + assert ds[1] == {1, 2, 3} + assert ds[2] == {1, 2, 3} + assert ds[3] == {1, 2, 3} + assert ds.union(1, 3) is False + assert ds[4] == {4} + assert ds != ds1 + assert 1 in ds + assert 2 in ds + assert 5 not in ds + + assert ds.find(1) == 1 + assert ds.find(2) == 1 + assert ds.find(3) == 1 + assert ds.find(4) == 4 + + assert ds.connected(1, 2) is True + assert ds.connected(1, 3) is True + assert ds.connected(1, 4) is False + + # test mapping api get + assert ds.get(1) == {1, 2, 3} + assert ds.get(4) == {4} + assert ds.get(5) is None + assert ds.get(5, 5) == 5 + assert ds.get(5, default=5) == 5 + + # test mapping api keys + assert set(ds.keys()) == {1, 2, 3, 4} + assert set(ds) == {1, 2, 3, 4} + + # test mapping api values + assert tuple(ds.values()) == ({1, 2, 3}, {1, 2, 3}, {1, 2, 3}, {4}) + + # test mapping api items + assert tuple(ds.items()) == ( + (1, {1, 2, 3}), + (2, {1, 2, 3}), + (3, {1, 2, 3}), + (4, {4}), + ) + + # check that the disjoint set doesn't get corrupted by adding an existing element + ds.verify() + ds.add(1) + ds.verify() + + with pytest.raises(RuntimeError, match="DisjointSet is corrupted"): + ds._parents[1] = 1 + ds._classes[1] = {1} + ds.verify() diff --git a/ibis/common/tests/test_egraph.py b/ibis/common/tests/test_egraph.py new file mode 100644 index 000000000000..22315b2821b2 --- /dev/null +++ b/ibis/common/tests/test_egraph.py @@ -0,0 +1,454 @@ +import itertools +from typing import Any, Tuple + +import pytest + +import ibis +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +from ibis.common.collections import DisjointSet +from ibis.common.egraph import EGraph, ENode, Pattern, Rewrite, Variable +from ibis.common.graph import Graph, Node +from ibis.common.grounds import Concrete +from ibis.util import promote_tuple + + +class PatternNamespace: + def __init__(self, module): + self.module = module + + def __getattr__(self, name): + klass = getattr(self.module, name) + + def pattern(*args): + return Pattern(klass, args) + + return pattern + + +p = PatternNamespace(ops) + +one = ibis.literal(1) +two = one * 2 +two_ = one + one +two__ = ibis.literal(2) +three = one + two +six = three * two_ +seven = six + 1 +seven_ = seven * 1 +eleven = seven_ + 4 + +a, b, c = Variable('a'), Variable('b'), Variable('c') +x, y, z = Variable('x'), Variable('y'), Variable('z') + + +class Base(Concrete, Node): + def __class_getitem__(self, args): + args = promote_tuple(args) + return Pattern(self, args) + + +class Lit(Base): + value: Any + + +class Add(Base): + x: Any + y: Any + + +class Mul(Base): + x: Any + y: Any + + +def test_enode(): + node = ENode(1, (2, 3)) + assert node == ENode(1, (2, 3)) + assert node != ENode(1, [2, 4]) + assert node != ENode(1, [2, 3, 4]) + assert node != ENode(1, [2]) + assert hash(node) == hash(ENode(1, (2, 3))) + assert hash(node) != hash(ENode(1, (2, 4))) + + with pytest.raises(AttributeError, match="immutable"): + node.head = 2 + with pytest.raises(AttributeError, match="immutable"): + node.args = (2, 3) + + +def test_enode_roundtrip(): + class MyNode(Concrete, Node): + a: int + b: int + c: str + + # create e-node from node + node = MyNode(a=1, b=2, c="3") + enode = ENode.from_node(node) + assert enode == ENode(MyNode, (1, 2, "3")) + + # reconstruct node from e-node + node_ = enode.to_node() + assert node_ == node + + +def test_enode_roundtrip_with_variadic_arg(): + class MyNode(Concrete, Node): + a: int + b: Tuple[int, ...] + + # create e-node from node + node = MyNode(a=1, b=(2, 3)) + enode = ENode.from_node(node) + assert enode == ENode(MyNode, (1, (2, 3))) + + # reconstruct node from e-node + node_ = enode.to_node() + assert node_ == node + + +def test_enode_roundtrip_with_nested_arg(): + class MyInt(Concrete, Node): + value: int + + class MyNode(Concrete, Node): + a: int + b: Tuple[MyInt, ...] + + # create e-node from node + node = MyNode(a=1, b=(MyInt(value=2), MyInt(value=3))) + enode = ENode.from_node(node) + assert enode == ENode(MyNode, (1, (ENode(MyInt, (2,)), ENode(MyInt, (3,))))) + + # reconstruct node from e-node + node_ = enode.to_node() + assert node_ == node + + +def test_disjoint_set_with_enode(): + class MyNode(Concrete, Node): + pass + + class MyLit(MyNode): + value: int + + class MyAdd(MyNode): + a: MyNode + b: MyNode + + class MyMul(MyNode): + a: MyNode + b: MyNode + + # number postfix highlights the depth of the node + one = MyLit(value=1) + two = MyLit(value=2) + two1 = MyAdd(a=one, b=one) + three1 = MyAdd(a=one, b=two) + six2 = MyMul(a=three1, b=two1) + seven2 = MyAdd(a=six2, b=one) + + # expected enodes postfixed with an underscore + one_ = ENode(MyLit, (1,)) + two_ = ENode(MyLit, (2,)) + three_ = ENode(MyLit, (3,)) + two1_ = ENode(MyAdd, (one_, one_)) + three1_ = ENode(MyAdd, (one_, two_)) + six2_ = ENode(MyMul, (three1_, two1_)) + seven2_ = ENode(MyAdd, (six2_, one_)) + + enode = ENode.from_node(seven2) + assert enode == seven2_ + + assert enode.to_node() == seven2 + + ds = DisjointSet() + for enode in Graph.from_bfs(seven2_): + ds.add(enode) + assert ds.find(enode) == enode + + # merging identical nodes should return False + assert ds.union(three1_, three1_) is False + assert ds.find(three1_) == three1_ + assert ds[three1_] == {three1_} + + # now merge a (1 + 2) and (3) nodes, but first add `three_` to the set + ds.add(three_) + assert ds.union(three1_, three_) is True + assert ds.find(three1_) == three1_ + assert ds.find(three_) == three1_ + assert ds[three_] == {three_, three1_} + + +def test_pattern(): + Pattern._counter = itertools.count() + + p = Pattern(ops.Literal, (1, dt.int8)) + assert p.head == ops.Literal + assert p.args == (1, dt.int8) + assert p.name is None + + p = "name" @ Pattern(ops.Literal, (1, dt.int8)) + assert p.head == ops.Literal + assert p.args == (1, dt.int8) + assert p.name == "name" + + +def test_pattern_flatten(): + # using auto-generated names + one = Pattern(ops.Literal, (1, dt.int8)) + two = Pattern(ops.Literal, (2, dt.int8)) + three = Pattern(ops.Add, (one, two)) + + result = dict(three.flatten()) + expected = { + Variable(0): Pattern(ops.Add, (Variable(1), Variable(2))), + Variable(2): Pattern(ops.Literal, (2, dt.int8)), + Variable(1): Pattern(ops.Literal, (1, dt.int8)), + } + assert result == expected + + # using user-provided names which helps capturing variables + one = "one" @ Pattern(ops.Literal, (1, dt.int8)) + two = "two" @ Pattern(ops.Literal, (2, dt.int8)) + three = "three" @ Pattern(ops.Add, (one, two)) + + result = tuple(three.flatten()) + expected = ( + (Variable("one"), Pattern(ops.Literal, (1, dt.int8))), + (Variable("two"), Pattern(ops.Literal, (2, dt.int8))), + (Variable("three"), Pattern(ops.Add, (Variable("one"), Variable("two")))), + ) + assert result == expected + + +def test_egraph_match_simple(): + eg = EGraph() + eg.add(eleven.op()) + + pat = p.Multiply(a, "lit" @ p.Literal(1, dt.int8)) + res = eg.match(pat) + + enode = ENode.from_node(seven_.op()) + matches = res[enode] + assert matches['a'] == ENode.from_node(seven.op()) + assert matches['lit'] == ENode.from_node(one.op()) + + +def test_egraph_match_wrong_argnum(): + two = one + one + four = two + two + + eg = EGraph() + eg.add(four.op()) + + # here we have an extra `2` among the literal's arguments + pat = p.Add(a, p.Add(p.Literal(1, dt.int8, 2), b)) + res = eg.match(pat) + + assert res == {} + + pat = p.Add(a, p.Add(p.Literal(1, dt.int8), b)) + res = eg.match(pat) + + expected = { + ENode.from_node(four.op()): { + 0: ENode.from_node(four.op()), + 1: ENode.from_node(two.op()), + 2: ENode.from_node(one.op()), + 'a': ENode.from_node(two.op()), + 'b': ENode.from_node(one.op()), + } + } + assert res == expected + + +def test_egraph_match_nested(): + node = eleven.op() + enode = ENode.from_node(node) + + eg = EGraph() + eg.add(enode) + + result = eg.match(p.Multiply(a, p.Literal(1, b))) + matched = ENode.from_node(seven_.op()) + + expected = { + matched: { + 0: matched, + 1: ENode.from_node(one.op()), + 'a': ENode.from_node(seven.op()), + 'b': dt.int8, + } + } + assert result == expected + + +def test_egraph_apply_nested(): + node = eleven.op() + enode = ENode.from_node(node) + + eg = EGraph() + eg.add(enode) + + r3 = p.Multiply(a, p.Literal(1, dt.int8)) >> a + eg.apply(r3) + + result = eg.extract(seven_.op()) + expected = seven.op() + assert result == expected + + +def test_egraph_extract_simple(): + eg = EGraph() + eg.add(eleven.op()) + + res = eg.extract(one.op()) + assert res == one.op() + + +def test_egraph_extract_minimum_cost(): + eg = EGraph() + eg.add(two.op()) # 1 * 2 + eg.add(two_.op()) # 1 + 1 + eg.add(two__.op()) # 2 + assert eg.extract(two.op()) == two.op() + + eg.union(two.op(), two_.op()) + assert eg.extract(two.op()) in {two.op(), two_.op()} + + eg.union(two.op(), two__.op()) + assert eg.extract(two.op()) == two__.op() + + eg.union(two.op(), two__.op()) + assert eg.extract(two.op()) == two__.op() + + +def test_egraph_rewrite_to_variable(): + eg = EGraph() + eg.add(eleven.op()) + + # rule with a variable on the right-hand side + rule = Rewrite(p.Multiply(a, "lit" @ p.Literal(1, dt.int8)), a) + eg.apply(rule) + assert eg.equivalent(seven_.op(), seven.op()) + + +def test_egraph_rewrite_to_constant_raises(): + node = (one * 0).op() + + eg = EGraph() + eg.add(node) + + # rule with a constant on the right-hand side + with pytest.raises(TypeError): + Rewrite(p.Multiply(a, "lit" @ p.Literal(0, dt.int8)), 0) + + +def test_egraph_rewrite_to_pattern(): + eg = EGraph() + eg.add(three.op()) + + # rule with a pattern on the right-hand side + rule = Rewrite(p.Multiply(a, "lit" @ p.Literal(2, dt.int8)), p.Add(a, a)) + eg.apply(rule) + assert eg.equivalent(two.op(), two_.op()) + + +def test_egraph_rewrite_dynamic(): + def applier(egraph, match, a, mul, times): + return p.Add(a, a).to_enode() + + node = (one * 2).op() + + eg = EGraph() + eg.add(node) + + # rule with a dynamic pattern on the right-hand side + rule = Rewrite( + "mul" @ p.Multiply(a, p.Literal(Variable("times"), dt.int8)), applier + ) + eg.apply(rule) + + assert eg.extract(node) in {two.op(), two_.op()} + + +def test_egraph_rewrite_commutative(): + rules = [ + Mul[a, b] >> Mul[b, a], + Mul[a, Lit[1]] >> a, + ] + node = Mul(Lit(2), Mul(Lit(1), Lit(3))) + expected = {Mul(Lit(2), Lit(3)), Mul(Lit(3), Lit(2))} + + egraph = EGraph() + egraph.add(node) + egraph.run(rules, 200) + best = egraph.extract(node) + + assert best in expected + + +@pytest.mark.parametrize( + ('node', 'expected'), + [(Mul(Lit(0), Lit(42)), Lit(0)), (Add(Lit(0), Mul(Lit(1), Lit(2))), Lit(2))], +) +def test_egraph_rewrite(node, expected): + rules = [ + Add[a, b] >> Add[b, a], + Mul[a, b] >> Mul[b, a], + Add[a, Lit[0]] >> a, + Mul[a, Lit[0]] >> Lit[0], + Mul[a, Lit[1]] >> a, + ] + egraph = EGraph() + egraph.add(node) + egraph.run(rules, 100) + best = egraph.extract(node) + + assert best == expected + + +def is_equal(a, b, rules, iters=7): + egraph = EGraph() + id_a = egraph.add(a) + id_b = egraph.add(b) + egraph.run(rules, iters) + return egraph.equivalent(id_a, id_b) + + +def test_math_associate_adds(benchmark): + math_rules = [Add[a, b] >> Add[b, a], Add[a, Add[b, c]] >> Add[Add[a, b], c]] + + expr_a = Add(1, Add(2, Add(3, Add(4, Add(5, Add(6, 7)))))) + expr_b = Add(7, Add(6, Add(5, Add(4, Add(3, Add(2, 1)))))) + assert is_equal(expr_a, expr_b, math_rules, iters=500) + + expr_a = Add(6, Add(Add(1, 5), Add(0, Add(4, Add(2, 3))))) + expr_b = Add(6, Add(Add(4, 5), Add(Add(0, 2), Add(3, 1)))) + assert is_equal(expr_a, expr_b, math_rules, iters=500) + + benchmark(is_equal, expr_a, expr_b, math_rules, iters=500) + + +def replace_add(egraph, enode, **kwargs): + node = egraph.extract(enode) + enode = egraph.add(node) + return enode + + +def test_dynamic_rewrite(): + rules = [Rewrite(Add[x, Mul[z, y]], replace_add)] + node = Add(1, Mul(2, 3)) + + egraph = EGraph() + egraph.add(node) + egraph.run(rules, 100) + best = egraph.extract(node) + + assert best == node + + +def test_dynamic_condition(): + pass diff --git a/ibis/common/tests/test_validators.py b/ibis/common/tests/test_validators.py index 8f368546bf72..e203f1cdc272 100644 --- a/ibis/common/tests/test_validators.py +++ b/ibis/common/tests/test_validators.py @@ -35,6 +35,7 @@ mapping_of, min_, pair_of, + ref, sequence_of, str_, tuple_of, @@ -43,6 +44,10 @@ T = TypeVar("T") +def test_ref(): + assert ref("b", this={"a": 1, "b": 2}) == 2 + + @pytest.mark.parametrize( ('validator', 'value', 'expected'), [ diff --git a/ibis/expr/operations/core.py b/ibis/expr/operations/core.py index 0e4f5dc734a4..94e258ef10b8 100644 --- a/ibis/expr/operations/core.py +++ b/ibis/expr/operations/core.py @@ -35,9 +35,6 @@ def to_expr(self): # Avoid custom repr for performance reasons __repr__ = object.__repr__ - def __rich_repr__(self): - return zip(self.__argnames__, self.__args__) - @public class Named(ABC): From 0ed6dea3707c363f6b88f7172ecc2b438fd7597c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Thu, 20 Apr 2023 11:04:03 +0200 Subject: [PATCH 2/3] refactor(common): move `ibis.collections.DisjointSet` to `ibis.common.egraph` --- ibis/common/collections.py | 223 +------------------------ ibis/common/egraph.py | 229 +++++++++++++++++++++++++- ibis/common/tests/test_collections.py | 72 +------- ibis/common/tests/test_egraph.py | 73 +++++++- 4 files changed, 300 insertions(+), 297 deletions(-) diff --git a/ibis/common/collections.py b/ibis/common/collections.py index 3a56872afd9b..2793042b011d 100644 --- a/ibis/common/collections.py +++ b/ibis/common/collections.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import MappingProxyType -from typing import Any, Hashable, Iterable, Iterator, Mapping, Set, TypeVar +from typing import Any, Hashable, Mapping, TypeVar from public import public from typing_extensions import Self @@ -208,225 +208,4 @@ def __repr__(self): return f"{self.__class__.__name__}({super().__repr__()})" -class DisjointSet(Mapping[K, Set[K]]): - """Disjoint set data structure. - - Also known as union-find data structure. It is a data structure that keeps - track of a set of elements partitioned into a number of disjoint (non-overlapping) - subsets. It provides near-constant-time operations to add new sets, to merge - existing sets, and to determine whether elements are in the same set. - - Parameters - ---------- - data : - Initial data to add to the disjoint set. - - Examples - -------- - >>> ds = DisjointSet() - >>> ds.add(1) - 1 - >>> ds.add(2) - 2 - >>> ds.add(3) - 3 - >>> ds.union(1, 2) - True - >>> ds.union(2, 3) - True - >>> ds.find(1) - 1 - >>> ds.find(2) - 1 - >>> ds.find(3) - 1 - >>> ds.union(1, 3) - False - """ - - __slots__ = ("_parents", "_classes") - - def __init__(self, data: Iterable[K] | None = None): - self._parents = {} - self._classes = {} - if data is not None: - for id in data: - self.add(id) - - def __contains__(self, id) -> bool: - """Check if the given id is in the disjoint set. - - Parameters - ---------- - id : - The id to check. - - Returns - ------- - ined: - True if the id is in the disjoint set, False otherwise. - """ - return id in self._parents - - def __getitem__(self, id) -> set[K]: - """Get the set of ids that are in the same class as the given id. - - Parameters - ---------- - id : - The id to get the class for. - - Returns - ------- - class: - The set of ids that are in the same class as the given id, including - the given id. - """ - id = self._parents[id] - return self._classes[id] - - def __iter__(self) -> Iterator[K]: - """Iterate over the ids in the disjoint set.""" - return iter(self._parents) - - def __len__(self) -> int: - """Get the number of ids in the disjoint set.""" - return len(self._parents) - - def __eq__(self, other: Self[K]) -> bool: - """Check if the disjoint set is equal to another disjoint set. - - Parameters - ---------- - other : - The other disjoint set to compare to. - - Returns - ------- - equal: - True if the disjoint sets are equal, False otherwise. - """ - if not isinstance(other, DisjointSet): - return NotImplemented - return self._parents == other._parents - - def add(self, id: K) -> K: - """Add a new id to the disjoint set. - - If the id is not in the disjoint set, it will be added to the disjoint set - along with a new class containing only the given id. - - Parameters - ---------- - id : - The id to add to the disjoint set. - - Returns - ------- - id: - The id that was added to the disjoint set. - """ - if id in self._parents: - return self._parents[id] - self._parents[id] = id - self._classes[id] = {id} - return id - - def find(self, id: K) -> K: - """Find the root of the class that the given id is in. - - Also called as the canonicalized id or the representative id. - - Parameters - ---------- - id : - The id to find the canonicalized id for. - - Returns - ------- - id: - The canonicalized id for the given id. - """ - return self._parents[id] - - def union(self, id1, id2) -> bool: - """Merge the classes that the given ids are in. - - If the ids are already in the same class, this will return False. Otherwise - it will merge the classes and return True. - - Parameters - ---------- - id1 : - The first id to merge the classes for. - id2 : - The second id to merge the classes for. - - Returns - ------- - merged: - True if the classes were merged, False otherwise. - """ - # Find the root of each class - id1 = self._parents[id1] - id2 = self._parents[id2] - if id1 == id2: - return False - - # Merge the smaller eclass into the larger one, aka. union-find by size - class1 = self._classes[id1] - class2 = self._classes[id2] - if len(class1) >= len(class2): - id1, id2 = id2, id1 - class1, class2 = class2, class1 - - # Update the parent pointers, this is called path compression but done - # during the union operation to keep the find operation minimal - for id in class1: - self._parents[id] = id2 - - # Do the actual merging and clear the other eclass - class2 |= class1 - class1.clear() - - return True - - def connected(self, id1, id2): - """Check if the given ids are in the same class. - - True if both ids have the same canonicalized id, False otherwise. - - Parameters - ---------- - id1 : - The first id to check. - id2 : - The second id to check. - - Returns - ------- - connected: - True if the ids are connected, False otherwise. - """ - return self._parents[id1] == self._parents[id2] - - def verify(self): - """Verify that the disjoint set is not corrupted. - - Check that each id's canonicalized id's class. In general corruption - should not happen if the public API is used, but this is a sanity check - to make sure that the internal data structures are not corrupted. - - Returns - ------- - verified: - True if the disjoint set is not corrupted, False otherwise. - """ - for id in self._parents: - if id not in self._classes[self._parents[id]]: - raise RuntimeError( - f"DisjointSet is corrupted: {id} is not in its class" - ) - - public(frozendict=FrozenDict, dotdict=DotDict) diff --git a/ibis/common/egraph.py b/ibis/common/egraph.py index 53259d26a8df..8e0e26ade400 100644 --- a/ibis/common/egraph.py +++ b/ibis/common/egraph.py @@ -3,12 +3,237 @@ import collections import itertools import math -from typing import Any +from collections.abc import Iterable, Iterator, Mapping, Set +from typing import Any, Hashable, TypeVar + +from typing_extensions import Self -from ibis.common.collections import DisjointSet from ibis.common.graph import Node from ibis.util import promote_list +K = TypeVar("K", bound=Hashable) + + +class DisjointSet(Mapping[K, Set[K]]): + """Disjoint set data structure. + + Also known as union-find data structure. It is a data structure that keeps + track of a set of elements partitioned into a number of disjoint (non-overlapping) + subsets. It provides near-constant-time operations to add new sets, to merge + existing sets, and to determine whether elements are in the same set. + + Parameters + ---------- + data : + Initial data to add to the disjoint set. + + Examples + -------- + >>> ds = DisjointSet() + >>> ds.add(1) + 1 + >>> ds.add(2) + 2 + >>> ds.add(3) + 3 + >>> ds.union(1, 2) + True + >>> ds.union(2, 3) + True + >>> ds.find(1) + 1 + >>> ds.find(2) + 1 + >>> ds.find(3) + 1 + >>> ds.union(1, 3) + False + """ + + __slots__ = ("_parents", "_classes") + + def __init__(self, data: Iterable[K] | None = None): + self._parents = {} + self._classes = {} + if data is not None: + for id in data: + self.add(id) + + def __contains__(self, id) -> bool: + """Check if the given id is in the disjoint set. + + Parameters + ---------- + id : + The id to check. + + Returns + ------- + ined: + True if the id is in the disjoint set, False otherwise. + """ + return id in self._parents + + def __getitem__(self, id) -> set[K]: + """Get the set of ids that are in the same class as the given id. + + Parameters + ---------- + id : + The id to get the class for. + + Returns + ------- + class: + The set of ids that are in the same class as the given id, including + the given id. + """ + id = self._parents[id] + return self._classes[id] + + def __iter__(self) -> Iterator[K]: + """Iterate over the ids in the disjoint set.""" + return iter(self._parents) + + def __len__(self) -> int: + """Get the number of ids in the disjoint set.""" + return len(self._parents) + + def __eq__(self, other: Self[K]) -> bool: + """Check if the disjoint set is equal to another disjoint set. + + Parameters + ---------- + other : + The other disjoint set to compare to. + + Returns + ------- + equal: + True if the disjoint sets are equal, False otherwise. + """ + if not isinstance(other, DisjointSet): + return NotImplemented + return self._parents == other._parents + + def add(self, id: K) -> K: + """Add a new id to the disjoint set. + + If the id is not in the disjoint set, it will be added to the disjoint set + along with a new class containing only the given id. + + Parameters + ---------- + id : + The id to add to the disjoint set. + + Returns + ------- + id: + The id that was added to the disjoint set. + """ + if id in self._parents: + return self._parents[id] + self._parents[id] = id + self._classes[id] = {id} + return id + + def find(self, id: K) -> K: + """Find the root of the class that the given id is in. + + Also called as the canonicalized id or the representative id. + + Parameters + ---------- + id : + The id to find the canonicalized id for. + + Returns + ------- + id: + The canonicalized id for the given id. + """ + return self._parents[id] + + def union(self, id1, id2) -> bool: + """Merge the classes that the given ids are in. + + If the ids are already in the same class, this will return False. Otherwise + it will merge the classes and return True. + + Parameters + ---------- + id1 : + The first id to merge the classes for. + id2 : + The second id to merge the classes for. + + Returns + ------- + merged: + True if the classes were merged, False otherwise. + """ + # Find the root of each class + id1 = self._parents[id1] + id2 = self._parents[id2] + if id1 == id2: + return False + + # Merge the smaller eclass into the larger one, aka. union-find by size + class1 = self._classes[id1] + class2 = self._classes[id2] + if len(class1) >= len(class2): + id1, id2 = id2, id1 + class1, class2 = class2, class1 + + # Update the parent pointers, this is called path compression but done + # during the union operation to keep the find operation minimal + for id in class1: + self._parents[id] = id2 + + # Do the actual merging and clear the other eclass + class2 |= class1 + class1.clear() + + return True + + def connected(self, id1, id2): + """Check if the given ids are in the same class. + + True if both ids have the same canonicalized id, False otherwise. + + Parameters + ---------- + id1 : + The first id to check. + id2 : + The second id to check. + + Returns + ------- + connected: + True if the ids are connected, False otherwise. + """ + return self._parents[id1] == self._parents[id2] + + def verify(self): + """Verify that the disjoint set is not corrupted. + + Check that each id's canonicalized id's class. In general corruption + should not happen if the public API is used, but this is a sanity check + to make sure that the internal data structures are not corrupted. + + Returns + ------- + verified: + True if the disjoint set is not corrupted, False otherwise. + """ + for id in self._parents: + if id not in self._classes[self._parents[id]]: + raise RuntimeError( + f"DisjointSet is corrupted: {id} is not in its class" + ) + class Slotted: """A lightweight alternative to `ibis.common.grounds.Concrete`. diff --git a/ibis/common/tests/test_collections.py b/ibis/common/tests/test_collections.py index c24129481f49..d7b0b389b3f4 100644 --- a/ibis/common/tests/test_collections.py +++ b/ibis/common/tests/test_collections.py @@ -2,7 +2,7 @@ import pytest -from ibis.common.collections import DisjointSet, DotDict, FrozenDict, MapSet +from ibis.common.collections import DotDict, FrozenDict, MapSet from ibis.tests.util import assert_pickle_roundtrip @@ -221,73 +221,3 @@ def test_frozendict(): assert hash(d) assert_pickle_roundtrip(d) - - -def test_disjoint_set(): - ds = DisjointSet() - ds.add(1) - ds.add(2) - ds.add(3) - ds.add(4) - - ds1 = DisjointSet([1, 2, 3, 4]) - assert ds == ds1 - assert ds[1] == {1} - assert ds[2] == {2} - assert ds[3] == {3} - assert ds[4] == {4} - - assert ds.union(1, 2) is True - assert ds[1] == {1, 2} - assert ds[2] == {1, 2} - assert ds.union(2, 3) is True - assert ds[1] == {1, 2, 3} - assert ds[2] == {1, 2, 3} - assert ds[3] == {1, 2, 3} - assert ds.union(1, 3) is False - assert ds[4] == {4} - assert ds != ds1 - assert 1 in ds - assert 2 in ds - assert 5 not in ds - - assert ds.find(1) == 1 - assert ds.find(2) == 1 - assert ds.find(3) == 1 - assert ds.find(4) == 4 - - assert ds.connected(1, 2) is True - assert ds.connected(1, 3) is True - assert ds.connected(1, 4) is False - - # test mapping api get - assert ds.get(1) == {1, 2, 3} - assert ds.get(4) == {4} - assert ds.get(5) is None - assert ds.get(5, 5) == 5 - assert ds.get(5, default=5) == 5 - - # test mapping api keys - assert set(ds.keys()) == {1, 2, 3, 4} - assert set(ds) == {1, 2, 3, 4} - - # test mapping api values - assert tuple(ds.values()) == ({1, 2, 3}, {1, 2, 3}, {1, 2, 3}, {4}) - - # test mapping api items - assert tuple(ds.items()) == ( - (1, {1, 2, 3}), - (2, {1, 2, 3}), - (3, {1, 2, 3}), - (4, {4}), - ) - - # check that the disjoint set doesn't get corrupted by adding an existing element - ds.verify() - ds.add(1) - ds.verify() - - with pytest.raises(RuntimeError, match="DisjointSet is corrupted"): - ds._parents[1] = 1 - ds._classes[1] = {1} - ds.verify() diff --git a/ibis/common/tests/test_egraph.py b/ibis/common/tests/test_egraph.py index 22315b2821b2..11f93f014662 100644 --- a/ibis/common/tests/test_egraph.py +++ b/ibis/common/tests/test_egraph.py @@ -6,13 +6,82 @@ import ibis import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.common.collections import DisjointSet -from ibis.common.egraph import EGraph, ENode, Pattern, Rewrite, Variable +from ibis.common.egraph import DisjointSet, EGraph, ENode, Pattern, Rewrite, Variable from ibis.common.graph import Graph, Node from ibis.common.grounds import Concrete from ibis.util import promote_tuple +def test_disjoint_set(): + ds = DisjointSet() + ds.add(1) + ds.add(2) + ds.add(3) + ds.add(4) + + ds1 = DisjointSet([1, 2, 3, 4]) + assert ds == ds1 + assert ds[1] == {1} + assert ds[2] == {2} + assert ds[3] == {3} + assert ds[4] == {4} + + assert ds.union(1, 2) is True + assert ds[1] == {1, 2} + assert ds[2] == {1, 2} + assert ds.union(2, 3) is True + assert ds[1] == {1, 2, 3} + assert ds[2] == {1, 2, 3} + assert ds[3] == {1, 2, 3} + assert ds.union(1, 3) is False + assert ds[4] == {4} + assert ds != ds1 + assert 1 in ds + assert 2 in ds + assert 5 not in ds + + assert ds.find(1) == 1 + assert ds.find(2) == 1 + assert ds.find(3) == 1 + assert ds.find(4) == 4 + + assert ds.connected(1, 2) is True + assert ds.connected(1, 3) is True + assert ds.connected(1, 4) is False + + # test mapping api get + assert ds.get(1) == {1, 2, 3} + assert ds.get(4) == {4} + assert ds.get(5) is None + assert ds.get(5, 5) == 5 + assert ds.get(5, default=5) == 5 + + # test mapping api keys + assert set(ds.keys()) == {1, 2, 3, 4} + assert set(ds) == {1, 2, 3, 4} + + # test mapping api values + assert tuple(ds.values()) == ({1, 2, 3}, {1, 2, 3}, {1, 2, 3}, {4}) + + # test mapping api items + assert tuple(ds.items()) == ( + (1, {1, 2, 3}), + (2, {1, 2, 3}), + (3, {1, 2, 3}), + (4, {4}), + ) + + # check that the disjoint set doesn't get corrupted by adding an existing element + ds.verify() + ds.add(1) + ds.verify() + + with pytest.raises(RuntimeError, match="DisjointSet is corrupted"): + ds._parents[1] = 1 + ds._classes[1] = {1} + ds.verify() + + class PatternNamespace: def __init__(self, module): self.module = module From 8595f7b7b098651ae15fc56d78c1ec48b62024cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Thu, 20 Apr 2023 13:30:39 +0200 Subject: [PATCH 3/3] refactor(common): add sanity checks for creating ENodes and Patterns --- ibis/common/egraph.py | 42 +++++++++++++++----------------- ibis/common/tests/test_egraph.py | 2 +- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/ibis/common/egraph.py b/ibis/common/egraph.py index 8e0e26ade400..8dd6435d55c3 100644 --- a/ibis/common/egraph.py +++ b/ibis/common/egraph.py @@ -3,8 +3,7 @@ import collections import itertools import math -from collections.abc import Iterable, Iterator, Mapping, Set -from typing import Any, Hashable, TypeVar +from typing import Any, Hashable, Iterable, Iterator, Mapping, Set, TypeVar from typing_extensions import Self @@ -277,6 +276,11 @@ class Variable(Slotted): __slots__ = ("name",) + def __init__(self, name: str): + if name is None: + raise ValueError("Variable name cannot be None") + super().__init__(name) + def __repr__(self): return f"${self.name}" @@ -322,6 +326,8 @@ class Pattern(Slotted): # TODO(kszucs): consider to raise if the pattern matches none def __init__(self, head, args, name=None, conditions=None): + # TODO(kszucs): ensure that args are either patterns, variables or leaf values + assert all(not isinstance(arg, (ENode, Node)) for arg in args) super().__init__(head, tuple(args), name) def matches_none(self): @@ -354,19 +360,6 @@ def __rmatmul__(self, name): """Syntax sugar to create a named pattern.""" return self.__class__(self.head, self.args, name) - def to_enode(self): - """Convert the pattern to an ENode. - - None of the arguments can be a pattern or a variable. - - Returns - ------- - enode : ENode - The pattern converted to an ENode. - """ - # TODO(kszucs): ensure that self is a ground term - return ENode(self.head, self.args) - def flatten(self, var=None, counter=None): """Recursively flatten the pattern to a join of selections. @@ -447,7 +440,9 @@ class DynamicApplier(Slotted): def substitute(self, egraph, enode, subst): kwargs = {k: v for k, v in subst.items() if isinstance(k, str)} result = self.func(egraph, enode, **kwargs) - return result.to_enode() if isinstance(result, Pattern) else result + if not isinstance(result, ENode): + raise TypeError(f"applier must return an ENode, got {type(result)}") + return result class Rewrite(Slotted): @@ -482,6 +477,8 @@ class ENode(Slotted, Node): __slots__ = ("head", "args") def __init__(self, head, args): + # TODO(kszucs): ensure that it is a ground term, this check should be removed + assert all(not isinstance(arg, (Pattern, Variable)) for arg in args) super().__init__(head, tuple(args)) @property @@ -631,15 +628,16 @@ def _match_args(self, args, patargs): subst = {} for arg, patarg in zip(args, patargs): if isinstance(patarg, Variable): - if patarg.name is None: - pass - elif isinstance(arg, ENode): + if isinstance(arg, ENode): subst[patarg.name] = self._eclasses.find(arg) else: subst[patarg.name] = arg - elif isinstance(arg, ENode): - if self._eclasses.find(arg) != self._eclasses.find(arg): - return None + # TODO(kszucs): this is not needed since patarg is either a variable or a + # leaf value due to the pattern flattening, though we may choose to + # support this in the future + # elif isinstance(arg, ENode): + # if self._eclasses.find(arg) != self._eclasses.find(arg): + # return None elif patarg != arg: return None return subst diff --git a/ibis/common/tests/test_egraph.py b/ibis/common/tests/test_egraph.py index 11f93f014662..12cf528b6d07 100644 --- a/ibis/common/tests/test_egraph.py +++ b/ibis/common/tests/test_egraph.py @@ -427,7 +427,7 @@ def test_egraph_rewrite_to_pattern(): def test_egraph_rewrite_dynamic(): def applier(egraph, match, a, mul, times): - return p.Add(a, a).to_enode() + return ENode(ops.Add, (a, a)) node = (one * 2).op()