diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp index 52e9adb75ba0..29df0d5cea95 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp @@ -24,6 +24,7 @@ #include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/LowerToMLIR.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" using namespace cir; using namespace llvm; @@ -570,15 +571,82 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern { } } + void optimizeOnCertainBreak(mlir::scf::WhileOp whileOp, + mlir::ConversionPatternRewriter &rewriter) const { + // Collect all BreakOp inside this while. + llvm::SmallVector breaks; + whileOp->walk([&](mlir::Operation *op) { + if (auto breakOp = dyn_cast(op)) + breaks.push_back(breakOp); + }); + if (breaks.empty()) + return; + auto *pp = whileOp->getParentOp(); + pp->dump(); + for (auto breakOp : breaks) { + // When there is another loop between this WhileOp and the BreakOp, + // we should change that loop instead. + if (breakOp->getParentOfType() != whileOp) + continue; + // Similar to the case of ContinueOp, when there is an `IfOp`, + // we need to take special care. + for (mlir::Operation *parent = breakOp->getParentOp(); parent != whileOp; + parent = parent->getParentOp()) { + if (auto ifOp = dyn_cast(parent)) + llvm_unreachable("NYI"); + } + // Operations after this BreakOp has to be removed. + for (mlir::Operation *runner = breakOp->getNextNode(); runner;) { + mlir::Operation *next = runner->getNextNode(); + runner->erase(); + runner = next; + } + + // Blocks after this BreakOp also has to be removed. + for (mlir::Block *block = breakOp->getBlock()->getNextNode(); block;) { + mlir::Block *next = block->getNextNode(); + block->erase(); + block = next; + } + + // We know this BreakOp isn't nested in any IfOp. + // Therefore, the loop is executed only once. + // We pull everything out of the loop. + auto &beforeOps = whileOp.getBeforeBody()->getOperations(); + for (mlir::Operation *op = &*beforeOps.begin(); op;) { + if (isa(op)) + break; + auto *next = op->getNextNode(); + op->moveBefore(whileOp); + op = next; + } + + auto &afterOps = whileOp.getAfterBody()->getOperations(); + for (mlir::Operation *op = &*afterOps.begin(); op;) { + if (isa(op)) + break; + auto *next = op->getNextNode(); + op->moveBefore(whileOp); + op = next; + } + } + + rewriter.eraseOp(whileOp); + pp->dump(); + } + public: using OpConversionPattern::OpConversionPattern; + /// This rewrite will do some optimizations at the same time. + /// Unreachable code and unnecessary loops will be eliminated. mlir::LogicalResult matchAndRewrite(cir::WhileOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { SCFWhileLoop loop(op, adaptor, &rewriter); auto whileOp = loop.transferToSCFWhileOp(); rewriteContinue(whileOp, rewriter); + optimizeOnCertainBreak(whileOp, rewriter); rewriter.eraseOp(op); return mlir::success(); } diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 5d2b4180571a..7668fc38ce5c 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -48,19 +48,17 @@ #include "mlir/Transforms/DialectConversion.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/IR/CIRTypes.h" +#include "clang/CIR/Interfaces/CIRLoopOpInterface.h" #include "clang/CIR/LowerToLLVM.h" #include "clang/CIR/LowerToMLIR.h" #include "clang/CIR/LoweringHelpers.h" #include "clang/CIR/Passes.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/ErrorHandling.h" -#include "clang/CIR/Interfaces/CIRLoopOpInterface.h" -#include "clang/CIR/LowerToLLVM.h" -#include "clang/CIR/Passes.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Value.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/TimeProfiler.h" using namespace cir; @@ -946,8 +944,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern { } else { // For scopes with results, use scf.execute_region SmallVector types; - if (mlir::failed( - getTypeConverter()->convertTypes(scopeOp->getResultTypes(), types))) + if (mlir::failed(getTypeConverter()->convertTypes( + scopeOp->getResultTypes(), types))) return mlir::failure(); auto exec = rewriter.create(scopeOp.getLoc(), types); @@ -1023,6 +1021,28 @@ class CIRYieldOpLowering : public mlir::OpConversionPattern { } }; +class CIRBreakOpLowering : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + mlir::LogicalResult + matchAndRewrite(cir::BreakOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto *parentOp = op->getParentOp(); + return llvm::TypeSwitch(parentOp) + .Case([&](auto) { + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands()); + return mlir::success(); + }) + .Case([&](auto) { + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands()); + return mlir::success(); + }) + .Default([](auto) { return mlir::failure(); }); + } +}; + class CIRIfOpLowering : public mlir::OpConversionPattern { public: using mlir::OpConversionPattern::OpConversionPattern; @@ -1519,24 +1539,23 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter) { patterns.add(patterns.getContext()); - patterns - .add(converter, patterns.getContext()); + patterns.add< + CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering, + CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering, + CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering, + CIRFuncOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering, + CIRYieldOpLowering, CIRBreakOpLowering, CIRCosOpLowering, + CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering, + CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering, + CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering, + CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering, + CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering, CIRSinOpLowering, + CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering, + CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering, + CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering, + CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering, + CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering, + CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext()); } static mlir::TypeConverter prepareTypeConverter() { @@ -1610,7 +1629,7 @@ void ConvertCIRToMLIRPass::runOnOperation() { mlir::ModuleOp theModule = getOperation(); auto converter = prepareTypeConverter(); - + mlir::RewritePatternSet patterns(&getContext()); populateCIRLoopToSCFConversionPatterns(patterns, converter); @@ -1628,10 +1647,11 @@ void ConvertCIRToMLIRPass::runOnOperation() { // cir dialect, for example the `cir.continue`. If we marked cir as illegal // here, then MLIR would think any remaining `cir.continue` indicates a // failure, which is not what we want. - - patterns.add(converter, context); - if (mlir::failed(mlir::applyPartialConversion(theModule, target, + patterns.add(converter, context); + + if (mlir::failed(mlir::applyPartialConversion(theModule, target, std::move(patterns)))) { signalPassFailure(); } diff --git a/clang/test/CIR/Lowering/ThroughMLIR/for_with_continue.cpp b/clang/test/CIR/Lowering/ThroughMLIR/for-with-continue.cpp similarity index 100% rename from clang/test/CIR/Lowering/ThroughMLIR/for_with_continue.cpp rename to clang/test/CIR/Lowering/ThroughMLIR/for-with-continue.cpp diff --git a/clang/test/CIR/Lowering/ThroughMLIR/while-with-break.cpp b/clang/test/CIR/Lowering/ThroughMLIR/while-with-break.cpp new file mode 100644 index 000000000000..be1022961c39 --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/while-with-break.cpp @@ -0,0 +1,25 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s + +void while_break() { + int i = 0; + while (i < 100) { + i++; + break; + i++; + } + // This should be compiled into the condition `i < 100` and a single `i++`, + // without the while-loop. + + // CHECK: memref.alloca_scope { + // CHECK: %[[IV:.+]] = memref.load %alloca[] + // CHECK: %[[HUNDRED:.+]] = arith.constant 100 + // CHECK: %[[_:.+]] = arith.cmpi slt, %[[IV]], %[[HUNDRED]] + // CHECK: memref.alloca_scope { + // CHECK: %[[IV2:.+]] = memref.load %alloca[] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[INCR:.+]] = arith.addi %[[IV2]], %[[ONE]] + // CHECK: memref.store %[[INCR]], %alloca[] + // CHECK: } + // CHECK: } +}