Skip to content

Commit 1de696b

Browse files
authored
[CIR] Fix multiple returns in switch statements (#164468)
Add support for multiple return statements in switch statements. Cases in switch statements don't have their own scopes but are distinct regions nonetheless. Insert multiple return blocks for each case and handle them in the cleanup code.
1 parent a831c3f commit 1de696b

File tree

3 files changed

+164
-40
lines changed

3 files changed

+164
-40
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,19 @@ void CIRGenFunction::LexicalScope::cleanup() {
242242
}
243243
};
244244

245-
if (returnBlock != nullptr) {
246-
// Write out the return block, which loads the value from `__retval` and
247-
// issues the `cir.return`.
245+
// Cleanup are done right before codegen resumes a scope. This is where
246+
// objects are destroyed. Process all return blocks.
247+
// TODO(cir): Handle returning from a switch statement through a cleanup
248+
// block. We can't simply jump to the cleanup block, because the cleanup block
249+
// is not part of the case region. Either reemit all cleanups in the return
250+
// block or wait for MLIR structured control flow to support early exits.
251+
llvm::SmallVector<mlir::Block *> retBlocks;
252+
for (mlir::Block *retBlock : localScope->getRetBlocks()) {
248253
mlir::OpBuilder::InsertionGuard guard(builder);
249-
builder.setInsertionPointToEnd(returnBlock);
250-
(void)emitReturn(*returnLoc);
254+
builder.setInsertionPointToEnd(retBlock);
255+
retBlocks.push_back(retBlock);
256+
mlir::Location retLoc = localScope->getRetLoc(retBlock);
257+
emitReturn(retLoc);
251258
}
252259

253260
auto insertCleanupAndLeave = [&](mlir::Block *insPt) {
@@ -274,19 +281,22 @@ void CIRGenFunction::LexicalScope::cleanup() {
274281

275282
if (localScope->depth == 0) {
276283
// Reached the end of the function.
277-
if (returnBlock != nullptr) {
278-
if (returnBlock->getUses().empty()) {
279-
returnBlock->erase();
284+
// Special handling only for single return block case
285+
if (localScope->getRetBlocks().size() == 1) {
286+
mlir::Block *retBlock = localScope->getRetBlocks()[0];
287+
mlir::Location retLoc = localScope->getRetLoc(retBlock);
288+
if (retBlock->getUses().empty()) {
289+
retBlock->erase();
280290
} else {
281291
// Thread return block via cleanup block.
282292
if (cleanupBlock) {
283-
for (mlir::BlockOperand &blockUse : returnBlock->getUses()) {
293+
for (mlir::BlockOperand &blockUse : retBlock->getUses()) {
284294
cir::BrOp brOp = mlir::cast<cir::BrOp>(blockUse.getOwner());
285295
brOp.setSuccessor(cleanupBlock);
286296
}
287297
}
288298

289-
cir::BrOp::create(builder, *returnLoc, returnBlock);
299+
cir::BrOp::create(builder, retLoc, retBlock);
290300
return;
291301
}
292302
}
@@ -324,8 +334,10 @@ void CIRGenFunction::LexicalScope::cleanup() {
324334
bool entryBlock = builder.getInsertionBlock()->isEntryBlock();
325335
if (!entryBlock && curBlock->empty()) {
326336
curBlock->erase();
327-
if (returnBlock != nullptr && returnBlock->getUses().empty())
328-
returnBlock->erase();
337+
for (mlir::Block *retBlock : retBlocks) {
338+
if (retBlock->getUses().empty())
339+
retBlock->erase();
340+
}
329341
return;
330342
}
331343

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,44 +1103,69 @@ class CIRGenFunction : public CIRGenTypeCache {
11031103
// ---
11041104

11051105
private:
1106-
// `returnBlock`, `returnLoc`, and all the functions that deal with them
1107-
// will change and become more complicated when `switch` statements are
1108-
// upstreamed. `case` statements within the `switch` are in the same scope
1109-
// but have their own regions. Therefore the LexicalScope will need to
1110-
// keep track of multiple return blocks.
1111-
mlir::Block *returnBlock = nullptr;
1112-
std::optional<mlir::Location> returnLoc;
1113-
1114-
// See the comment on `getOrCreateRetBlock`.
1106+
// On switches we need one return block per region, since cases don't
1107+
// have their own scopes but are distinct regions nonetheless.
1108+
1109+
// TODO: This implementation should change once we have support for early
1110+
// exits in MLIR structured control flow (llvm-project#161575)
1111+
llvm::SmallVector<mlir::Block *> retBlocks;
1112+
llvm::DenseMap<mlir::Block *, mlir::Location> retLocs;
1113+
llvm::DenseMap<cir::CaseOp, unsigned> retBlockInCaseIndex;
1114+
std::optional<unsigned> normalRetBlockIndex;
1115+
1116+
// There's usually only one ret block per scope, but this needs to be
1117+
// get or create because of potential unreachable return statements, note
1118+
// that for those, all source location maps to the first one found.
11151119
mlir::Block *createRetBlock(CIRGenFunction &cgf, mlir::Location loc) {
1116-
assert(returnBlock == nullptr && "only one return block per scope");
1117-
// Create the cleanup block but don't hook it up just yet.
1120+
assert((isa_and_nonnull<cir::CaseOp>(
1121+
cgf.builder.getBlock()->getParentOp()) ||
1122+
retBlocks.size() == 0) &&
1123+
"only switches can hold more than one ret block");
1124+
1125+
// Create the return block but don't hook it up just yet.
11181126
mlir::OpBuilder::InsertionGuard guard(cgf.builder);
1119-
returnBlock =
1120-
cgf.builder.createBlock(cgf.builder.getBlock()->getParent());
1121-
updateRetLoc(returnBlock, loc);
1122-
return returnBlock;
1127+
auto *b = cgf.builder.createBlock(cgf.builder.getBlock()->getParent());
1128+
retBlocks.push_back(b);
1129+
updateRetLoc(b, loc);
1130+
return b;
11231131
}
11241132

11251133
cir::ReturnOp emitReturn(mlir::Location loc);
11261134
void emitImplicitReturn();
11271135

11281136
public:
1129-
mlir::Block *getRetBlock() { return returnBlock; }
1130-
mlir::Location getRetLoc(mlir::Block *b) { return *returnLoc; }
1131-
void updateRetLoc(mlir::Block *b, mlir::Location loc) { returnLoc = loc; }
1132-
1133-
// Create the return block for this scope, or return the existing one.
1134-
// This get-or-create logic is necessary to handle multiple return
1135-
// statements within the same scope, which can happen if some of them are
1136-
// dead code or if there is a `goto` into the middle of the scope.
1137+
llvm::ArrayRef<mlir::Block *> getRetBlocks() { return retBlocks; }
1138+
mlir::Location getRetLoc(mlir::Block *b) { return retLocs.at(b); }
1139+
void updateRetLoc(mlir::Block *b, mlir::Location loc) {
1140+
retLocs.insert_or_assign(b, loc);
1141+
}
1142+
11371143
mlir::Block *getOrCreateRetBlock(CIRGenFunction &cgf, mlir::Location loc) {
1138-
if (returnBlock == nullptr) {
1139-
returnBlock = createRetBlock(cgf, loc);
1140-
return returnBlock;
1144+
// Check if we're inside a case region
1145+
if (auto caseOp = mlir::dyn_cast_if_present<cir::CaseOp>(
1146+
cgf.builder.getBlock()->getParentOp())) {
1147+
auto iter = retBlockInCaseIndex.find(caseOp);
1148+
if (iter != retBlockInCaseIndex.end()) {
1149+
// Reuse existing return block
1150+
mlir::Block *ret = retBlocks[iter->second];
1151+
updateRetLoc(ret, loc);
1152+
return ret;
1153+
}
1154+
// Create new return block
1155+
mlir::Block *ret = createRetBlock(cgf, loc);
1156+
retBlockInCaseIndex[caseOp] = retBlocks.size() - 1;
1157+
return ret;
11411158
}
1142-
updateRetLoc(returnBlock, loc);
1143-
return returnBlock;
1159+
1160+
if (normalRetBlockIndex) {
1161+
mlir::Block *ret = retBlocks[*normalRetBlockIndex];
1162+
updateRetLoc(ret, loc);
1163+
return ret;
1164+
}
1165+
1166+
mlir::Block *ret = createRetBlock(cgf, loc);
1167+
normalRetBlockIndex = retBlocks.size() - 1;
1168+
return ret;
11441169
}
11451170

11461171
mlir::Block *getEntryBlock() { return entryBlock; }

clang/test/CIR/CodeGen/switch.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,3 +1183,90 @@ int nested_switch(int a) {
11831183
// OGCG: [[IFEND10]]:
11841184
// OGCG: br label %[[EPILOG]]
11851185
// OGCG: [[EPILOG]]:
1186+
1187+
int sw_return_multi_cases(int x) {
1188+
switch (x) {
1189+
case 0:
1190+
return 0;
1191+
case 1:
1192+
return 1;
1193+
case 2:
1194+
return 2;
1195+
default:
1196+
return -1;
1197+
}
1198+
}
1199+
1200+
// CIR-LABEL: cir.func{{.*}} @_Z21sw_return_multi_casesi
1201+
// CIR: cir.switch (%{{.*}} : !s32i) {
1202+
// CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) {
1203+
// CIR: %[[ZERO:.*]] = cir.const #cir.int<0> : !s32i
1204+
// CIR: cir.store{{.*}} %[[ZERO]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1205+
// CIR: %[[RET0:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1206+
// CIR-NEXT: cir.return %[[RET0]] : !s32i
1207+
// CIR-NEXT: }
1208+
// CIR-NEXT: cir.case(equal, [#cir.int<1> : !s32i]) {
1209+
// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
1210+
// CIR: cir.store{{.*}} %[[ONE]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1211+
// CIR: %[[RET1:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1212+
// CIR-NEXT: cir.return %[[RET1]] : !s32i
1213+
// CIR-NEXT: }
1214+
// CIR-NEXT: cir.case(equal, [#cir.int<2> : !s32i]) {
1215+
// CIR: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
1216+
// CIR: cir.store{{.*}} %[[TWO]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1217+
// CIR: %[[RET2:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1218+
// CIR-NEXT: cir.return %[[RET2]] : !s32i
1219+
// CIR-NEXT: }
1220+
// CIR-NEXT: cir.case(default, []) {
1221+
// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
1222+
// CIR: %[[NEG:.*]] = cir.unary(minus, %[[ONE]]) {{.*}} : !s32i, !s32i
1223+
// CIR: cir.store{{.*}} %[[NEG]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1224+
// CIR: %[[RETDEF:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1225+
// CIR-NEXT: cir.return %[[RETDEF]] : !s32i
1226+
// CIR-NEXT: }
1227+
// CIR-NEXT: cir.yield
1228+
1229+
// LLVM-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi
1230+
// LLVM: switch i32 %{{.*}}, label %[[DEFAULT:.*]] [
1231+
// LLVM-DAG: i32 0, label %[[CASE0:.*]]
1232+
// LLVM-DAG: i32 1, label %[[CASE1:.*]]
1233+
// LLVM-DAG: i32 2, label %[[CASE2:.*]]
1234+
// LLVM: ]
1235+
// LLVM: [[CASE0]]:
1236+
// LLVM: store i32 0, ptr %{{.*}}, align 4
1237+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1238+
// LLVM: ret i32 %{{.*}}
1239+
// LLVM: [[CASE1]]:
1240+
// LLVM: store i32 1, ptr %{{.*}}, align 4
1241+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1242+
// LLVM: ret i32 %{{.*}}
1243+
// LLVM: [[CASE2]]:
1244+
// LLVM: store i32 2, ptr %{{.*}}, align 4
1245+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1246+
// LLVM: ret i32 %{{.*}}
1247+
// LLVM: [[DEFAULT]]:
1248+
// LLVM: store i32 -1, ptr %{{.*}}, align 4
1249+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1250+
// LLVM: ret i32 %{{.*}}
1251+
1252+
// OGCG-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi
1253+
// OGCG: entry:
1254+
// OGCG: %[[RETVAL:.*]] = alloca i32, align 4
1255+
// OGCG: %[[X_ADDR:.*]] = alloca i32, align 4
1256+
// OGCG: %[[X_VAL:.*]] = load i32, ptr %[[X_ADDR]], align 4
1257+
// OGCG: switch i32 %[[X_VAL]], label %[[DEFAULT:.*]] [
1258+
// OGCG-DAG: i32 0, label %[[SW0:.*]]
1259+
// OGCG-DAG: i32 1, label %[[SW1:.*]]
1260+
// OGCG-DAG: i32 2, label %[[SW2:.*]]
1261+
// OGCG: ]
1262+
// OGCG: [[SW0]]:
1263+
// OGCG: br label %[[RETURN:.*]]
1264+
// OGCG: [[SW1]]:
1265+
// OGCG: br label %[[RETURN]]
1266+
// OGCG: [[SW2]]:
1267+
// OGCG: br label %[[RETURN]]
1268+
// OGCG: [[DEFAULT]]:
1269+
// OGCG: br label %[[RETURN]]
1270+
// OGCG: [[RETURN]]:
1271+
// OGCG: %[[RETVAL_LOAD:.*]] = load i32, ptr %[[RETVAL]], align 4
1272+
// OGCG: ret i32 %[[RETVAL_LOAD]]

0 commit comments

Comments
 (0)