Skip to content

Commit

Permalink
refactor(common): add sanity checks for creating ENodes and Patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Apr 20, 2023
1 parent 0ed6dea commit 8595f7b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 23 deletions.
42 changes: 20 additions & 22 deletions ibis/common/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import collections
import itertools
import math
from collections.abc import Iterable, Iterator, Mapping, Set
from typing import Any, Hashable, TypeVar
from typing import Any, Hashable, Iterable, Iterator, Mapping, Set, TypeVar

from typing_extensions import Self

Expand Down Expand Up @@ -277,6 +276,11 @@ class Variable(Slotted):

__slots__ = ("name",)

def __init__(self, name: str):
if name is None:
raise ValueError("Variable name cannot be None")
super().__init__(name)

def __repr__(self):
return f"${self.name}"

Expand Down Expand Up @@ -322,6 +326,8 @@ class Pattern(Slotted):

# TODO(kszucs): consider to raise if the pattern matches none
def __init__(self, head, args, name=None, conditions=None):
# TODO(kszucs): ensure that args are either patterns, variables or leaf values
assert all(not isinstance(arg, (ENode, Node)) for arg in args)
super().__init__(head, tuple(args), name)

def matches_none(self):
Expand Down Expand Up @@ -354,19 +360,6 @@ def __rmatmul__(self, name):
"""Syntax sugar to create a named pattern."""
return self.__class__(self.head, self.args, name)

def to_enode(self):
"""Convert the pattern to an ENode.
None of the arguments can be a pattern or a variable.
Returns
-------
enode : ENode
The pattern converted to an ENode.
"""
# TODO(kszucs): ensure that self is a ground term
return ENode(self.head, self.args)

def flatten(self, var=None, counter=None):
"""Recursively flatten the pattern to a join of selections.
Expand Down Expand Up @@ -447,7 +440,9 @@ class DynamicApplier(Slotted):
def substitute(self, egraph, enode, subst):
kwargs = {k: v for k, v in subst.items() if isinstance(k, str)}
result = self.func(egraph, enode, **kwargs)
return result.to_enode() if isinstance(result, Pattern) else result
if not isinstance(result, ENode):
raise TypeError(f"applier must return an ENode, got {type(result)}")
return result


class Rewrite(Slotted):
Expand Down Expand Up @@ -482,6 +477,8 @@ class ENode(Slotted, Node):
__slots__ = ("head", "args")

def __init__(self, head, args):
# TODO(kszucs): ensure that it is a ground term, this check should be removed
assert all(not isinstance(arg, (Pattern, Variable)) for arg in args)
super().__init__(head, tuple(args))

@property
Expand Down Expand Up @@ -631,15 +628,16 @@ def _match_args(self, args, patargs):
subst = {}
for arg, patarg in zip(args, patargs):
if isinstance(patarg, Variable):
if patarg.name is None:
pass
elif isinstance(arg, ENode):
if isinstance(arg, ENode):
subst[patarg.name] = self._eclasses.find(arg)
else:
subst[patarg.name] = arg
elif isinstance(arg, ENode):
if self._eclasses.find(arg) != self._eclasses.find(arg):
return None
# TODO(kszucs): this is not needed since patarg is either a variable or a
# leaf value due to the pattern flattening, though we may choose to
# support this in the future
# elif isinstance(arg, ENode):
# if self._eclasses.find(arg) != self._eclasses.find(arg):
# return None
elif patarg != arg:
return None
return subst
Expand Down
2 changes: 1 addition & 1 deletion ibis/common/tests/test_egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def test_egraph_rewrite_to_pattern():

def test_egraph_rewrite_dynamic():
def applier(egraph, match, a, mul, times):
return p.Add(a, a).to_enode()
return ENode(ops.Add, (a, a))

node = (one * 2).op()

Expand Down

0 comments on commit 8595f7b

Please sign in to comment.