From dcb514663c48d055e8c0ef94e358d1df6c4189f0 Mon Sep 17 00:00:00 2001 From: Vinicius Couto Espindola Date: Tue, 6 Jun 2023 15:55:14 -0300 Subject: [PATCH] [CIR][Lowering] Patch If without Else lowering Lowering if operations without an else block would crash during. Some steps in the IfOp lowering were updated to only be applied if the else block is not empty. [ghstack-poisoned] --- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 32 ++++++++++++------- clang/test/CIR/Lowering/if.cir | 28 +++++++++++++--- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 9250218918ca..a2739915ce5a 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -282,8 +282,8 @@ class CIRIfLowering : public mlir::OpConversionPattern { matchAndRewrite(mlir::cir::IfOp ifOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::OpBuilder::InsertionGuard guard(rewriter); - auto loc = ifOp.getLoc(); + auto emptyElse = ifOp.getElseRegion().empty(); auto *currentBlock = rewriter.getInsertionBlock(); auto *remainingOpsBlock = @@ -310,10 +310,16 @@ class CIRIfLowering : public mlir::OpConversionPattern { rewriter.setInsertionPointToEnd(continueBlock); - // Inline then region - auto *elseBeforeBody = &ifOp.getElseRegion().front(); - auto *elseAfterBody = &ifOp.getElseRegion().back(); - rewriter.inlineRegionBefore(ifOp.getElseRegion(), thenAfterBody); + // Has else region: inline it. + mlir::Block *elseBeforeBody = nullptr; + mlir::Block *elseAfterBody = nullptr; + if (!emptyElse) { + elseBeforeBody = &ifOp.getElseRegion().front(); + elseAfterBody = &ifOp.getElseRegion().back(); + rewriter.inlineRegionBefore(ifOp.getElseRegion(), thenAfterBody); + } else { + elseBeforeBody = elseAfterBody = continueBlock; + } rewriter.setInsertionPointToEnd(currentBlock); auto trunc = rewriter.create(loc, rewriter.getI1Type(), @@ -321,13 +327,15 @@ class CIRIfLowering : public mlir::OpConversionPattern { rewriter.create(loc, trunc.getRes(), thenBeforeBody, elseBeforeBody); - rewriter.setInsertionPointToEnd(elseAfterBody); - if (auto elseYieldOp = - dyn_cast(elseAfterBody->getTerminator())) { - rewriter.replaceOpWithNewOp( - elseYieldOp, elseYieldOp.getArgs(), continueBlock); - } else if (!dyn_cast(elseAfterBody->getTerminator())) { - llvm_unreachable("what are we terminating with?"); + if (!emptyElse) { + rewriter.setInsertionPointToEnd(elseAfterBody); + if (auto elseYieldOp = + dyn_cast(elseAfterBody->getTerminator())) { + rewriter.replaceOpWithNewOp( + elseYieldOp, elseYieldOp.getArgs(), continueBlock); + } else if (!dyn_cast(elseAfterBody->getTerminator())) { + llvm_unreachable("what are we terminating with?"); + } } rewriter.replaceOp(ifOp, continueBlock->getArguments()); diff --git a/clang/test/CIR/Lowering/if.cir b/clang/test/CIR/Lowering/if.cir index c7ed945d0892..f70460347a5e 100644 --- a/clang/test/CIR/Lowering/if.cir +++ b/clang/test/CIR/Lowering/if.cir @@ -14,10 +14,8 @@ module { } cir.return %arg0 : !s32i } -} -// MLIR: module { -// MLIR-NEXT: llvm.func @foo(%arg0: i32) -> i32 { +// MLIR: llvm.func @foo(%arg0: i32) -> i32 { // MLIR-NEXT: %0 = llvm.mlir.constant(0 : i32) : i32 // MLIR-NEXT: %1 = llvm.icmp "ne" %arg0, %0 : i32 // MLIR-NEXT: %2 = llvm.zext %1 : i1 to i8 @@ -32,7 +30,6 @@ module { // MLIR-NEXT: ^bb3: // no predecessors // MLIR-NEXT: llvm.return %arg0 : i32 // MLIR-NEXT: } -// MLIR-NEXT: } // LLVM: define i32 @foo(i32 %0) { // LLVM-NEXT: %2 = icmp ne i32 %0, 0 @@ -49,3 +46,26 @@ module { // LLVM-NEXT: 7: // LLVM-NEXT: ret i32 %0 // LLVM-NEXT: } + + cir.func @onlyIf(%arg0: !s32i) -> !s32i { + %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool + cir.if %4 { + %5 = cir.const(#cir.int<1> : !s32i) : !s32i + cir.return %5 : !s32i + } + cir.return %arg0 : !s32i + } + + // MLIR: llvm.func @onlyIf(%arg0: i32) -> i32 { + // MLIR-NEXT: %0 = llvm.mlir.constant(0 : i32) : i32 + // MLIR-NEXT: %1 = llvm.icmp "ne" %arg0, %0 : i32 + // MLIR-NEXT: %2 = llvm.zext %1 : i1 to i8 + // MLIR-NEXT: %3 = llvm.trunc %2 : i8 to i1 + // MLIR-NEXT: llvm.cond_br %3, ^bb1, ^bb2 + // MLIR-NEXT: ^bb1: // pred: ^bb0 + // MLIR-NEXT: %4 = llvm.mlir.constant(1 : i32) : i32 + // MLIR-NEXT: llvm.return %4 : i32 + // MLIR-NEXT: ^bb2: // pred: ^bb0 + // MLIR-NEXT: llvm.return %arg0 : i32 + // MLIR-NEXT: } +}