Skip to content

Commit

Permalink
[CIR][Transforms] Move RemoveRedundantBranches logic into BrOp::fold …
Browse files Browse the repository at this point in the history
…method (llvm#663)

This pr is a part of llvm#593 .
Move RemoveRedundantBranches logic into BrOp::fold method and modify
tests.
  • Loading branch information
Krito authored Jun 13, 2024
1 parent f78f9a5 commit 3b9f698
Show file tree
Hide file tree
Showing 23 changed files with 287 additions and 442 deletions.
2 changes: 2 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,8 @@ def BrOp : CIR_Op<"br",
let assemblyFormat = [{
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
}];

let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/CodeGen/CIRPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ namespace mlir {
void populateCIRPreLoweringPasses(OpPassManager &pm) {
pm.addPass(createFlattenCFGPass());
pm.addPass(createGotoSolverPass());
pm.addPass(createMergeCleanupsPass());
}

} // namespace mlir
36 changes: 36 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,42 @@ mlir::SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {

Block *BrOp::getSuccessorForOperands(ArrayRef<Attribute>) { return getDest(); }

/// Removes branches between two blocks if it is the only branch.
///
/// From:
/// ^bb0:
/// cir.br ^bb1
/// ^bb1: // pred: ^bb0
/// cir.return
///
/// To:
/// ^bb0:
/// cir.return
LogicalResult BrOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
Block *block = getOperation()->getBlock();
Block *dest = getDest();

if (isa<mlir::cir::LabelOp>(dest->front())) {
return failure();
}

if (block->getNumSuccessors() == 1 && dest->getSinglePredecessor() == block) {
getOperation()->erase();
block->getOperations().splice(block->end(), dest->getOperations());
auto eraseBlock = [](Block *block) {
for (auto &op : llvm::make_early_inc_range(*block))
op.erase();
block->erase();
};
eraseBlock(dest);

return success();
}

return failure();
}

//===----------------------------------------------------------------------===//
// BrCondOp
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 0 additions & 35 deletions clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,6 @@ using namespace cir;

namespace {

/// Removes branches between two blocks if it is the only branch.
///
/// From:
/// ^bb0:
/// cir.br ^bb1
/// ^bb1: // pred: ^bb0
/// cir.return
///
/// To:
/// ^bb0:
/// cir.return
struct RemoveRedundantBranches : public OpRewritePattern<BrOp> {
using OpRewritePattern<BrOp>::OpRewritePattern;

LogicalResult matchAndRewrite(BrOp op,
PatternRewriter &rewriter) const final {
Block *block = op.getOperation()->getBlock();
Block *dest = op.getDest();

if (isa<mlir::cir::LabelOp>(dest->front()))
return failure();

// Single edge between blocks: merge it.
if (block->getNumSuccessors() == 1 &&
dest->getSinglePredecessor() == block) {
rewriter.eraseOp(op);
rewriter.mergeBlocks(dest, block);
return success();
}

return failure();
}
};

struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
using OpRewritePattern<ScopeOp>::OpRewritePattern;

Expand Down Expand Up @@ -104,7 +70,6 @@ struct MergeCleanupsPass : public MergeCleanupsBase<MergeCleanupsPass> {
void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
RemoveRedundantBranches,
RemoveEmptyScope,
RemoveEmptySwitch
>(patterns.getContext());
Expand Down
106 changes: 38 additions & 68 deletions clang/test/CIR/CodeGen/goto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,25 +159,19 @@ int jumpIntoLoop(int* ar) {
// CHECK: cir.func @_Z12jumpIntoLoopPi
// CHECK: cir.brcond {{.*}} ^bb[[#BLK2:]], ^bb[[#BLK3:]]
// CHECK: ^bb[[#BLK2]]:
// CHECK: cir.br ^bb[[#BODY:]]
// CHECK: cir.br ^bb[[#BLK7:]]
// CHECK: ^bb[[#BLK3]]:
// CHECK: cir.br ^bb[[#BLK4:]]
// CHECK: ^bb[[#BLK4]]:
// CHECK: cir.br ^bb[[#RETURN:]]
// CHECK: ^bb[[#RETURN]]:
// CHECK: cir.return
// CHECK: ^bb[[#BLK5:]]:
// CHECK: cir.br ^bb[[#BLK6:]]
// CHECK: ^bb[[#BLK6]]:
// CHECK: cir.br ^bb[[#COND:]]
// CHECK: ^bb[[#COND]]:
// CHECK: cir.brcond {{.*}} ^bb[[#BODY]], ^bb[[#EXIT:]]
// CHECK: ^bb[[#BODY]]:
// CHECK: cir.br ^bb[[#COND]]
// CHECK: ^bb[[#EXIT]]:
// CHECK: cir.br ^bb[[#BLK7:]]
// CHECK: ^bb[[#BLK6]]:
// CHECK: cir.brcond {{.*}} ^bb[[#BLK7:]], ^bb[[#BLK8:]]
// CHECK: ^bb[[#BLK7]]:
// CHECK: cir.br ^bb[[#RETURN]]
// CHECK: cir.br ^bb[[#BLK6]]
// CHECK: ^bb[[#BLK8]]:
// CHECK: cir.br ^bb[[#BLK4]]



Expand All @@ -197,31 +191,21 @@ int jumpFromLoop(int* ar) {
return 0;
}
// CHECK: cir.func @_Z12jumpFromLoopPi
// CHECK: cir.brcond {{.*}} ^bb[[#RETURN1:]], ^bb[[#BLK3:]]
// CHECK: ^bb[[#RETURN1]]:
// CHECK: cir.return
// CHECK: ^bb[[#BLK3]]:
// CHECK: cir.br ^bb[[#BLK4:]]
// CHECK: ^bb[[#BLK4]]:
// CHECK: cir.br ^bb[[#BLK5:]]
// CHECK: ^bb[[#BLK5]]:
// CHECK: cir.br ^bb[[#COND:]]
// CHECK: ^bb[[#COND]]:
// CHECK: cir.brcond {{.*}} ^bb[[#BODY:]], ^bb[[#EXIT:]]
// CHECK: ^bb[[#BODY]]:
// CHECK: cir.br ^bb[[#IF42:]]
// CHECK: ^bb[[#IF42]]:
// CHECK: cir.brcond {{.*}} ^bb[[#IF42TRUE:]], ^bb[[#IF42FALSE:]]
// CHECK: ^bb[[#IF42TRUE]]:
// CHECK: cir.br ^bb[[#RETURN1]]
// CHECK: ^bb[[#IF42FALSE]]:
// CHECK: cir.br ^bb[[#BLK11:]]
// CHECK: ^bb[[#BLK11]]:
// CHECK: cir.br ^bb[[#COND]]
// CHECK: ^bb[[#EXIT]]:
// CHECK: cir.br ^bb[[#RETURN2:]]
// CHECK: ^bb[[#RETURN2]]:
// CHECK: cir.return
// CHECK: cir.brcond {{.*}} ^bb[[#BLK1:]], ^bb[[#BLK2:]]
// CHECK: ^bb[[#BLK1:]]:
// CHECK: cir.return {{.*}}
// CHECK: ^bb[[#BLK2:]]:
// CHECK: cir.br ^bb[[#BLK3:]]
// CHECK: ^bb[[#BLK3:]]:
// CHECK: cir.brcond {{.*}} ^bb[[#BLK4:]], ^bb[[#BLK7:]]
// CHECK: ^bb[[#BLK4:]]:
// CHECK: cir.brcond {{.*}} ^bb[[#BLK5:]], ^bb[[#BLK6:]]
// CHECK: ^bb[[#BLK5:]]:
// CHECK: cir.br ^bb[[#BLK1:]]
// CHECK: ^bb[[#BLK6:]]:
// CHECK: cir.br ^bb[[#BLK3:]]
// CHECK: ^bb[[#BLK7:]]:
// CHECK: cir.return {{.*}}


void flatLoopWithNoTerminatorInFront(int* ptr) {
Expand All @@ -240,35 +224,21 @@ void flatLoopWithNoTerminatorInFront(int* ptr) {
;
}

// CHECK: cir.func @_Z31flatLoopWithNoTerminatorInFrontPi
// CHECK: cir.brcond {{.*}} ^bb[[#BLK2:]], ^bb[[#BLK3:]]
// CHECK: ^bb[[#BLK2]]:
// CHECK: cir.br ^bb[[#LABEL_LOOP:]]
// CHECK: ^bb[[#BLK3]]:
// CHECK: cir.br ^bb[[#BLK4:]]
// CHECK: ^bb[[#BLK4]]:
// CHECK: cir.br ^bb[[#BLK5:]]
// CHECK: ^bb[[#BLK5]]:
// CHECK: cir.br ^bb[[#BODY:]]
// CHECK: ^bb[[#COND]]:
// CHECK: cir.brcond {{.*}} ^bb[[#BODY]], ^bb[[#EXIT:]]
// CHECK: ^bb[[#BODY]]:
// CHECK-LABEL: cir.func @_Z31flatLoopWithNoTerminatorInFrontPi
// CHECK: cir.brcond {{.*}} ^bb[[#BLK1:]], ^bb[[#BLK2:]]
// CHECK: ^bb[[#BLK1:]]:
// CHECK: cir.br ^bb[[#BLK6:]]
// CHECK: ^bb[[#BLK2:]]:
// CHECK: cir.br ^bb[[#BLK3:]]
// CHECK: ^bb[[#BLK3:]]: // 2 preds: ^bb[[#BLK2:]], ^bb[[#BLK6:]]
// CHECK: cir.brcond {{.*}} ^bb[[#BLK4:]], ^bb[[#BLK5:]]
// CHECK: ^bb[[#BLK4:]]:
// CHECK: cir.br ^bb[[#BLK8:]]
// CHECK: ^bb[[#BLK8]]:
// CHECK: cir.brcond {{.*}} ^bb[[#BLK9:]], ^bb[[#BLK10:]]
// CHECK: ^bb[[#BLK9]]:
// CHECK: cir.br ^bb[[#RETURN:]]
// CHECK: ^bb[[#BLK10]]:
// CHECK: cir.br ^bb[[#BLK11:]]
// CHECK: ^bb[[#BLK11]]:
// CHECK: cir.br ^bb[[#LABEL_LOOP]]
// CHECK: ^bb[[#LABEL_LOOP]]:
// CHECK: cir.br ^bb[[#COND]]
// CHECK: ^bb[[#EXIT]]:
// CHECK: cir.br ^bb[[#BLK14:]]
// CHECK: ^bb[[#BLK14]]:
// CHECK: cir.br ^bb[[#RETURN]]
// CHECK: ^bb[[#RETURN]]:
// CHECK: cir.return
// CHECK: }
// CHECK:}
// CHECK: ^bb[[#BLK5:]]:
// CHECK: cir.br ^bb[[#BLK6:]]
// CHECK: ^bb[[#BLK6:]]: // 2 preds: ^bb[[#BLK1:]], ^bb[[#BLK5:]]
// CHECK: cir.brcond {{.*}} ^bb[[#BLK3:]], ^bb[[#BLK7:]]
// CHECK: ^bb[[#BLK7:]]:
// CHECK: cir.br ^bb[[#BLK8:]]
// CHECK: ^bb[[#BLK8:]]: // 2 preds: ^bb[[#BLK4:]], ^bb[[#BLK7:]]
// CHECK: cir.return
10 changes: 0 additions & 10 deletions clang/test/CIR/CodeGen/switch-gnurange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,6 @@ void sw3(enum letter c) {
// LLVM: store i32 4, ptr %[[X]]
// LLVM: br label %[[EPILOG]]
// LLVM: [[EPILOG]]:
// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]]
// LLVM: [[EPILOG_END]]:
// LLVM-NEXT: ret void

void sw4(int x) {
Expand Down Expand Up @@ -213,8 +211,6 @@ void sw4(int x) {
// LLVM-NEXT: %[[DIFF_CMP:[0-9]+]] = icmp ule i32 %[[DIFF]], 167
// LLVM: br i1 %[[DIFF_CMP]], label %[[CASE_66_233]], label %[[EPILOG]]
// LLVM: [[EPILOG]]:
// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]]
// LLVM: [[EPILOG_END]]:
// LLVM-NEXT: ret void

void sw5(int x) {
Expand All @@ -241,8 +237,6 @@ void sw5(int x) {
// LLVM-NEXT: store i32 1, ptr %[[Y:[0-9]+]]
// LLVM-NEXT: br label %[[EPILOG]]
// LLVM: [[EPILOG]]:
// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]]
// LLVM: [[EPILOG_END]]:
// LLVM-NEXT: ret void

void sw6(int x) {
Expand Down Expand Up @@ -273,8 +267,6 @@ void sw6(int x) {
// LLVM-NEXT: %[[DIFF_CMP:[0-9]+]] = icmp ule i32 %[[DIFF]], -1
// LLVM-NEXT: br i1 %[[DIFF_CMP]], label %[[CASE_MIN_MAX]], label %[[EPILOG]]
// LLVM: [[EPILOG]]:
// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]]
// LLVM: [[EPILOG_END]]:
// LLVM-NEXT: ret void

void sw7(int x) {
Expand Down Expand Up @@ -346,7 +338,5 @@ void sw7(int x) {
// LLVM: [[CASE_500_600]]:
// LLVM-NEXT: br label %[[EPILOG]]
// LLVM: [[EPILOG]]:
// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]]
// LLVM: [[EPILOG_END]]:
// LLVM-NEXT: ret void

8 changes: 1 addition & 7 deletions clang/test/CIR/CodeGen/var-arg-scope.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ void f1(__builtin_va_list c) {
// LLVM: %struct.__va_list = type { ptr, ptr, ptr, i32, i32 }
// LLVM: define void @f1(%struct.__va_list %0)
// LLVM: [[VARLIST:%.*]] = alloca %struct.__va_list, i64 1, align 8,
// LLVM: br label %[[SCOPE_FRONT:.*]],

// LLVM: [[SCOPE_FRONT]]: ; preds = %1
// LLVM: [[GR_OFFS_P:%.*]] = getelementptr %struct.__va_list, ptr [[VARLIST]], i32 0, i32 3
// LLVM: [[GR_OFFS:%.*]] = load i32, ptr [[GR_OFFS_P]], align 4,
// LLVM-NEXT: [[CMP0:%.*]] = icmp sge i32 [[GR_OFFS]], 0,
Expand Down Expand Up @@ -99,7 +96,4 @@ void f1(__builtin_va_list c) {
// LLVM: [[BB_END]]: ; preds = %[[BB_ON_STACK]], %[[BB_IN_REG]]
// LLVM-NEXT: [[PHIP:%.*]] = phi ptr [ [[IN_REG_OUTPUT]], %[[BB_IN_REG]] ], [ [[STACK_V]], %[[BB_ON_STACK]] ]
// LLVM-NEXT: [[PHIV:%.*]] = load ptr, ptr [[PHIP]], align 8,
// LLVM-NEXT: br label %[[OUT_SCOPE:.*]],

// LLVM: [[OUT_SCOPE]]: ; preds = %[[BB_END]]
// LLVM-NEXT: ret void,
// LLVM: ret void,
14 changes: 10 additions & 4 deletions clang/test/CIR/Lowering/ThroughMLIR/goto.cir
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ module {
%0 = cir.alloca !u32i, !cir.ptr<!u32i>, ["b", init] {alignment = 4 : i64}
%1 = cir.const #cir.int<1> : !u32i
cir.store %1, %0 : !u32i, !cir.ptr<!u32i>
cir.br ^bb2
^bb1: // no predecessors
%c = cir.const #cir.int<0> : !u32i
%cond = cir.cast(int_to_bool, %c : !u32i), !cir.bool
cir.brcond %cond ^bb1, ^bb2

^bb1:
%2 = cir.load %0 : !cir.ptr<!u32i>, !u32i
%3 = cir.const #cir.int<1> : !u32i
%4 = cir.binop(add, %2, %3) : !u32i
cir.store %4, %0 : !u32i, !cir.ptr<!u32i>
cir.br ^bb2

^bb2: // 2 preds: ^bb0, ^bb1
%5 = cir.load %0 : !cir.ptr<!u32i>, !u32i
%6 = cir.const #cir.int<2> : !u32i
Expand All @@ -25,8 +29,10 @@ module {

// MLIR: module {
// MLIR-NEXT: func @foo
// MLIR: cf.br ^bb1
// MLIR: ^bb1:
// MLIR: cf.cond_br %{{.+}}, ^bb[[#BLK1:]], ^bb[[#BLK2:]]
// MLIR: ^bb[[#BLK1:]]:
// MLIR: cf.br ^bb[[#BLK2:]]
// MLIR: ^bb[[#BLK2:]]:
// MLIR: return

// LLVM: br label %[[Value:[0-9]+]]
Expand Down
Loading

0 comments on commit 3b9f698

Please sign in to comment.