diff --git a/tests/filecheck/transforms/eqsat-create-eclasses.mlir b/tests/filecheck/transforms/eqsat-create-eclasses.mlir new file mode 100644 index 0000000000..d7a45d9c37 --- /dev/null +++ b/tests/filecheck/transforms/eqsat-create-eclasses.mlir @@ -0,0 +1,13 @@ +// RUN: xdsl-opt -p eqsat-create-eclasses %s | filecheck %s + +func.func @test(%x : index) -> (index) { + %c2 = arith.constant 2 : index + func.return %c2 : index +} + +// CHECK: func.func @test(%x : index) -> index { +// CHECK-NEXT: %x_1 = eqsat.eclass %x : index +// CHECK-NEXT: %c2 = arith.constant 2 : index +// CHECK-NEXT: %c2_1 = eqsat.eclass %c2 : index +// CHECK-NEXT: func.return %c2_1 : index +// CHECK-NEXT: } diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index ac20b02ca8..149c7eb720 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -420,6 +420,11 @@ def get_stencil_bufferize(): return stencil_bufferize.StencilBufferize + def get_eqsat_create_eclasses(): + from xdsl.transforms import eqsat_create_eclasses + + return eqsat_create_eclasses.EqsatCreateEclasses + def get_stencil_shape_minimize(): from xdsl.transforms import stencil_shape_minimize @@ -507,6 +512,7 @@ def get_stencil_shape_minimize(): "stencil-bufferize": get_stencil_bufferize, "stencil-shape-minimize": get_stencil_shape_minimize, "test-lower-linalg-to-snitch": get_test_lower_linalg_to_snitch, + "eqsat-create-eclasses": get_eqsat_create_eclasses, } diff --git a/xdsl/transforms/eqsat_create_eclasses.py b/xdsl/transforms/eqsat_create_eclasses.py new file mode 100644 index 0000000000..f6a05979b1 --- /dev/null +++ b/xdsl/transforms/eqsat_create_eclasses.py @@ -0,0 +1,84 @@ +from xdsl.context import MLContext +from xdsl.dialects import builtin, eqsat, func +from xdsl.ir import Block +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.rewriter import InsertPoint, Rewriter +from xdsl.utils.exceptions import DiagnosticException + + +def insert_eclass_ops(block: Block): + # Insert eqsat.eclass for each operation + for op in block.ops: + results = op.results + + # Skip special ops such as return ops + if isinstance(op, func.Return): + continue + + if len(results) != 1: + raise DiagnosticException("Ops with non-single results not handled") + + eclass_op = eqsat.EClassOp(results[0]) + insertion_point = InsertPoint.after(op) + Rewriter.insert_op(eclass_op, insertion_point) + results[0].replace_by_if( + eclass_op.results[0], lambda u: not isinstance(u.operation, eqsat.EClassOp) + ) + + # Insert eqsat.eclass for each arg + for arg in block.args: + eclass_op = eqsat.EClassOp(arg) + insertion_point = InsertPoint.at_start(block) + Rewriter.insert_op(eclass_op, insertion_point) + arg.replace_by_if( + eclass_op.results[0], lambda u: not isinstance(u.operation, eqsat.EClassOp) + ) + + +class InsertEclassOps(RewritePattern): + """ + Inserts a `eqsat.eclass` after each operation except module op and function op. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): + insert_eclass_ops(op.body.block) + + +class EqsatCreateEclasses(ModulePass): + """ + Create initial eclasses from an MLIR program. + + Input example: + ```mlir + func.func @test(%a : index, %b : index) -> (index) { + %c = arith.addi %a, %b : index + func.return %c : index + } + ``` + Output example: + ```mlir + func.func @test(%a : index, %b : index) -> (index) { + %a_eq = eqsat.eclass %a : index + %b_eq = eqsat.eclass %b : index + %c = arith.addi %a_eq, %b_eq : index + %c_eq = eqsat.eclass %c : index + func.return %c_eq : index + } + ``` + """ + + name = "eqsat-create-eclasses" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + PatternRewriteWalker( + GreedyRewritePatternApplier([InsertEclassOps()]), + apply_recursively=False, + ).rewrite_module(op)