Skip to content

Commit

Permalink
transformations: (eqsat) add pass to convert non-eclass functions to …
Browse files Browse the repository at this point in the history
…eclass (xdslproject#3189)

This PR addresses xdslproject#3170:

- [x] Added initial front end pass `eqsat-create-eclasses` for the
minimal example
- [x] Added an initial test case for the pass
  • Loading branch information
jianyicheng authored and EdmundGoodman committed Dec 6, 2024
1 parent 8a9dd05 commit 7fb573d
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tests/filecheck/transforms/eqsat-create-eclasses.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: xdsl-opt -p eqsat-create-eclasses %s | filecheck %s

func.func @test(%x : index) -> (index) {
%c2 = arith.constant 2 : index
func.return %c2 : index
}

// CHECK: func.func @test(%x : index) -> index {
// CHECK-NEXT: %x_1 = eqsat.eclass %x : index
// CHECK-NEXT: %c2 = arith.constant 2 : index
// CHECK-NEXT: %c2_1 = eqsat.eclass %c2 : index
// CHECK-NEXT: func.return %c2_1 : index
// CHECK-NEXT: }
6 changes: 6 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,11 @@ def get_stencil_bufferize():

return stencil_bufferize.StencilBufferize

def get_eqsat_create_eclasses():
from xdsl.transforms import eqsat_create_eclasses

return eqsat_create_eclasses.EqsatCreateEclasses

def get_stencil_shape_minimize():
from xdsl.transforms import stencil_shape_minimize

Expand Down Expand Up @@ -507,6 +512,7 @@ def get_stencil_shape_minimize():
"stencil-bufferize": get_stencil_bufferize,
"stencil-shape-minimize": get_stencil_shape_minimize,
"test-lower-linalg-to-snitch": get_test_lower_linalg_to_snitch,
"eqsat-create-eclasses": get_eqsat_create_eclasses,
}


Expand Down
84 changes: 84 additions & 0 deletions xdsl/transforms/eqsat_create_eclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from xdsl.context import MLContext
from xdsl.dialects import builtin, eqsat, func
from xdsl.ir import Block
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint, Rewriter
from xdsl.utils.exceptions import DiagnosticException


def insert_eclass_ops(block: Block):
# Insert eqsat.eclass for each operation
for op in block.ops:
results = op.results

# Skip special ops such as return ops
if isinstance(op, func.Return):
continue

if len(results) != 1:
raise DiagnosticException("Ops with non-single results not handled")

eclass_op = eqsat.EClassOp(results[0])
insertion_point = InsertPoint.after(op)
Rewriter.insert_op(eclass_op, insertion_point)
results[0].replace_by_if(
eclass_op.results[0], lambda u: not isinstance(u.operation, eqsat.EClassOp)
)

# Insert eqsat.eclass for each arg
for arg in block.args:
eclass_op = eqsat.EClassOp(arg)
insertion_point = InsertPoint.at_start(block)
Rewriter.insert_op(eclass_op, insertion_point)
arg.replace_by_if(
eclass_op.results[0], lambda u: not isinstance(u.operation, eqsat.EClassOp)
)


class InsertEclassOps(RewritePattern):
"""
Inserts a `eqsat.eclass` after each operation except module op and function op.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
insert_eclass_ops(op.body.block)


class EqsatCreateEclasses(ModulePass):
"""
Create initial eclasses from an MLIR program.
Input example:
```mlir
func.func @test(%a : index, %b : index) -> (index) {
%c = arith.addi %a, %b : index
func.return %c : index
}
```
Output example:
```mlir
func.func @test(%a : index, %b : index) -> (index) {
%a_eq = eqsat.eclass %a : index
%b_eq = eqsat.eclass %b : index
%c = arith.addi %a_eq, %b_eq : index
%c_eq = eqsat.eclass %c : index
func.return %c_eq : index
}
```
"""

name = "eqsat-create-eclasses"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier([InsertEclassOps()]),
apply_recursively=False,
).rewrite_module(op)

0 comments on commit 7fb573d

Please sign in to comment.