Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewriting: add attr_constr_rewrite_pattern #3439

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/pattern_rewriter/test_pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
i64,
)
from xdsl.ir import Block, Operation, SSAValue
from xdsl.irdl import BaseAttr
from xdsl.parser import Parser
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand All @@ -26,6 +27,7 @@
PatternRewriteWalker,
RewritePattern,
TypeConversionPattern,
attr_constr_rewrite_pattern,
attr_type_rewrite_pattern,
op_type_rewrite_pattern,
)
Expand Down Expand Up @@ -1713,3 +1715,37 @@ def match_and_rewrite(self, matched_op: test.TestOp, rewriter: PatternRewriter):
walker = PatternRewriteWalker(Rewrite())
with pytest.raises(ValueError, match=re.escape(expected)):
walker.rewrite_module(module)


def test_attr_constr_rewrite_pattern():
prog = """\
"builtin.module"() ({
"func.func"() <{"function_type" = (memref<2x4xui16>) -> (), "sym_name" = "main", "sym_visibility" = "private"}> ({
^bb0(%arg0 : memref<2x4xui16>):
"func.return"() : () -> ()
}) : () -> ()
}) : () -> ()
"""
expected_prog = """\
"builtin.module"() ({
"func.func"() <{"function_type" = (memref<2x4xindex>) -> (), "sym_name" = "main", "sym_visibility" = "private"}> ({
^0(%arg0 : memref<2x4xindex>):
"func.return"() : () -> ()
}) : () -> ()
}) : () -> ()
"""

class IndexConversion(TypeConversionPattern):
@attr_constr_rewrite_pattern(BaseAttr(IntegerType))
def convert_type(self, typ: IntegerType) -> IndexType:
return IndexType()

rewrite_and_compare(
prog,
expected_prog,
PatternRewriteWalker(IndexConversion(recursive=True)),
op_inserted=1,
op_removed=1,
op_replaced=1,
op_modified=1,
)
40 changes: 31 additions & 9 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
Region,
SSAValue,
)
from xdsl.irdl import GenericAttrConstraint, base
from xdsl.rewriter import InsertPoint, Rewriter
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr


@dataclass(eq=False)
Expand Down Expand Up @@ -551,8 +553,34 @@ def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter):
_ConvertedT = TypeVar("_ConvertedT", bound=Attribute)


def attr_constr_rewrite_pattern(
constr: GenericAttrConstraint[_AttributeT],
) -> Callable[
[Callable[[_TypeConversionPatternT, _AttributeT], Attribute | None]],
Callable[[_TypeConversionPatternT, Attribute], Attribute | None],
]:
"""
This function is intended to be used as a decorator on a TypeConversionPattern
method. It uses the passed constraint to match on a specific attribute type before
calling the decorated function.
"""

def wrapper(
func: Callable[[_TypeConversionPatternT, _AttributeT], _ConvertedT | None],
):
@wraps(func)
def impl(self: _TypeConversionPatternT, typ: Attribute) -> Attribute | None:
if isattr(typ, constr):
return func(self, typ)
return None

return impl

return wrapper


def attr_type_rewrite_pattern(
func: Callable[[_TypeConversionPatternT, _AttributeT], _ConvertedT | None],
func: Callable[[_TypeConversionPatternT, _AttributeT], Attribute | None],
) -> Callable[[_TypeConversionPatternT, Attribute], Attribute | None]:
"""
This function is intended to be used as a decorator on a TypeConversionPattern
Expand All @@ -561,14 +589,8 @@ def attr_type_rewrite_pattern(
"""
params = list(inspect.signature(func).parameters.values())
expected_type: type[_AttributeT] = params[-1].annotation

@wraps(func)
def impl(self: _TypeConversionPatternT, typ: Attribute) -> Attribute | None:
if isa(typ, expected_type):
return func(self, typ)
return None

return impl
constr = base(expected_type)
return attr_constr_rewrite_pattern(constr)(func)


@dataclass(eq=False, repr=False)
Expand Down
Loading