Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (arith) adding an arith canonicalisation pattern #2094

Merged
merged 19 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/filecheck/transforms/arith-add-immediate-zero.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: xdsl-opt %s -p canonicalize | filecheck %s

func.func @hello(%n : i32) -> i32 {
%two = arith.constant 0 : i32
%three = arith.constant 0 : i32
%res = arith.addi %two, %n : i32
%res2 = arith.addi %three, %res : i32
func.return %res : i32
}


//CHECK: builtin.module {
// CHECK-NEXT: func.func @hello(%n : i32) -> i32 {
// CHECK-NEXT: %two = arith.constant 0 : i32
// CHECK-NEXT: %three = arith.constant 0 : i32
// CHECK-NEXT: func.return %n : i32
// CHECK-NEXT: }
// CHECK-NEXT: }


16 changes: 16 additions & 0 deletions tests/filecheck/transforms/individual_rewrite.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN:xdsl-opt %s -p 'apply-individual-rewrite{matched_operation_index=4 operation_name="riscv.add" pattern_name="AddImmediates"}'| filecheck %s

%a = riscv.li 1 : () -> !riscv.reg<>
%b = riscv.li 2 : () -> !riscv.reg<>
%c = riscv.li 3 : () -> !riscv.reg<>
%d = riscv.add %a, %b : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
%e = riscv.add %b, %c : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>


//CHECK: builtin.module {
// CHECK-NEXT: %a = riscv.li 1 : () -> !riscv.reg<>
// CHECK-NEXT: %b = riscv.li 2 : () -> !riscv.reg<>
// CHECK-NEXT: %c = riscv.li 3 : () -> !riscv.reg<>
// CHECK-NEXT: %d = riscv.li 3 : () -> !riscv.reg<>
// CHECK-NEXT: %e = riscv.add %b, %c : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: }
2 changes: 2 additions & 0 deletions tests/interactive/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from xdsl.interactive.app import InputApp
from xdsl.ir import Block, Region
from xdsl.transforms import (
individual_rewrite,
mlir_opt,
printf_to_llvm,
scf_parallel_loop_tiling,
Expand Down Expand Up @@ -266,6 +267,7 @@ async def test_buttons():

condensed_list = tuple(
(
individual_rewrite.IndividualRewrite,
convert_arith_to_riscv.ConvertArithToRiscvPass,
convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass,
stencil_global_to_local.DistributeStencilPass,
Expand Down
13 changes: 11 additions & 2 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
result_def,
)
from xdsl.parser import Parser
from xdsl.pattern_rewriter import RewritePattern
from xdsl.printer import Printer
from xdsl.traits import ConstantLike, Pure
from xdsl.traits import ConstantLike, HasCanonicalisationPatternsTrait, Pure
from xdsl.utils.deprecation import deprecated
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa
Expand Down Expand Up @@ -263,11 +264,19 @@ def print(self, printer: Printer):
IntegerBinaryOp = BinaryOperation[IntegerType]


class AddiOpHasCanonicalizationPatternsTrait(HasCanonicalisationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.arith import AddImmediateZero

return (AddImmediateZero(),)


@irdl_op_definition
class Addi(SignlessIntegerBinaryOp):
name = "arith.addi"

traits = frozenset([Pure()])
traits = frozenset([Pure(), AddiOpHasCanonicalizationPatternsTrait()])


@irdl_op_definition
Expand Down
6 changes: 6 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ def get_lower_halo_to_mpi():

return stencil_global_to_local.LowerHaloToMPI

def get_individual_rewrite():
from xdsl.transforms.individual_rewrite import IndividualRewrite

return IndividualRewrite

def get_lower_affine():
from xdsl.transforms import lower_affine

Expand Down Expand Up @@ -489,6 +494,7 @@ def get_test_lower_linalg_to_snitch():
"frontend-desymrefy": get_desymrefy,
"gpu-map-parallel-loops": get_gpu_map_parallel_loops,
"hls-convert-stencil-to-ll-mlir": get_hls_convert_stencil_to_ll_mlir,
"apply-individual-rewrite": get_individual_rewrite,
"lower-affine": get_lower_affine,
"lower-hls": get_lower_hls,
"lower-mpi": get_lower_mpi,
Expand Down
18 changes: 18 additions & 0 deletions xdsl/transforms/canonicalization_patterns/arith.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from xdsl.dialects import arith
from xdsl.dialects.builtin import IntegerAttr
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)


class AddImmediateZero(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.Addi, rewriter: PatternRewriter) -> None:
if (
isinstance(op.lhs.owner, arith.Constant)
and isinstance(value := op.lhs.owner.value, IntegerAttr)
and value.value.data == 0
):
rewriter.replace_matched_op([], [op.rhs])
68 changes: 68 additions & 0 deletions xdsl/transforms/individual_rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from dataclasses import dataclass

from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import PatternRewriter, RewritePattern
from xdsl.tools.command_line_tool import get_all_dialects
from xdsl.traits import HasCanonicalisationPatternsTrait

REWRITE_BY_NAMES: dict[str, dict[str, RewritePattern]] = {
op.name: {
pattern.__class__.__name__: pattern
for pattern in trait.get_canonicalization_patterns()
}
for dialect in get_all_dialects().values()
for op in dialect().operations
if (trait := op.get_trait(HasCanonicalisationPatternsTrait)) is not None
}
"""
Returns a dictionary representing all possible rewrites. Keys are operation names, and
values are dictionaries. In the inner dictionary, the keys are names of patterns
associated with each operation, and the values are the corresponding RewritePattern
instances.
"""


@dataclass
class IndividualRewrite(ModulePass):
"""
Module pass representing the application of an individual rewrite pattern to a module.

Matches the operation at the provided index within the module and applies the rewrite
pattern specified by the operation and pattern names.
"""

name = "apply-individual-rewrite"

matched_operation_index: int | None = None
operation_name: str | None = None
pattern_name: str | None = None

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
assert self.matched_operation_index is not None
assert self.operation_name is not None
assert self.pattern_name is not None

matched_operation_list = list(op.walk())
if self.matched_operation_index >= len(matched_operation_list):
raise ValueError("Matched operation index out of range.")

matched_operation = list(op.walk())[self.matched_operation_index]
rewriter = PatternRewriter(matched_operation)

rewrite_dictionary = REWRITE_BY_NAMES.get(self.operation_name)
if rewrite_dictionary is None:
raise ValueError(
f"Operation name {self.operation_name} not found in the rewrite dictionary."
)

pattern = rewrite_dictionary.get(self.pattern_name)
if pattern is None:
raise ValueError(
f"Pattern name {self.pattern_name} not found for the provided operation name."
)

pattern.match_and_rewrite(matched_operation, rewriter)
if not rewriter.has_done_action:
raise ValueError("Invalid rewrite at current location.")
Comment on lines +66 to +68
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, this feels like a case that should have a warning instead of an exception. It feels like this might make it hard to interact with this pass through scripts and stuff.
Though, I guess people consider exceptional control flow for basic stuff acceptable in Python, so this is just a side comment.

Loading