Skip to content

Commit

Permalink
dialects: (cf) branch canonicalization. (#3234)
Browse files Browse the repository at this point in the history
Adds one of the canonicalization patterns for `cf.branch`. The other
will need predecessor information for blocks and so is not part of this
PR.

---------

Co-authored-by: Sasha Lopoukhine <superlopuh@gmail.com>
  • Loading branch information
alexarice and superlopuh authored Oct 2, 2024
1 parent 07e8eba commit 7e5543b
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 4 deletions.
29 changes: 26 additions & 3 deletions tests/filecheck/dialects/cf/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,30 @@
// CHECK-NEXT: }

func.func @assert_true() -> i1 {
%0 = arith.constant true
cf.assert %0 , "assert true"
func.return %0 : i1
%0 = arith.constant true
cf.assert %0 , "assert true"
func.return %0 : i1
}

/// Test that pass-through successors of BranchOp get folded.

// CHECK: func.func @br_passthrough(%arg0 : i32, %arg1 : i32) -> (i32, i32) {
// CHECK-NEXT: "test.termop"() [^[[#b0:]], ^[[#b1:]], ^[[#b2:]]] : () -> ()
// CHECK-NEXT: ^[[#b0]]:
// CHECK-NEXT: cf.br ^[[#b2]](%arg2, %arg1 : i32, i32)
// CHECK-NEXT: ^[[#b1]](%arg2 : i32):
// CHECK-NEXT: cf.br ^[[#b2]](%arg2, %arg1 : i32, i32)
// CHECK-NEXT: ^[[#b2]](%arg4 : i32, %arg5 : i32):
// CHECK-NEXT: func.return %arg4, %arg5 : i32, i32
// CHECK-NEXT: }
func.func @br_passthrough(%arg0 : i32, %arg1 : i32) -> (i32, i32) {
"test.termop"() [^0, ^1, ^2] : () -> ()
^0:
cf.br ^1(%arg0 : i32)

^1(%arg2 : i32):
cf.br ^2(%arg2, %arg1 : i32, i32)

^2(%arg4 : i32, %arg5 : i32):
return %arg4, %arg5 : i32, i32
}
10 changes: 9 additions & 1 deletion xdsl/dialects/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def __init__(self, arg: Operation | SSAValue, msg: str | StringAttr):
assembly_format = "$arg `,` $msg attr-dict"


class BranchHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.cf import SimplifyPassThroughBr

return (SimplifyPassThroughBr(),)


@irdl_op_definition
class Branch(IRDLOperation):
"""Branch operation"""
Expand All @@ -76,7 +84,7 @@ class Branch(IRDLOperation):
arguments = var_operand_def()
successor = successor_def()

traits = frozenset([IsTerminator()])
traits = frozenset((IsTerminator(), BranchHasCanonicalizationPatterns()))

def __init__(self, dest: Block, *ops: Operation | SSAValue):
super().__init__(operands=[[op for op in ops]], successors=[dest])
Expand Down
72 changes: 72 additions & 0 deletions xdsl/transforms/canonicalization_patterns/cf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from collections.abc import Sequence

from xdsl.dialects import arith, cf
from xdsl.dialects.builtin import IntegerAttr
from xdsl.ir import Block, BlockArgument, SSAValue
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
Expand All @@ -26,3 +29,72 @@ def match_and_rewrite(self, op: cf.Assert, rewriter: PatternRewriter):
return

rewriter.replace_matched_op([])


def collapse_branch(
successor: Block, successor_operands: Sequence[SSAValue]
) -> tuple[Block, Sequence[SSAValue]] | None:
"""
Given a successor, try to collapse it to a new destination if it only
contains a passthrough unconditional branch. If the successor is
collapsable, `successor` and `successorOperands` are updated to reference
the new destination and values. `argStorage` is used as storage if operands
to the collapsed successor need to be remapped. It must outlive uses of
successorOperands.
"""

# Check that successor only contains branch
if len(successor.ops) != 1:
return

branch = successor.ops.first
# Check that the terminator is an unconditional branch
if not isinstance(branch, cf.Branch):
return

# Check that the arguments are only used within the terminator
for argument in successor.args:
for user in argument.uses:
if user.operation != branch:
return

# Don't try to collapse branches to infinite loops.
if branch.successor == successor:
return

# Remap operands
operands = branch.operands

new_operands = tuple(
successor_operands[op_owner.index]
if isinstance(op_owner := operand.owner, BlockArgument)
and op_owner.block is successor
else operand
for operand in operands
)

return (branch.successor, new_operands)


class SimplifyPassThroughBr(RewritePattern):
"""
br ^bb1
^bb1
br ^bbN(...)
-> br ^bbN(...)
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.Branch, rewriter: PatternRewriter):
# Check the successor doesn't point back to the current block
parent = op.parent_block()
if parent is None or op.successor == parent:
return

ret = collapse_branch(op.successor, op.arguments)
if ret is None:
return
(block, args) = ret

rewriter.replace_matched_op(cf.Branch(block, *args))

0 comments on commit 7e5543b

Please sign in to comment.