Skip to content

Commit

Permalink
[CIR][Lowering][Bugfix] Refactor for loop lowering
Browse files Browse the repository at this point in the history
This refactor merges the lowering logic of all the different kinds of
loops into a single function.

It also removes unnecessary LIT tests that validate LLVM dialect to
LLVM IR lowering, as this functionality is not within CIR's scope.

Fixes llvm#153

ghstack-source-id: ebaab859057a6d81f1978fd88701c28402712562
Pull Request resolved: llvm/clangir#156
  • Loading branch information
sitio-couto authored and lanza committed Oct 12, 2024
1 parent 702856a commit af22575
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 209 deletions.
110 changes: 22 additions & 88 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
return mlir::success();
}

mlir::LogicalResult
rewriteWhileLoop(mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter,
mlir::cir::LoopOpKind kind) const {
mlir::LogicalResult rewriteLoop(mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter,
mlir::cir::LoopOpKind kind) const {
auto *currentBlock = rewriter.getInsertionBlock();
auto *continueBlock =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
Expand All @@ -150,16 +149,24 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
if (fetchCondRegionYields(condRegion, yieldToBody, yieldToCont).failed())
return loopOp.emitError("failed to fetch yields in cond region");

// Fetch required info from the condition region.
// Fetch required info from the body region.
auto &bodyRegion = loopOp.getBody();
auto &bodyFrontBlock = bodyRegion.front();
auto bodyYield =
dyn_cast<mlir::cir::YieldOp>(bodyRegion.back().getTerminator());
assert(bodyYield && "unstructured while loops are NYI");

// Fetch required info from the step region.
auto &stepRegion = loopOp.getStep();
auto &stepFrontBlock = stepRegion.front();
auto stepYield =
dyn_cast<mlir::cir::YieldOp>(stepRegion.back().getTerminator());

// Move loop op region contents to current CFG.
rewriter.inlineRegionBefore(condRegion, continueBlock);
rewriter.inlineRegionBefore(bodyRegion, continueBlock);
if (kind == LoopKind::For) // Ignore step if not a for-loop.
rewriter.inlineRegionBefore(stepRegion, continueBlock);

// Set loop entry point to condition or to body in do-while cases.
rewriter.setInsertionPointToEnd(currentBlock);
Expand All @@ -174,9 +181,16 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
rewriter.setInsertionPoint(yieldToBody);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldToBody, &bodyFrontBlock);

// Branch from body to condition.
// Branch from body to condition or to step on for-loop cases.
rewriter.setInsertionPoint(bodyYield);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &condFrontBlock);
auto &bodyExit = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &bodyExit);

// Is a for loop: branch from step to condition.
if (kind == LoopKind::For) {
rewriter.setInsertionPoint(stepYield);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(stepYield, &condFrontBlock);
}

// Remove the loop op.
rewriter.eraseOp(loopOp);
Expand All @@ -188,91 +202,11 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
switch (loopOp.getKind()) {
case LoopKind::For:
break;
case LoopKind::While:
case LoopKind::DoWhile:
return rewriteWhileLoop(loopOp, adaptor, rewriter, loopOp.getKind());
return rewriteLoop(loopOp, adaptor, rewriter, loopOp.getKind());
}

auto loc = loopOp.getLoc();

auto *currentBlock = rewriter.getInsertionBlock();
auto *remainingOpsBlock =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
mlir::Block *continueBlock;
if (loopOp->getResults().size() == 0)
continueBlock = remainingOpsBlock;
else
llvm_unreachable("NYI");

auto &condRegion = loopOp.getCond();
auto &condFrontBlock = condRegion.front();

auto &stepRegion = loopOp.getStep();
auto &stepFrontBlock = stepRegion.front();
auto &stepBackBlock = stepRegion.back();

auto &bodyRegion = loopOp.getBody();
auto &bodyFrontBlock = bodyRegion.front();
auto &bodyBackBlock = bodyRegion.back();

bool rewroteContinue = false;
bool rewroteBreak = false;

for (auto &bb : condRegion) {
if (rewroteContinue && rewroteBreak)
break;

if (auto yieldOp = dyn_cast<mlir::cir::YieldOp>(bb.getTerminator())) {
rewriter.setInsertionPointToEnd(yieldOp->getBlock());
if (yieldOp.getKind().has_value()) {
switch (yieldOp.getKind().value()) {
case mlir::cir::YieldOpKind::Break:
case mlir::cir::YieldOpKind::Fallthrough:
case mlir::cir::YieldOpKind::NoSuspend:
llvm_unreachable("None of these should be present");
case mlir::cir::YieldOpKind::Continue:;
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
yieldOp, yieldOp.getArgs(), &stepFrontBlock);
rewroteContinue = true;
}
} else {
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
yieldOp, yieldOp.getArgs(), continueBlock);
rewroteBreak = true;
}
}
}

rewriter.inlineRegionBefore(condRegion, continueBlock);

rewriter.inlineRegionBefore(stepRegion, continueBlock);

if (auto stepYieldOp =
dyn_cast<mlir::cir::YieldOp>(stepBackBlock.getTerminator())) {
rewriter.setInsertionPointToEnd(stepYieldOp->getBlock());
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
stepYieldOp, stepYieldOp.getArgs(), &bodyFrontBlock);
} else {
llvm_unreachable("What are we terminating with?");
}

rewriter.inlineRegionBefore(bodyRegion, continueBlock);

if (auto bodyYieldOp =
dyn_cast<mlir::cir::YieldOp>(bodyBackBlock.getTerminator())) {
rewriter.setInsertionPointToEnd(bodyYieldOp->getBlock());
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
bodyYieldOp, bodyYieldOp.getArgs(), &condFrontBlock);
} else {
llvm_unreachable("What are we terminating with?");
}

rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<mlir::cir::BrOp>(loc, mlir::ValueRange(), &condFrontBlock);

rewriter.replaceOp(loopOp, continueBlock->getArguments());

return mlir::success();
}
};
Expand Down
100 changes: 18 additions & 82 deletions clang/test/CIR/Lowering/dot.cir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: cir-tool %s -cir-to-llvm -o - | FileCheck %s -check-prefix=MLIR
// RUN: cir-tool %s -cir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
// RUN: cir-tool %s -cir-to-llvm -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s -check-prefix=MLIR

!s32i = !cir.int<s, 32>
module {
Expand Down Expand Up @@ -95,24 +95,24 @@ module {
// MLIR-NEXT: ^bb4: // pred: ^bb2
// MLIR-NEXT: llvm.br ^bb7
// MLIR-NEXT: ^bb5: // pred: ^bb3
// MLIR-NEXT: %22 = llvm.load %12 : !llvm.ptr
// MLIR-NEXT: %23 = llvm.mlir.constant(1 : i32) : i32
// MLIR-NEXT: %24 = llvm.add %22, %23 : i32
// MLIR-NEXT: llvm.store %24, %12 : i32, !llvm.ptr
// MLIR-NEXT: %22 = llvm.load %1 : !llvm.ptr
// MLIR-NEXT: %23 = llvm.load %12 : !llvm.ptr
// MLIR-NEXT: %24 = llvm.getelementptr %22[%23] : (!llvm.ptr, i32) -> !llvm.ptr
// MLIR-NEXT: %25 = llvm.load %24 : !llvm.ptr
// MLIR-NEXT: %26 = llvm.load %3 : !llvm.ptr
// MLIR-NEXT: %27 = llvm.load %12 : !llvm.ptr
// MLIR-NEXT: %28 = llvm.getelementptr %26[%27] : (!llvm.ptr, i32) -> !llvm.ptr
// MLIR-NEXT: %29 = llvm.load %28 : !llvm.ptr
// MLIR-NEXT: %30 = llvm.fmul %25, %29 : f64
// MLIR-NEXT: %31 = llvm.load %9 : !llvm.ptr
// MLIR-NEXT: %32 = llvm.fadd %31, %30 : f64
// MLIR-NEXT: llvm.store %32, %9 : f64, !llvm.ptr
// MLIR-NEXT: llvm.br ^bb6
// MLIR-NEXT: ^bb6: // pred: ^bb5
// MLIR-NEXT: %25 = llvm.load %1 : !llvm.ptr
// MLIR-NEXT: %26 = llvm.load %12 : !llvm.ptr
// MLIR-NEXT: %27 = llvm.getelementptr %25[%26] : (!llvm.ptr, i32) -> !llvm.ptr
// MLIR-NEXT: %28 = llvm.load %27 : !llvm.ptr
// MLIR-NEXT: %29 = llvm.load %3 : !llvm.ptr
// MLIR-NEXT: %30 = llvm.load %12 : !llvm.ptr
// MLIR-NEXT: %31 = llvm.getelementptr %29[%30] : (!llvm.ptr, i32) -> !llvm.ptr
// MLIR-NEXT: %32 = llvm.load %31 : !llvm.ptr
// MLIR-NEXT: %33 = llvm.fmul %28, %32 : f64
// MLIR-NEXT: %34 = llvm.load %9 : !llvm.ptr
// MLIR-NEXT: %35 = llvm.fadd %34, %33 : f64
// MLIR-NEXT: llvm.store %35, %9 : f64, !llvm.ptr
// MLIR-NEXT: %33 = llvm.load %12 : !llvm.ptr
// MLIR-NEXT: %34 = llvm.mlir.constant(1 : i32) : i32
// MLIR-NEXT: %35 = llvm.add %33, %34 : i32
// MLIR-NEXT: llvm.store %35, %12 : i32, !llvm.ptr
// MLIR-NEXT: llvm.br ^bb2
// MLIR-NEXT: ^bb7: // pred: ^bb4
// MLIR-NEXT: llvm.br ^bb8
Expand All @@ -123,67 +123,3 @@ module {
// MLIR-NEXT: llvm.return %37 : f64
// MLIR-NEXT: }
// MLIR-NEXT: }

// LLVM: define double @dot(ptr %0, ptr %1, i32 %2) {
// LLVM-NEXT: %4 = alloca ptr, i64 1, align 8
// LLVM-NEXT: %5 = alloca ptr, i64 1, align 8
// LLVM-NEXT: %6 = alloca i32, i64 1, align 4
// LLVM-NEXT: %7 = alloca double, i64 1, align 8
// LLVM-NEXT: %8 = alloca double, i64 1, align 8
// LLVM-NEXT: store ptr %0, ptr %4, align 8
// LLVM-NEXT: store ptr %1, ptr %5, align 8
// LLVM-NEXT: store i32 %2, ptr %6, align 4
// LLVM-NEXT: store double 0.000000e+00, ptr %8, align 8
// LLVM-NEXT: br label %9
// LLVM-EMPTY:
// LLVM-NEXT: 9: ; preds = %3
// LLVM-NEXT: %10 = alloca i32, i64 1, align 4
// LLVM-NEXT: store i32 0, ptr %10, align 4
// LLVM-NEXT: br label %11
// LLVM-EMPTY:
// LLVM-NEXT: 11: ; preds = %24, %9
// LLVM-NEXT: %12 = load i32, ptr %10, align 4
// LLVM-NEXT: %13 = load i32, ptr %6, align 4
// LLVM-NEXT: %14 = icmp slt i32 %12, %13
// LLVM-NEXT: %15 = zext i1 %14 to i32
// LLVM-NEXT: %16 = icmp ne i32 %15, 0
// LLVM-NEXT: %17 = zext i1 %16 to i8
// LLVM-NEXT: %18 = trunc i8 %17 to i1
// LLVM-NEXT: br i1 %18, label %19, label %20
// LLVM-EMPTY:
// LLVM-NEXT: 19: ; preds = %11
// LLVM-NEXT: br label %21
// LLVM-EMPTY:
// LLVM-NEXT: 20: ; preds = %11
// LLVM-NEXT: br label %36
// LLVM-EMPTY:
// LLVM-NEXT: 21: ; preds = %19
// LLVM-NEXT: %22 = load i32, ptr %10, align 4
// LLVM-NEXT: %23 = add i32 %22, 1
// LLVM-NEXT: store i32 %23, ptr %10, align 4
// LLVM-NEXT: br label %24
// LLVM-EMPTY:
// LLVM-NEXT: 24: ; preds = %21
// LLVM-NEXT: %25 = load ptr, ptr %4, align 8
// LLVM-NEXT: %26 = load i32, ptr %10, align 4
// LLVM-NEXT: %27 = getelementptr double, ptr %25, i32 %26
// LLVM-NEXT: %28 = load double, ptr %27, align 8
// LLVM-NEXT: %29 = load ptr, ptr %5, align 8
// LLVM-NEXT: %30 = load i32, ptr %10, align 4
// LLVM-NEXT: %31 = getelementptr double, ptr %29, i32 %30
// LLVM-NEXT: %32 = load double, ptr %31, align 8
// LLVM-NEXT: %33 = fmul double %28, %32
// LLVM-NEXT: %34 = load double, ptr %8, align 8
// LLVM-NEXT: %35 = fadd double %34, %33
// LLVM-NEXT: store double %35, ptr %8, align 8
// LLVM-NEXT: br label %11
// LLVM-EMPTY:
// LLVM-NEXT: 36: ; preds = %20
// LLVM-NEXT: br label %37
// LLVM-EMPTY:
// LLVM-NEXT: 37: ; preds = %36
// LLVM-NEXT: %38 = load double, ptr %8, align 8
// LLVM-NEXT: store double %38, ptr %7, align 8
// LLVM-NEXT: %39 = load double, ptr %7, align 8
// LLVM-NEXT: ret double %39
// LLVM-NEXT: }
49 changes: 10 additions & 39 deletions clang/test/CIR/Lowering/loop.cir
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// RUN: cir-tool %s -cir-to-llvm -o - | FileCheck %s -check-prefix=MLIR
// RUN: cir-tool %s -cir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
// RUN: cir-tool %s -cir-to-llvm -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s -check-prefix=MLIR

!s32i = !cir.int<s, 32>
module {
cir.func @foo() {
cir.func @testFor() {
%0 = cir.alloca !s32i, cir.ptr <!s32i>, ["i", init] {alignment = 4 : i64}
%1 = cir.const(#cir.int<0> : !s32i) : !s32i
cir.store %1, %0 : !s32i, cir.ptr <!s32i>
Expand All @@ -29,12 +29,13 @@ module {
}

// MLIR: module {
// MLIR-NEXT: llvm.func @foo() {
// MLIR-NEXT: llvm.func @testFor() {
// MLIR-NEXT: %0 = llvm.mlir.constant(1 : index) : i64
// MLIR-NEXT: %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i64) -> !llvm.ptr
// MLIR-NEXT: %2 = llvm.mlir.constant(0 : i32) : i32
// MLIR-NEXT: llvm.store %2, %1 : i32, !llvm.ptr
// MLIR-NEXT: llvm.br ^bb1
// ============= Condition block =============
// MLIR-NEXT: ^bb1: // 2 preds: ^bb0, ^bb5
// MLIR-NEXT: %3 = llvm.load %1 : !llvm.ptr
// MLIR-NEXT: %4 = llvm.mlir.constant(10 : i32) : i32
Expand All @@ -49,51 +50,21 @@ module {
// MLIR-NEXT: llvm.br ^bb4
// MLIR-NEXT: ^bb3: // pred: ^bb1
// MLIR-NEXT: llvm.br ^bb6
// ============= Body block =============
// MLIR-NEXT: ^bb4: // pred: ^bb2
// MLIR-NEXT: llvm.br ^bb5
// ============= Step block =============
// MLIR-NEXT: ^bb5: // pred: ^bb4
// MLIR-NEXT: %11 = llvm.load %1 : !llvm.ptr
// MLIR-NEXT: %12 = llvm.mlir.constant(1 : i32) : i32
// MLIR-NEXT: %13 = llvm.add %11, %12 : i32
// MLIR-NEXT: llvm.store %13, %1 : i32, !llvm.ptr
// MLIR-NEXT: llvm.br ^bb5
// MLIR-NEXT: ^bb5: // pred: ^bb4
// MLIR-NEXT: llvm.br ^bb1
// ============= Exit block =============
// MLIR-NEXT: ^bb6: // pred: ^bb3
// MLIR-NEXT: llvm.return
// MLIR-NEXT: }

// LLVM: define void @foo() {
// LLVM-NEXT: %1 = alloca i32, i64 1, align 4
// LLVM-NEXT: store i32 0, ptr %1, align 4
// LLVM-NEXT: br label %2
// LLVM-EMPTY:
// LLVM-NEXT: 2:
// LLVM-NEXT: %3 = load i32, ptr %1, align 4
// LLVM-NEXT: %4 = icmp slt i32 %3, 10
// LLVM-NEXT: %5 = zext i1 %4 to i32
// LLVM-NEXT: %6 = icmp ne i32 %5, 0
// LLVM-NEXT: %7 = zext i1 %6 to i8
// LLVM-NEXT: %8 = trunc i8 %7 to i1
// LLVM-NEXT: br i1 %8, label %9, label %10
// LLVM-EMPTY:
// LLVM-NEXT: 9:
// LLVM-NEXT: br label %11
// LLVM-EMPTY:
// LLVM-NEXT: 10:
// LLVM-NEXT: br label %15
// LLVM-EMPTY:
// LLVM-NEXT: 11:
// LLVM-NEXT: %12 = load i32, ptr %1, align 4
// LLVM-NEXT: %13 = add i32 %12, 1
// LLVM-NEXT: store i32 %13, ptr %1, align 4
// LLVM-NEXT: br label %14
// LLVM-EMPTY:
// LLVM-NEXT: 14:
// LLVM-NEXT: br label %2
// LLVM-EMPTY:
// LLVM-NEXT: 15:
// LLVM-NEXT: ret void
// LLVM-NEXT: }

// Test while cir.loop operation lowering.
cir.func @testWhile(%arg0: !s32i) {
%0 = cir.alloca !s32i, cir.ptr <!s32i>, ["i", init] {alignment = 4 : i64}
Expand Down

0 comments on commit af22575

Please sign in to comment.