diff --git a/tests/pattern_rewriter/test_pattern_rewriter.py b/tests/pattern_rewriter/test_pattern_rewriter.py index 90dd6ce6a1..a7830d91c2 100644 --- a/tests/pattern_rewriter/test_pattern_rewriter.py +++ b/tests/pattern_rewriter/test_pattern_rewriter.py @@ -18,6 +18,7 @@ i64, ) from xdsl.ir import Block, Operation, SSAValue +from xdsl.irdl import BaseAttr from xdsl.parser import Parser from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -26,6 +27,7 @@ PatternRewriteWalker, RewritePattern, TypeConversionPattern, + attr_constr_rewrite_pattern, attr_type_rewrite_pattern, op_type_rewrite_pattern, ) @@ -1713,3 +1715,37 @@ def match_and_rewrite(self, matched_op: test.TestOp, rewriter: PatternRewriter): walker = PatternRewriteWalker(Rewrite()) with pytest.raises(ValueError, match=re.escape(expected)): walker.rewrite_module(module) + + +def test_attr_constr_rewrite_pattern(): + prog = """\ +"builtin.module"() ({ + "func.func"() <{"function_type" = (memref<2x4xui16>) -> (), "sym_name" = "main", "sym_visibility" = "private"}> ({ + ^bb0(%arg0 : memref<2x4xui16>): + "func.return"() : () -> () + }) : () -> () +}) : () -> () +""" + expected_prog = """\ +"builtin.module"() ({ + "func.func"() <{"function_type" = (memref<2x4xindex>) -> (), "sym_name" = "main", "sym_visibility" = "private"}> ({ + ^0(%arg0 : memref<2x4xindex>): + "func.return"() : () -> () + }) : () -> () +}) : () -> () +""" + + class IndexConversion(TypeConversionPattern): + @attr_constr_rewrite_pattern(BaseAttr(IntegerType)) + def convert_type(self, typ: IntegerType) -> IndexType: + return IndexType() + + rewrite_and_compare( + prog, + expected_prog, + PatternRewriteWalker(IndexConversion(recursive=True)), + op_inserted=1, + op_removed=1, + op_replaced=1, + op_modified=1, + ) diff --git a/xdsl/pattern_rewriter.py b/xdsl/pattern_rewriter.py index 7a49a071da..866b5e7f8b 100644 --- a/xdsl/pattern_rewriter.py +++ b/xdsl/pattern_rewriter.py @@ -22,8 +22,10 @@ Region, SSAValue, ) +from xdsl.irdl import GenericAttrConstraint, base from xdsl.rewriter import InsertPoint, Rewriter from xdsl.utils.hints import isa +from xdsl.utils.isattr import isattr @dataclass(eq=False) @@ -551,8 +553,34 @@ def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter): _ConvertedT = TypeVar("_ConvertedT", bound=Attribute) +def attr_constr_rewrite_pattern( + constr: GenericAttrConstraint[_AttributeT], +) -> Callable[ + [Callable[[_TypeConversionPatternT, _AttributeT], Attribute | None]], + Callable[[_TypeConversionPatternT, Attribute], Attribute | None], +]: + """ + This function is intended to be used as a decorator on a TypeConversionPattern + method. It uses the passed constraint to match on a specific attribute type before + calling the decorated function. + """ + + def wrapper( + func: Callable[[_TypeConversionPatternT, _AttributeT], _ConvertedT | None], + ): + @wraps(func) + def impl(self: _TypeConversionPatternT, typ: Attribute) -> Attribute | None: + if isattr(typ, constr): + return func(self, typ) + return None + + return impl + + return wrapper + + def attr_type_rewrite_pattern( - func: Callable[[_TypeConversionPatternT, _AttributeT], _ConvertedT | None], + func: Callable[[_TypeConversionPatternT, _AttributeT], Attribute | None], ) -> Callable[[_TypeConversionPatternT, Attribute], Attribute | None]: """ This function is intended to be used as a decorator on a TypeConversionPattern @@ -561,14 +589,8 @@ def attr_type_rewrite_pattern( """ params = list(inspect.signature(func).parameters.values()) expected_type: type[_AttributeT] = params[-1].annotation - - @wraps(func) - def impl(self: _TypeConversionPatternT, typ: Attribute) -> Attribute | None: - if isa(typ, expected_type): - return func(self, typ) - return None - - return impl + constr = base(expected_type) + return attr_constr_rewrite_pattern(constr)(func) @dataclass(eq=False, repr=False)