diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 4c870693b648..8c020922f3b1 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -282,8 +282,8 @@ class CIRIfLowering : public mlir::OpConversionPattern { matchAndRewrite(mlir::cir::IfOp ifOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::OpBuilder::InsertionGuard guard(rewriter); - auto loc = ifOp.getLoc(); + auto emptyElse = ifOp.getElseRegion().empty(); auto *currentBlock = rewriter.getInsertionBlock(); auto *remainingOpsBlock = @@ -310,10 +310,16 @@ class CIRIfLowering : public mlir::OpConversionPattern { rewriter.setInsertionPointToEnd(continueBlock); - // Inline then region - auto *elseBeforeBody = &ifOp.getElseRegion().front(); - auto *elseAfterBody = &ifOp.getElseRegion().back(); - rewriter.inlineRegionBefore(ifOp.getElseRegion(), thenAfterBody); + // Has else region: inline it. + mlir::Block *elseBeforeBody = nullptr; + mlir::Block *elseAfterBody = nullptr; + if (!emptyElse) { + elseBeforeBody = &ifOp.getElseRegion().front(); + elseAfterBody = &ifOp.getElseRegion().back(); + rewriter.inlineRegionBefore(ifOp.getElseRegion(), thenAfterBody); + } else { + elseBeforeBody = elseAfterBody = continueBlock; + } rewriter.setInsertionPointToEnd(currentBlock); auto trunc = rewriter.create(loc, rewriter.getI1Type(), @@ -321,13 +327,15 @@ class CIRIfLowering : public mlir::OpConversionPattern { rewriter.create(loc, trunc.getRes(), thenBeforeBody, elseBeforeBody); - rewriter.setInsertionPointToEnd(elseAfterBody); - if (auto elseYieldOp = - dyn_cast(elseAfterBody->getTerminator())) { - rewriter.replaceOpWithNewOp( - elseYieldOp, elseYieldOp.getArgs(), continueBlock); - } else if (!dyn_cast(elseAfterBody->getTerminator())) { - llvm_unreachable("what are we terminating with?"); + if (!emptyElse) { + rewriter.setInsertionPointToEnd(elseAfterBody); + if (auto elseYieldOp = + dyn_cast(elseAfterBody->getTerminator())) { + rewriter.replaceOpWithNewOp( + elseYieldOp, elseYieldOp.getArgs(), continueBlock); + } else if (!dyn_cast(elseAfterBody->getTerminator())) { + llvm_unreachable("what are we terminating with?"); + } } rewriter.replaceOp(ifOp, continueBlock->getArguments()); diff --git a/clang/test/CIR/Lowering/if.cir b/clang/test/CIR/Lowering/if.cir index c7ed945d0892..f70460347a5e 100644 --- a/clang/test/CIR/Lowering/if.cir +++ b/clang/test/CIR/Lowering/if.cir @@ -14,10 +14,8 @@ module { } cir.return %arg0 : !s32i } -} -// MLIR: module { -// MLIR-NEXT: llvm.func @foo(%arg0: i32) -> i32 { +// MLIR: llvm.func @foo(%arg0: i32) -> i32 { // MLIR-NEXT: %0 = llvm.mlir.constant(0 : i32) : i32 // MLIR-NEXT: %1 = llvm.icmp "ne" %arg0, %0 : i32 // MLIR-NEXT: %2 = llvm.zext %1 : i1 to i8 @@ -32,7 +30,6 @@ module { // MLIR-NEXT: ^bb3: // no predecessors // MLIR-NEXT: llvm.return %arg0 : i32 // MLIR-NEXT: } -// MLIR-NEXT: } // LLVM: define i32 @foo(i32 %0) { // LLVM-NEXT: %2 = icmp ne i32 %0, 0 @@ -49,3 +46,26 @@ module { // LLVM-NEXT: 7: // LLVM-NEXT: ret i32 %0 // LLVM-NEXT: } + + cir.func @onlyIf(%arg0: !s32i) -> !s32i { + %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool + cir.if %4 { + %5 = cir.const(#cir.int<1> : !s32i) : !s32i + cir.return %5 : !s32i + } + cir.return %arg0 : !s32i + } + + // MLIR: llvm.func @onlyIf(%arg0: i32) -> i32 { + // MLIR-NEXT: %0 = llvm.mlir.constant(0 : i32) : i32 + // MLIR-NEXT: %1 = llvm.icmp "ne" %arg0, %0 : i32 + // MLIR-NEXT: %2 = llvm.zext %1 : i1 to i8 + // MLIR-NEXT: %3 = llvm.trunc %2 : i8 to i1 + // MLIR-NEXT: llvm.cond_br %3, ^bb1, ^bb2 + // MLIR-NEXT: ^bb1: // pred: ^bb0 + // MLIR-NEXT: %4 = llvm.mlir.constant(1 : i32) : i32 + // MLIR-NEXT: llvm.return %4 : i32 + // MLIR-NEXT: ^bb2: // pred: ^bb0 + // MLIR-NEXT: llvm.return %arg0 : i32 + // MLIR-NEXT: } +}