Skip to content

Commit 4959834

Browse files
Kritolanza
Krito
authored andcommitted
[CIR][Transforms] Move RemoveRedundantBranches logic into BrOp::fold method (llvm#663)
This pr is a part of llvm#593 . Move RemoveRedundantBranches logic into BrOp::fold method and modify tests.
1 parent 82267a4 commit 4959834

23 files changed

+287
-442
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+2
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,8 @@ def BrOp : CIR_Op<"br",
15991599
let assemblyFormat = [{
16001600
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
16011601
}];
1602+
1603+
let hasFolder = 1;
16021604
}
16031605

16041606
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRPasses.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ namespace mlir {
9191
void populateCIRPreLoweringPasses(OpPassManager &pm) {
9292
pm.addPass(createFlattenCFGPass());
9393
pm.addPass(createGotoSolverPass());
94+
pm.addPass(createMergeCleanupsPass());
9495
}
9596

9697
} // namespace mlir

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+36
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,42 @@ mlir::SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
944944

945945
Block *BrOp::getSuccessorForOperands(ArrayRef<Attribute>) { return getDest(); }
946946

947+
/// Removes branches between two blocks if it is the only branch.
948+
///
949+
/// From:
950+
/// ^bb0:
951+
/// cir.br ^bb1
952+
/// ^bb1: // pred: ^bb0
953+
/// cir.return
954+
///
955+
/// To:
956+
/// ^bb0:
957+
/// cir.return
958+
LogicalResult BrOp::fold(FoldAdaptor adaptor,
959+
SmallVectorImpl<OpFoldResult> &results) {
960+
Block *block = getOperation()->getBlock();
961+
Block *dest = getDest();
962+
963+
if (isa<mlir::cir::LabelOp>(dest->front())) {
964+
return failure();
965+
}
966+
967+
if (block->getNumSuccessors() == 1 && dest->getSinglePredecessor() == block) {
968+
getOperation()->erase();
969+
block->getOperations().splice(block->end(), dest->getOperations());
970+
auto eraseBlock = [](Block *block) {
971+
for (auto &op : llvm::make_early_inc_range(*block))
972+
op.erase();
973+
block->erase();
974+
};
975+
eraseBlock(dest);
976+
977+
return success();
978+
}
979+
980+
return failure();
981+
}
982+
947983
//===----------------------------------------------------------------------===//
948984
// BrCondOp
949985
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp

-35
Original file line numberDiff line numberDiff line change
@@ -23,40 +23,6 @@ using namespace cir;
2323

2424
namespace {
2525

26-
/// Removes branches between two blocks if it is the only branch.
27-
///
28-
/// From:
29-
/// ^bb0:
30-
/// cir.br ^bb1
31-
/// ^bb1: // pred: ^bb0
32-
/// cir.return
33-
///
34-
/// To:
35-
/// ^bb0:
36-
/// cir.return
37-
struct RemoveRedundantBranches : public OpRewritePattern<BrOp> {
38-
using OpRewritePattern<BrOp>::OpRewritePattern;
39-
40-
LogicalResult matchAndRewrite(BrOp op,
41-
PatternRewriter &rewriter) const final {
42-
Block *block = op.getOperation()->getBlock();
43-
Block *dest = op.getDest();
44-
45-
if (isa<mlir::cir::LabelOp>(dest->front()))
46-
return failure();
47-
48-
// Single edge between blocks: merge it.
49-
if (block->getNumSuccessors() == 1 &&
50-
dest->getSinglePredecessor() == block) {
51-
rewriter.eraseOp(op);
52-
rewriter.mergeBlocks(dest, block);
53-
return success();
54-
}
55-
56-
return failure();
57-
}
58-
};
59-
6026
struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
6127
using OpRewritePattern<ScopeOp>::OpRewritePattern;
6228

@@ -104,7 +70,6 @@ struct MergeCleanupsPass : public MergeCleanupsBase<MergeCleanupsPass> {
10470
void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
10571
// clang-format off
10672
patterns.add<
107-
RemoveRedundantBranches,
10873
RemoveEmptyScope,
10974
RemoveEmptySwitch
11075
>(patterns.getContext());

clang/test/CIR/CodeGen/goto.cpp

+38-68
Original file line numberDiff line numberDiff line change
@@ -159,25 +159,19 @@ int jumpIntoLoop(int* ar) {
159159
// CHECK: cir.func @_Z12jumpIntoLoopPi
160160
// CHECK: cir.brcond {{.*}} ^bb[[#BLK2:]], ^bb[[#BLK3:]]
161161
// CHECK: ^bb[[#BLK2]]:
162-
// CHECK: cir.br ^bb[[#BODY:]]
162+
// CHECK: cir.br ^bb[[#BLK7:]]
163163
// CHECK: ^bb[[#BLK3]]:
164164
// CHECK: cir.br ^bb[[#BLK4:]]
165165
// CHECK: ^bb[[#BLK4]]:
166-
// CHECK: cir.br ^bb[[#RETURN:]]
167-
// CHECK: ^bb[[#RETURN]]:
168166
// CHECK: cir.return
169167
// CHECK: ^bb[[#BLK5:]]:
170168
// CHECK: cir.br ^bb[[#BLK6:]]
171-
// CHECK: ^bb[[#BLK6]]:
172-
// CHECK: cir.br ^bb[[#COND:]]
173-
// CHECK: ^bb[[#COND]]:
174-
// CHECK: cir.brcond {{.*}} ^bb[[#BODY]], ^bb[[#EXIT:]]
175-
// CHECK: ^bb[[#BODY]]:
176-
// CHECK: cir.br ^bb[[#COND]]
177-
// CHECK: ^bb[[#EXIT]]:
178-
// CHECK: cir.br ^bb[[#BLK7:]]
169+
// CHECK: ^bb[[#BLK6]]:
170+
// CHECK: cir.brcond {{.*}} ^bb[[#BLK7:]], ^bb[[#BLK8:]]
179171
// CHECK: ^bb[[#BLK7]]:
180-
// CHECK: cir.br ^bb[[#RETURN]]
172+
// CHECK: cir.br ^bb[[#BLK6]]
173+
// CHECK: ^bb[[#BLK8]]:
174+
// CHECK: cir.br ^bb[[#BLK4]]
181175

182176

183177

@@ -197,31 +191,21 @@ int jumpFromLoop(int* ar) {
197191
return 0;
198192
}
199193
// CHECK: cir.func @_Z12jumpFromLoopPi
200-
// CHECK: cir.brcond {{.*}} ^bb[[#RETURN1:]], ^bb[[#BLK3:]]
201-
// CHECK: ^bb[[#RETURN1]]:
202-
// CHECK: cir.return
203-
// CHECK: ^bb[[#BLK3]]:
204-
// CHECK: cir.br ^bb[[#BLK4:]]
205-
// CHECK: ^bb[[#BLK4]]:
206-
// CHECK: cir.br ^bb[[#BLK5:]]
207-
// CHECK: ^bb[[#BLK5]]:
208-
// CHECK: cir.br ^bb[[#COND:]]
209-
// CHECK: ^bb[[#COND]]:
210-
// CHECK: cir.brcond {{.*}} ^bb[[#BODY:]], ^bb[[#EXIT:]]
211-
// CHECK: ^bb[[#BODY]]:
212-
// CHECK: cir.br ^bb[[#IF42:]]
213-
// CHECK: ^bb[[#IF42]]:
214-
// CHECK: cir.brcond {{.*}} ^bb[[#IF42TRUE:]], ^bb[[#IF42FALSE:]]
215-
// CHECK: ^bb[[#IF42TRUE]]:
216-
// CHECK: cir.br ^bb[[#RETURN1]]
217-
// CHECK: ^bb[[#IF42FALSE]]:
218-
// CHECK: cir.br ^bb[[#BLK11:]]
219-
// CHECK: ^bb[[#BLK11]]:
220-
// CHECK: cir.br ^bb[[#COND]]
221-
// CHECK: ^bb[[#EXIT]]:
222-
// CHECK: cir.br ^bb[[#RETURN2:]]
223-
// CHECK: ^bb[[#RETURN2]]:
224-
// CHECK: cir.return
194+
// CHECK: cir.brcond {{.*}} ^bb[[#BLK1:]], ^bb[[#BLK2:]]
195+
// CHECK: ^bb[[#BLK1:]]:
196+
// CHECK: cir.return {{.*}}
197+
// CHECK: ^bb[[#BLK2:]]:
198+
// CHECK: cir.br ^bb[[#BLK3:]]
199+
// CHECK: ^bb[[#BLK3:]]:
200+
// CHECK: cir.brcond {{.*}} ^bb[[#BLK4:]], ^bb[[#BLK7:]]
201+
// CHECK: ^bb[[#BLK4:]]:
202+
// CHECK: cir.brcond {{.*}} ^bb[[#BLK5:]], ^bb[[#BLK6:]]
203+
// CHECK: ^bb[[#BLK5:]]:
204+
// CHECK: cir.br ^bb[[#BLK1:]]
205+
// CHECK: ^bb[[#BLK6:]]:
206+
// CHECK: cir.br ^bb[[#BLK3:]]
207+
// CHECK: ^bb[[#BLK7:]]:
208+
// CHECK: cir.return {{.*}}
225209

226210

227211
void flatLoopWithNoTerminatorInFront(int* ptr) {
@@ -240,35 +224,21 @@ void flatLoopWithNoTerminatorInFront(int* ptr) {
240224
;
241225
}
242226

243-
// CHECK: cir.func @_Z31flatLoopWithNoTerminatorInFrontPi
244-
// CHECK: cir.brcond {{.*}} ^bb[[#BLK2:]], ^bb[[#BLK3:]]
245-
// CHECK: ^bb[[#BLK2]]:
246-
// CHECK: cir.br ^bb[[#LABEL_LOOP:]]
247-
// CHECK: ^bb[[#BLK3]]:
248-
// CHECK: cir.br ^bb[[#BLK4:]]
249-
// CHECK: ^bb[[#BLK4]]:
250-
// CHECK: cir.br ^bb[[#BLK5:]]
251-
// CHECK: ^bb[[#BLK5]]:
252-
// CHECK: cir.br ^bb[[#BODY:]]
253-
// CHECK: ^bb[[#COND]]:
254-
// CHECK: cir.brcond {{.*}} ^bb[[#BODY]], ^bb[[#EXIT:]]
255-
// CHECK: ^bb[[#BODY]]:
227+
// CHECK-LABEL: cir.func @_Z31flatLoopWithNoTerminatorInFrontPi
228+
// CHECK: cir.brcond {{.*}} ^bb[[#BLK1:]], ^bb[[#BLK2:]]
229+
// CHECK: ^bb[[#BLK1:]]:
230+
// CHECK: cir.br ^bb[[#BLK6:]]
231+
// CHECK: ^bb[[#BLK2:]]:
232+
// CHECK: cir.br ^bb[[#BLK3:]]
233+
// CHECK: ^bb[[#BLK3:]]: // 2 preds: ^bb[[#BLK2:]], ^bb[[#BLK6:]]
234+
// CHECK: cir.brcond {{.*}} ^bb[[#BLK4:]], ^bb[[#BLK5:]]
235+
// CHECK: ^bb[[#BLK4:]]:
256236
// CHECK: cir.br ^bb[[#BLK8:]]
257-
// CHECK: ^bb[[#BLK8]]:
258-
// CHECK: cir.brcond {{.*}} ^bb[[#BLK9:]], ^bb[[#BLK10:]]
259-
// CHECK: ^bb[[#BLK9]]:
260-
// CHECK: cir.br ^bb[[#RETURN:]]
261-
// CHECK: ^bb[[#BLK10]]:
262-
// CHECK: cir.br ^bb[[#BLK11:]]
263-
// CHECK: ^bb[[#BLK11]]:
264-
// CHECK: cir.br ^bb[[#LABEL_LOOP]]
265-
// CHECK: ^bb[[#LABEL_LOOP]]:
266-
// CHECK: cir.br ^bb[[#COND]]
267-
// CHECK: ^bb[[#EXIT]]:
268-
// CHECK: cir.br ^bb[[#BLK14:]]
269-
// CHECK: ^bb[[#BLK14]]:
270-
// CHECK: cir.br ^bb[[#RETURN]]
271-
// CHECK: ^bb[[#RETURN]]:
272-
// CHECK: cir.return
273-
// CHECK: }
274-
// CHECK:}
237+
// CHECK: ^bb[[#BLK5:]]:
238+
// CHECK: cir.br ^bb[[#BLK6:]]
239+
// CHECK: ^bb[[#BLK6:]]: // 2 preds: ^bb[[#BLK1:]], ^bb[[#BLK5:]]
240+
// CHECK: cir.brcond {{.*}} ^bb[[#BLK3:]], ^bb[[#BLK7:]]
241+
// CHECK: ^bb[[#BLK7:]]:
242+
// CHECK: cir.br ^bb[[#BLK8:]]
243+
// CHECK: ^bb[[#BLK8:]]: // 2 preds: ^bb[[#BLK4:]], ^bb[[#BLK7:]]
244+
// CHECK: cir.return

clang/test/CIR/CodeGen/switch-gnurange.cpp

-10
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@ void sw3(enum letter c) {
172172
// LLVM: store i32 4, ptr %[[X]]
173173
// LLVM: br label %[[EPILOG]]
174174
// LLVM: [[EPILOG]]:
175-
// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]]
176-
// LLVM: [[EPILOG_END]]:
177175
// LLVM-NEXT: ret void
178176

179177
void sw4(int x) {
@@ -213,8 +211,6 @@ void sw4(int x) {
213211
// LLVM-NEXT: %[[DIFF_CMP:[0-9]+]] = icmp ule i32 %[[DIFF]], 167
214212
// LLVM: br i1 %[[DIFF_CMP]], label %[[CASE_66_233]], label %[[EPILOG]]
215213
// LLVM: [[EPILOG]]:
216-
// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]]
217-
// LLVM: [[EPILOG_END]]:
218214
// LLVM-NEXT: ret void
219215

220216
void sw5(int x) {
@@ -241,8 +237,6 @@ void sw5(int x) {
241237
// LLVM-NEXT: store i32 1, ptr %[[Y:[0-9]+]]
242238
// LLVM-NEXT: br label %[[EPILOG]]
243239
// LLVM: [[EPILOG]]:
244-
// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]]
245-
// LLVM: [[EPILOG_END]]:
246240
// LLVM-NEXT: ret void
247241

248242
void sw6(int x) {
@@ -273,8 +267,6 @@ void sw6(int x) {
273267
// LLVM-NEXT: %[[DIFF_CMP:[0-9]+]] = icmp ule i32 %[[DIFF]], -1
274268
// LLVM-NEXT: br i1 %[[DIFF_CMP]], label %[[CASE_MIN_MAX]], label %[[EPILOG]]
275269
// LLVM: [[EPILOG]]:
276-
// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]]
277-
// LLVM: [[EPILOG_END]]:
278270
// LLVM-NEXT: ret void
279271

280272
void sw7(int x) {
@@ -346,7 +338,5 @@ void sw7(int x) {
346338
// LLVM: [[CASE_500_600]]:
347339
// LLVM-NEXT: br label %[[EPILOG]]
348340
// LLVM: [[EPILOG]]:
349-
// LLVM-NEXT: br label %[[EPILOG_END:[0-9]+]]
350-
// LLVM: [[EPILOG_END]]:
351341
// LLVM-NEXT: ret void
352342

clang/test/CIR/CodeGen/var-arg-scope.c

+1-7
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ void f1(__builtin_va_list c) {
6868
// LLVM: %struct.__va_list = type { ptr, ptr, ptr, i32, i32 }
6969
// LLVM: define void @f1(%struct.__va_list %0)
7070
// LLVM: [[VARLIST:%.*]] = alloca %struct.__va_list, i64 1, align 8,
71-
// LLVM: br label %[[SCOPE_FRONT:.*]],
72-
73-
// LLVM: [[SCOPE_FRONT]]: ; preds = %1
7471
// LLVM: [[GR_OFFS_P:%.*]] = getelementptr %struct.__va_list, ptr [[VARLIST]], i32 0, i32 3
7572
// LLVM: [[GR_OFFS:%.*]] = load i32, ptr [[GR_OFFS_P]], align 4,
7673
// LLVM-NEXT: [[CMP0:%.*]] = icmp sge i32 [[GR_OFFS]], 0,
@@ -99,7 +96,4 @@ void f1(__builtin_va_list c) {
9996
// LLVM: [[BB_END]]: ; preds = %[[BB_ON_STACK]], %[[BB_IN_REG]]
10097
// LLVM-NEXT: [[PHIP:%.*]] = phi ptr [ [[IN_REG_OUTPUT]], %[[BB_IN_REG]] ], [ [[STACK_V]], %[[BB_ON_STACK]] ]
10198
// LLVM-NEXT: [[PHIV:%.*]] = load ptr, ptr [[PHIP]], align 8,
102-
// LLVM-NEXT: br label %[[OUT_SCOPE:.*]],
103-
104-
// LLVM: [[OUT_SCOPE]]: ; preds = %[[BB_END]]
105-
// LLVM-NEXT: ret void,
99+
// LLVM: ret void,

clang/test/CIR/Lowering/ThroughMLIR/goto.cir

+10-4
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@ module {
77
%0 = cir.alloca !u32i, !cir.ptr<!u32i>, ["b", init] {alignment = 4 : i64}
88
%1 = cir.const #cir.int<1> : !u32i
99
cir.store %1, %0 : !u32i, !cir.ptr<!u32i>
10-
cir.br ^bb2
11-
^bb1: // no predecessors
10+
%c = cir.const #cir.int<0> : !u32i
11+
%cond = cir.cast(int_to_bool, %c : !u32i), !cir.bool
12+
cir.brcond %cond ^bb1, ^bb2
13+
14+
^bb1:
1215
%2 = cir.load %0 : !cir.ptr<!u32i>, !u32i
1316
%3 = cir.const #cir.int<1> : !u32i
1417
%4 = cir.binop(add, %2, %3) : !u32i
1518
cir.store %4, %0 : !u32i, !cir.ptr<!u32i>
1619
cir.br ^bb2
20+
1721
^bb2: // 2 preds: ^bb0, ^bb1
1822
%5 = cir.load %0 : !cir.ptr<!u32i>, !u32i
1923
%6 = cir.const #cir.int<2> : !u32i
@@ -25,8 +29,10 @@ module {
2529

2630
// MLIR: module {
2731
// MLIR-NEXT: func @foo
28-
// MLIR: cf.br ^bb1
29-
// MLIR: ^bb1:
32+
// MLIR: cf.cond_br %{{.+}}, ^bb[[#BLK1:]], ^bb[[#BLK2:]]
33+
// MLIR: ^bb[[#BLK1:]]:
34+
// MLIR: cf.br ^bb[[#BLK2:]]
35+
// MLIR: ^bb[[#BLK2:]]:
3036
// MLIR: return
3137

3238
// LLVM: br label %[[Value:[0-9]+]]

0 commit comments

Comments
 (0)