Skip to content

Commit

Permalink
transforms: Allow to pass a pattern rewriter in CSE
Browse files Browse the repository at this point in the history
Without passing the pattern rewriter, CSE couldn't be called inside
a pattern rewriter walker, as it would not notify the operations that
were deleted or replaced.
  • Loading branch information
math-fehr committed Nov 29, 2024
1 parent f9c786f commit 902ccce
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion xdsl/transforms/canonicalization_patterns/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N
continue
a.replace_by(bbargs[rbargs[i]])

cse(op.region.block)
cse(op.region.block, rewriter)


class ApplyUnusedOperands(RewritePattern):
Expand Down
18 changes: 9 additions & 9 deletions xdsl/transforms/common_subexpression_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from xdsl.dialects.builtin import ModuleOp, UnregisteredOp
from xdsl.ir import Block, Operation, Region, Use
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import PatternRewriter
from xdsl.rewriter import Rewriter
from xdsl.traits import (
IsolatedFromAbove,
Expand Down Expand Up @@ -115,19 +116,15 @@ def has_other_side_effecting_op_in_between(
return False


@dataclass
class CSEDriver:
"""
Boilerplate class to handle and carry the state for CSE.
"""

_rewriter: Rewriter
_rewriter: Rewriter | PatternRewriter = field(default_factory=Rewriter)
_to_erase: set[Operation] = field(default_factory=set)
_known_ops: KnownOps = KnownOps()

def __init__(self):
self._rewriter = Rewriter()
self._to_erase = set()
self._known_ops = KnownOps()
_known_ops: KnownOps = field(default_factory=KnownOps)

def _mark_erasure(self, op: Operation):
self._to_erase.add(op)
Expand Down Expand Up @@ -250,8 +247,11 @@ def simplify(self, thing: Operation | Block | Region):
self._commit_erasures()


def cse(thing: Operation | Block | Region):
CSEDriver().simplify(thing)
def cse(thing: Operation | Block | Region, rewriter: PatternRewriter | None = None):
if rewriter is not None:
CSEDriver(_rewriter=rewriter).simplify(thing)
else:
CSEDriver().simplify(thing)


class CommonSubexpressionElimination(ModulePass):
Expand Down
4 changes: 2 additions & 2 deletions xdsl/transforms/control_flow_hoist.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def match_and_rewrite(self, op: affine.IfOp, rewriter: PatternRewriter):
return
block = op.parent
if block:
cse(block)
cse(block, rewriter)


class SCFIfHoistPattern(RewritePattern):
Expand All @@ -84,7 +84,7 @@ def match_and_rewrite(self, op: scf.IfOp, rewriter: PatternRewriter):
block = op.parent
if block:
# If we hoisted some ops, run CSE on that block to not keep pushing duplicates upward.
cse(block)
cse(block, rewriter)


class ControlFlowHoistPass(ModulePass):
Expand Down

0 comments on commit 902ccce

Please sign in to comment.