diff --git a/tests/filecheck/transforms/arith-add-immediate-zero.mlir b/tests/filecheck/transforms/arith-add-immediate-zero.mlir new file mode 100644 index 0000000000..603e84c7a5 --- /dev/null +++ b/tests/filecheck/transforms/arith-add-immediate-zero.mlir @@ -0,0 +1,20 @@ +// RUN: xdsl-opt %s -p canonicalize | filecheck %s + +func.func @hello(%n : i32) -> i32 { + %two = arith.constant 0 : i32 + %three = arith.constant 0 : i32 + %res = arith.addi %two, %n : i32 + %res2 = arith.addi %three, %res : i32 + func.return %res : i32 +} + + +//CHECK: builtin.module { +// CHECK-NEXT: func.func @hello(%n : i32) -> i32 { +// CHECK-NEXT: %two = arith.constant 0 : i32 +// CHECK-NEXT: %three = arith.constant 0 : i32 +// CHECK-NEXT: func.return %n : i32 +// CHECK-NEXT: } +// CHECK-NEXT: } + + diff --git a/tests/filecheck/transforms/individual_rewrite.mlir b/tests/filecheck/transforms/individual_rewrite.mlir new file mode 100644 index 0000000000..8f36475eec --- /dev/null +++ b/tests/filecheck/transforms/individual_rewrite.mlir @@ -0,0 +1,16 @@ +// RUN:xdsl-opt %s -p 'apply-individual-rewrite{matched_operation_index=4 operation_name="riscv.add" pattern_name="AddImmediates"}'| filecheck %s + +%a = riscv.li 1 : () -> !riscv.reg<> +%b = riscv.li 2 : () -> !riscv.reg<> +%c = riscv.li 3 : () -> !riscv.reg<> +%d = riscv.add %a, %b : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> +%e = riscv.add %b, %c : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> + + +//CHECK: builtin.module { +// CHECK-NEXT: %a = riscv.li 1 : () -> !riscv.reg<> +// CHECK-NEXT: %b = riscv.li 2 : () -> !riscv.reg<> +// CHECK-NEXT: %c = riscv.li 3 : () -> !riscv.reg<> +// CHECK-NEXT: %d = riscv.li 3 : () -> !riscv.reg<> +// CHECK-NEXT: %e = riscv.add %b, %c : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> +// CHECK-NEXT: } diff --git a/tests/interactive/test_app.py b/tests/interactive/test_app.py index 3ea455c6d8..a28c82d0d4 100644 --- a/tests/interactive/test_app.py +++ b/tests/interactive/test_app.py @@ -17,6 +17,7 @@ from xdsl.interactive.app import InputApp from xdsl.ir import Block, Region from xdsl.transforms import ( + individual_rewrite, mlir_opt, printf_to_llvm, scf_parallel_loop_tiling, @@ -266,6 +267,7 @@ async def test_buttons(): condensed_list = tuple( ( + individual_rewrite.IndividualRewrite, convert_arith_to_riscv.ConvertArithToRiscvPass, convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass, stencil_global_to_local.DistributeStencilPass, diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 15425ef26a..3407aac5b5 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -34,8 +34,9 @@ result_def, ) from xdsl.parser import Parser +from xdsl.pattern_rewriter import RewritePattern from xdsl.printer import Printer -from xdsl.traits import ConstantLike, Pure +from xdsl.traits import ConstantLike, HasCanonicalisationPatternsTrait, Pure from xdsl.utils.deprecation import deprecated from xdsl.utils.exceptions import VerifyException from xdsl.utils.hints import isa @@ -263,11 +264,19 @@ def print(self, printer: Printer): IntegerBinaryOp = BinaryOperation[IntegerType] +class AddiOpHasCanonicalizationPatternsTrait(HasCanonicalisationPatternsTrait): + @classmethod + def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: + from xdsl.transforms.canonicalization_patterns.arith import AddImmediateZero + + return (AddImmediateZero(),) + + @irdl_op_definition class Addi(SignlessIntegerBinaryOp): name = "arith.addi" - traits = frozenset([Pure()]) + traits = frozenset([Pure(), AddiOpHasCanonicalizationPatternsTrait()]) @irdl_op_definition diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index b339a67e0d..85b2dcb121 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -333,6 +333,11 @@ def get_lower_halo_to_mpi(): return stencil_global_to_local.LowerHaloToMPI + def get_individual_rewrite(): + from xdsl.transforms.individual_rewrite import IndividualRewrite + + return IndividualRewrite + def get_lower_affine(): from xdsl.transforms import lower_affine @@ -489,6 +494,7 @@ def get_test_lower_linalg_to_snitch(): "frontend-desymrefy": get_desymrefy, "gpu-map-parallel-loops": get_gpu_map_parallel_loops, "hls-convert-stencil-to-ll-mlir": get_hls_convert_stencil_to_ll_mlir, + "apply-individual-rewrite": get_individual_rewrite, "lower-affine": get_lower_affine, "lower-hls": get_lower_hls, "lower-mpi": get_lower_mpi, diff --git a/xdsl/transforms/canonicalization_patterns/arith.py b/xdsl/transforms/canonicalization_patterns/arith.py new file mode 100644 index 0000000000..13f816b935 --- /dev/null +++ b/xdsl/transforms/canonicalization_patterns/arith.py @@ -0,0 +1,18 @@ +from xdsl.dialects import arith +from xdsl.dialects.builtin import IntegerAttr +from xdsl.pattern_rewriter import ( + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) + + +class AddImmediateZero(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: arith.Addi, rewriter: PatternRewriter) -> None: + if ( + isinstance(op.lhs.owner, arith.Constant) + and isinstance(value := op.lhs.owner.value, IntegerAttr) + and value.value.data == 0 + ): + rewriter.replace_matched_op([], [op.rhs]) diff --git a/xdsl/transforms/individual_rewrite.py b/xdsl/transforms/individual_rewrite.py new file mode 100644 index 0000000000..2bf30dd32b --- /dev/null +++ b/xdsl/transforms/individual_rewrite.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass + +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir import MLContext +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import PatternRewriter, RewritePattern +from xdsl.tools.command_line_tool import get_all_dialects +from xdsl.traits import HasCanonicalisationPatternsTrait + +REWRITE_BY_NAMES: dict[str, dict[str, RewritePattern]] = { + op.name: { + pattern.__class__.__name__: pattern + for pattern in trait.get_canonicalization_patterns() + } + for dialect in get_all_dialects().values() + for op in dialect().operations + if (trait := op.get_trait(HasCanonicalisationPatternsTrait)) is not None +} +""" +Returns a dictionary representing all possible rewrites. Keys are operation names, and +values are dictionaries. In the inner dictionary, the keys are names of patterns +associated with each operation, and the values are the corresponding RewritePattern +instances. +""" + + +@dataclass +class IndividualRewrite(ModulePass): + """ + Module pass representing the application of an individual rewrite pattern to a module. + + Matches the operation at the provided index within the module and applies the rewrite + pattern specified by the operation and pattern names. + """ + + name = "apply-individual-rewrite" + + matched_operation_index: int | None = None + operation_name: str | None = None + pattern_name: str | None = None + + def apply(self, ctx: MLContext, op: ModuleOp) -> None: + assert self.matched_operation_index is not None + assert self.operation_name is not None + assert self.pattern_name is not None + + matched_operation_list = list(op.walk()) + if self.matched_operation_index >= len(matched_operation_list): + raise ValueError("Matched operation index out of range.") + + matched_operation = list(op.walk())[self.matched_operation_index] + rewriter = PatternRewriter(matched_operation) + + rewrite_dictionary = REWRITE_BY_NAMES.get(self.operation_name) + if rewrite_dictionary is None: + raise ValueError( + f"Operation name {self.operation_name} not found in the rewrite dictionary." + ) + + pattern = rewrite_dictionary.get(self.pattern_name) + if pattern is None: + raise ValueError( + f"Pattern name {self.pattern_name} not found for the provided operation name." + ) + + pattern.match_and_rewrite(matched_operation, rewriter) + if not rewriter.has_done_action: + raise ValueError("Invalid rewrite at current location.")