From 674b70dba23a21fe3f5bc02a97c33d9682867b1e Mon Sep 17 00:00:00 2001 From: Yue Huang Date: Sun, 25 May 2025 22:28:21 +0100 Subject: [PATCH] [CIR][ThroughMLIR] Lower uncanonicalized fors to whiles --- .../ThroughMLIR/LowerCIRLoopToSCF.cpp | 84 +++++++++++++------ .../CIR/Lowering/ThroughMLIR/for-reject-1.cpp | 23 ++++- .../CIR/Lowering/ThroughMLIR/for-reject-2.cpp | 24 +++++- 3 files changed, 97 insertions(+), 34 deletions(-) diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp index 69d58fea2703..8f90dd01f65e 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp @@ -13,20 +13,15 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/ValueRange.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/LowerToMLIR.h" -#include "clang/CIR/Passes.h" #include "llvm/ADT/TypeSwitch.h" using namespace cir; @@ -52,6 +47,7 @@ class SCFLoop { mlir::Value plusConstant(mlir::Value v, mlir::Location loc, int addend); void transferToSCFForOp(); + void transformToSCFWhileOp(); private: cir::ForOp forOp; @@ -209,21 +205,21 @@ cir::CmpOp SCFLoop::findCmpOp() { } } if (!cmpOp) - llvm_unreachable("Can't find loop CmpOp"); + return nullptr; auto type = cmpOp.getLhs().getType(); if (!mlir::isa(type)) - llvm_unreachable("Non-integer type IV is not supported"); + return nullptr; auto *lhsDefOp = cmpOp.getLhs().getDefiningOp(); if (!lhsDefOp) - llvm_unreachable("Can't find IV load"); + return nullptr; if (!isIVLoad(lhsDefOp, ivAddr)) - llvm_unreachable("cmpOp LHS is not IV"); + return nullptr; if (cmpOp.getKind() != cir::CmpOpKind::le && cmpOp.getKind() != cir::CmpOpKind::lt) - llvm_unreachable("Not support lowering other than le or lt comparison"); + return nullptr; return cmpOp; } @@ -253,30 +249,40 @@ mlir::Value SCFLoop::findIVInitValue() { void SCFLoop::analysis() { canonical = mlir::succeeded(findStepAndIV()); - if (!canonical) { - mlir::emitError(forOp.getLoc(), - "cannot handle non-constant step for induction variable"); + if (!canonical) return; - } cmpOp = findCmpOp(); - auto IVInit = findIVInitValue(); + if (!cmpOp) { + canonical = false; + return; + } + + auto ivInit = findIVInitValue(); + if (!ivInit) { + canonical = false; + return; + } + // The loop end value should be hoisted out of loop by -cir-mlir-scf-prepare. // So we could get the value by getRemappedValue. - auto IVEndBound = rewriter->getRemappedValue(cmpOp.getRhs()); - // If the loop end bound is not loop invariant and can't be hoisted. - // The following assertion will be triggerred. - assert(IVEndBound && "can't find IV end boundary"); + auto ivEndBound = rewriter->getRemappedValue(cmpOp.getRhs()); + // If the loop end bound is not loop invariant and can't be hoisted, + // then this is not a canonical loop. + if (!ivEndBound) { + canonical = false; + return; + } if (step > 0) { - lowerBound = IVInit; + lowerBound = ivInit; if (cmpOp.getKind() == cir::CmpOpKind::lt) - upperBound = IVEndBound; + upperBound = ivEndBound; else if (cmpOp.getKind() == cir::CmpOpKind::le) - upperBound = plusConstant(IVEndBound, cmpOp.getLoc(), 1); + upperBound = plusConstant(ivEndBound, cmpOp.getLoc(), 1); } - assert(lowerBound && "can't find loop lower bound"); - assert(upperBound && "can't find loop upper bound"); + if (!lowerBound || !upperBound) + canonical = false; } void SCFLoop::transferToSCFForOp() { @@ -309,6 +315,28 @@ void SCFLoop::transferToSCFForOp() { }); } +void SCFLoop::transformToSCFWhileOp() { + auto scfWhileOp = rewriter->create( + forOp->getLoc(), forOp->getResultTypes(), mlir::ValueRange()); + rewriter->createBlock(&scfWhileOp.getBefore()); + rewriter->createBlock(&scfWhileOp.getAfter()); + + rewriter->inlineBlockBefore(&forOp.getCond().front(), + scfWhileOp.getBeforeBody(), + scfWhileOp.getBeforeBody()->end()); + rewriter->inlineBlockBefore(&forOp.getBody().front(), + scfWhileOp.getAfterBody(), + scfWhileOp.getAfterBody()->end()); + // There will be a yield after the `for` body. + // We should delete it. + auto yield = mlir::cast(scfWhileOp.getAfterBody()->back()); + rewriter->eraseOp(yield); + + rewriter->inlineBlockBefore(&forOp.getStep().front(), + scfWhileOp.getAfterBody(), + scfWhileOp.getAfterBody()->end()); +} + void SCFWhileLoop::transferToSCFWhileOp() { auto scfWhileOp = rewriter->create( whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands()); @@ -352,9 +380,11 @@ class CIRForOpLowering : public mlir::OpConversionPattern { mlir::ConversionPatternRewriter &rewriter) const override { SCFLoop loop(op, &rewriter); loop.analysis(); - if (!loop.isCanonical()) - return mlir::emitError(op.getLoc(), - "cannot handle non-canonicalized loop"); + if (!loop.isCanonical()) { + loop.transformToSCFWhileOp(); + rewriter.eraseOp(op); + return mlir::success(); + } loop.transferToSCFForOp(); rewriter.eraseOp(op); diff --git a/clang/test/CIR/Lowering/ThroughMLIR/for-reject-1.cpp b/clang/test/CIR/Lowering/ThroughMLIR/for-reject-1.cpp index 9efd71e979c1..60267bfbb953 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/for-reject-1.cpp +++ b/clang/test/CIR/Lowering/ThroughMLIR/for-reject-1.cpp @@ -1,9 +1,24 @@ -// RUN: not %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o - 2>&1 | FileCheck %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s -void f(); +void f() {} void reject() { for (int i = 0; i < 100; i++, f()); - // CHECK: cannot handle non-constant step for induction variable - // CHECK: cannot handle non-canonicalized loop + // CHECK: %[[ALLOCA:.+]] = memref.alloca + // CHECK: %[[ZERO:.+]] = arith.constant 0 + // CHECK: memref.store %[[ZERO]], %[[ALLOCA]] + // CHECK: %[[HUNDRED:.+]] = arith.constant 100 + // CHECK: scf.while : () -> () { + // CHECK: %[[TMP:.+]] = memref.load %[[ALLOCA]] + // CHECK: %[[TMP1:.+]] = arith.cmpi slt, %0, %[[HUNDRED]] + // CHECK: scf.condition(%[[TMP1]]) + // CHECK: } do { + // CHECK: %[[TMP2:.+]] = memref.load %[[ALLOCA]] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[TMP3:.+]] = arith.addi %[[TMP2]], %[[ONE]] + // CHECK: memref.store %[[TMP3]], %[[ALLOCA]] + // CHECK: func.call @_Z1fv() + // CHECK: scf.yield + // CHECK: } } diff --git a/clang/test/CIR/Lowering/ThroughMLIR/for-reject-2.cpp b/clang/test/CIR/Lowering/ThroughMLIR/for-reject-2.cpp index bff2a8e74d12..c58d0675ccc6 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/for-reject-2.cpp +++ b/clang/test/CIR/Lowering/ThroughMLIR/for-reject-2.cpp @@ -1,7 +1,25 @@ -// RUN: not %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o - 2>&1 | FileCheck %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s void reject() { for (int i = 0; i < 100; i++, i++); - // CHECK: cannot handle non-constant step for induction variable - // CHECK: cannot handle non-canonicalized loop + // CHECK: %[[ALLOCA:.+]] = memref.alloca + // CHECK: %[[ZERO:.+]] = arith.constant 0 + // CHECK: memref.store %[[ZERO]], %[[ALLOCA]] + // CHECK: %[[HUNDRED:.+]] = arith.constant 100 + // CHECK: scf.while : () -> () { + // CHECK: %[[TMP:.+]] = memref.load %[[ALLOCA]] + // CHECK: %[[TMP2:.+]] = arith.cmpi slt, %[[TMP]], %[[HUNDRED]] + // CHECK: scf.condition(%[[TMP2]]) + // CHECK: } do { + // CHECK: %[[TMP3:.+]] = memref.load %[[ALLOCA]] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[ADD:.+]] = arith.addi %[[TMP3]], %[[ONE]] + // CHECK: memref.store %[[ADD]], %[[ALLOCA]] + // CHECK: %[[LOAD:.+]] = memref.load %[[ALLOCA]] + // CHECK: %[[ONE2:.+]] = arith.constant 1 + // CHECK: %[[ADD2:.+]] = arith.addi %[[LOAD]], %[[ONE2]] + // CHECK: memref.store %[[ADD2]], %[[ALLOCA]] + // CHECK: scf.yield + // CHECK: } }