Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][CF] Split cf-to-llvm from func-to-llvm #120580

Merged
merged 1 commit into from
Dec 20, 2024

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Dec 19, 2024

Do not run cf-to-llvm as part of func-to-llvm. This commit fixes #70982.

This commit changes the way how func.func ops are lowered to LLVM. Previously, the signature of the entire region (i.e., entry block and all other blocks in the func.func op) was converted as part of the func.func lowering pattern.

Now, only the entry block is converted. The remaining block signatures are converted together with cf.br and cf.cond_br as part of cf-to-llvm. All unstructured control flow is not converted as part of a single pass (cf-to-llvm). func-to-llvm no longer deals with unstructured control flow.

Also add more test cases for control flow dialect ops.

Note: This PR is in preparation of #120431, which adds an additional GPU-specific lowering for cf.assert. This was a problem because cf.assert used to be converted as part of func-to-llvm.

Note for LLVM integration: If you see failures, add -convert-cf-to-llvm to your pass pipeline.

@llvmbot
Copy link
Member

llvmbot commented Dec 19, 2024

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-codegen
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-complex
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

Do not run cf-to-llvm as part of func-to-llvm`. This commit fixes #70982.

This commit changes the way how func.func ops are lowered to LLVM. Previously, the signature of the entire region (i.e., entry block and all other blocks in the func.func op) was converted as part of the func.func lowering pattern.

Now, only the entry block is converted. The remaining block signatures are converted together with cf.br and cf.cond_br as part of cf-to-llvm. All unstructured control flow is not converted as part of a single pass (cf-to-llvm). func-to-llvm no longer deals with unstructured control flow.


Patch is 37.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120580.diff

21 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (-4)
  • (modified) mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (+79-66)
  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+5-9)
  • (added) mlir/test/Conversion/ControlFlowToLLVM/branch.mlir (+69)
  • (removed) mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir (-42)
  • (added) mlir/test/Conversion/ControlFlowToLLVM/switch.mlir (+66)
  • (modified) mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir (+1-1)
  • (modified) mlir/test/Conversion/FuncToLLVM/func-memref.mlir (+2-2)
  • (modified) mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir (+20-2)
  • (modified) mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp (+3)
  • (modified) mlir/test/mlir-cpu-runner/async-error.mlir (+1-1)
  • (modified) mlir/test/mlir-cpu-runner/async-group.mlir (+1-1)
  • (modified) mlir/test/mlir-cpu-runner/async-value.mlir (+1-1)
  • (modified) mlir/test/mlir-cpu-runner/async.mlir (+1-1)
  • (modified) mlir/test/mlir-cpu-runner/bare-ptr-call-conv.mlir (+1-1)
  • (modified) mlir/test/mlir-cpu-runner/copy.mlir (+1-1)
  • (modified) mlir/test/mlir-cpu-runner/memref-reinterpret-cast.mlir (+1-1)
  • (modified) mlir/test/mlir-cpu-runner/memref-reshape.mlir (+1-1)
  • (modified) mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir (+1-1)
  • (modified) mlir/test/mlir-cpu-runner/unranked-memref.mlir (+1-1)
  • (modified) mlir/test/mlir-cpu-runner/utils.mlir (+4-4)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 8835e0a9099fdd..58ee87cf820396 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -460,10 +460,6 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
     1 value is returned, packed into an LLVM IR struct type. Function calls and
     returns are updated accordingly. Block argument types are updated to use
     LLVM IR types.
-
-    Note that until https://github.com/llvm/llvm-project/issues/70982 is resolved,
-    this pass includes patterns that lower `arith` and `cf` to LLVM. This is legacy
-    code due to when they were all converted in the same pass.
   }];
   let dependentDialects = ["LLVM::LLVMDialect"];
   let options = [
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index e5c735e10703a7..a79d27fecf0d25 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -94,60 +94,54 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
   bool abortOnFailedAssert = true;
 };
 
-/// The cf->LLVM lowerings for branching ops require that the blocks they jump
-/// to first have updated types which should be handled by a pattern operating
-/// on the parent op.
-static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
-                                          ValueRange operands,
-                                          ValueRange blockArgs, Location loc,
-                                          llvm::StringRef messagePrefix) {
-  for (const auto &idxAndTypes :
-       llvm::enumerate(llvm::zip(blockArgs, operands))) {
-    int64_t i = idxAndTypes.index();
-    Value argValue =
-        rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
-    Type operandType = std::get<1>(idxAndTypes.value()).getType();
-    // In the case of an invalid jump, the block argument will have been
-    // remapped to an UnrealizedConversionCast. In the case of a valid jump,
-    // there might still be a no-op conversion cast with both types being equal.
-    // Consider both of these details to see if the jump would be invalid.
-    if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
-            argValue.getDefiningOp())) {
-      if (op.getOperandTypes().front() != operandType) {
-        return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
-          diag << messagePrefix;
-          diag << "mismatched types from operand # " << i << " ";
-          diag << operandType;
-          diag << " not compatible with destination block argument type ";
-          diag << op.getOperandTypes().front();
-          diag << " which should be converted with the parent op.";
-        });
-      }
-    }
-  }
-  return success();
+/// Helper function for converting branch ops. This function converts the
+/// signature of the given block. If the new block signature is different from
+/// `expectedTypes`, returns "failure".
+static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
+                                            const TypeConverter *converter,
+                                            Operation *branchOp, Block *block,
+                                            TypeRange expectedTypes) {
+  assert(converter && "expected non-null type converter");
+  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+  // There is nothing to do if the types already match.
+  if (block->getArgumentTypes() == expectedTypes)
+    return block;
+
+  // Compute the new block argument types and convert the block.
+  std::optional<TypeConverter::SignatureConversion> conversion =
+      converter->convertBlockSignature(block);
+  if (!conversion)
+    return rewriter.notifyMatchFailure(branchOp,
+                                       "could not compute block signature");
+  if (expectedTypes != conversion->getConvertedTypes())
+    return rewriter.notifyMatchFailure(
+        branchOp,
+        "mismatch between adaptor operand types and computed block signature");
+  return rewriter.applySignatureConversion(block, *conversion, converter);
 }
 
-/// Ensure that all block types were updated and then create an LLVM::BrOp
+/// Convert the destination block signature (if necessary) and lower the branch
+/// op to llvm.br.
 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
   using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
-                                    op.getSuccessor()->getArguments(),
-                                    op.getLoc(),
-                                    /*messagePrefix=*/"")))
+    FailureOr<Block *> convertedBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
+                          TypeRange(adaptor.getOperands()));
+    if (failed(convertedBlock))
       return failure();
-
-    rewriter.replaceOpWithNewOp<LLVM::BrOp>(
-        op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
+    rewriter.replaceOpWithNewOp<LLVM::BrOp>(op, adaptor.getOperands(),
+                                            *convertedBlock);
     return success();
   }
 };
 
-/// Ensure that all block types were updated and then create an LLVM::CondBrOp
+/// Convert the destination block signatures (if necessary) and lower the
+/// branch op to llvm.cond_br.
 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
   using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
 
@@ -155,45 +149,56 @@ struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
   matchAndRewrite(cf::CondBranchOp op,
                   typename cf::CondBranchOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
-                                    op.getFalseDest()->getArguments(),
-                                    op.getLoc(), "in false case branch ")))
+    FailureOr<Block *> convertedTrueBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
+                          TypeRange(adaptor.getTrueDestOperands()));
+    if (failed(convertedTrueBlock))
       return failure();
-    if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
-                                    op.getTrueDest()->getArguments(),
-                                    op.getLoc(), "in true case branch ")))
+    FailureOr<Block *> convertedFalseBlock =
+        getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
+                          TypeRange(adaptor.getFalseDestOperands()));
+    if (failed(convertedFalseBlock))
       return failure();
-
     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
-        op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
+        op, adaptor.getCondition(), *convertedTrueBlock,
+        adaptor.getTrueDestOperands(), *convertedFalseBlock,
+        adaptor.getFalseDestOperands());
     return success();
   }
 };
 
-/// Ensure that all block types were updated and then create an LLVM::SwitchOp
+/// Convert the destination block signatures (if necessary) and lower the
+/// switch op to llvm.switch.
 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
   using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
-                                    op.getDefaultDestination()->getArguments(),
-                                    op.getLoc(), "in switch default case ")))
+    // Get or convert default block.
+    FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
+        rewriter, getTypeConverter(), op, op.getDefaultDestination(),
+        TypeRange(adaptor.getDefaultOperands()));
+    if (failed(convertedDefaultBlock))
       return failure();
 
-    for (const auto &i : llvm::enumerate(
-             llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
-      if (failed(verifyMatchingValues(
-              rewriter, std::get<0>(i.value()),
-              std::get<1>(i.value())->getArguments(), op.getLoc(),
-              "in switch case " + std::to_string(i.index()) + " "))) {
+    // Get or convert all case blocks.
+    SmallVector<Block *> caseDestinations;
+    SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
+    for (auto it : llvm::enumerate(op.getCaseDestinations())) {
+      Block *b = it.value();
+      FailureOr<Block *> convertedBlock =
+          getConvertedBlock(rewriter, getTypeConverter(), op, b,
+                            TypeRange(caseOperands[it.index()]));
+      if (failed(convertedBlock))
         return failure();
-      }
+      caseDestinations.push_back(*convertedBlock);
     }
 
     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
-        op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
+        op, adaptor.getFlag(), *convertedDefaultBlock,
+        adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
+        caseDestinations, caseOperands);
     return success();
   }
 };
@@ -230,14 +235,22 @@ struct ConvertControlFlowToLLVM
 
   /// Run the dialect converter on the module.
   void runOnOperation() override {
-    LLVMConversionTarget target(getContext());
-    RewritePatternSet patterns(&getContext());
-
-    LowerToLLVMOptions options(&getContext());
+    MLIRContext *ctx = &getContext();
+    LLVMConversionTarget target(*ctx);
+    // This pass lowers only CF dialect ops, but it also modifies block
+    // signatures inside other ops. These ops should be treated as legal. They
+    // are lowered by other passes.
+    target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+      return op->getDialect() !=
+             ctx->getLoadedDialect<cf::ControlFlowDialect>();
+    });
+
+    LowerToLLVMOptions options(ctx);
     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
       options.overrideIndexBitwidth(indexBitwidth);
 
-    LLVMTypeConverter converter(&getContext(), options);
+    LLVMTypeConverter converter(ctx, options);
+    RewritePatternSet patterns(ctx);
     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
 
     if (failed(applyPartialConversion(getOperation(), target,
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 938d7cb9a20040..790e18d2fccebe 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -432,11 +432,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
 
   rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
                               newFuncOp.end());
-  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), converter,
-                                         &result))) {
-    return rewriter.notifyMatchFailure(funcOp,
-                                       "region types conversion failed");
-  }
+  // Convert just the entry block. The remaining unstructured control flow is
+  // converted by ControlFlowToLLVM.
+  if (!newFuncOp.getBody().empty())
+    rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
+                                      &converter);
 
   // Fix the type mismatch between the materialized `llvm.ptr` and the expected
   // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
@@ -785,10 +785,6 @@ struct ConvertFuncToLLVMPass
     RewritePatternSet patterns(&getContext());
     populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
 
-    // TODO(https://github.com/llvm/llvm-project/issues/70982): Remove these in
-    // favor of their dedicated conversion passes.
-    cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
-
     LLVMConversionTarget target(getContext());
     if (failed(applyPartialConversion(m, target, std::move(patterns))))
       signalPassFailure();
diff --git a/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir b/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
new file mode 100644
index 00000000000000..9a0f2b77145440
--- /dev/null
+++ b/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
@@ -0,0 +1,69 @@
+// RUN: mlir-opt %s -convert-cf-to-llvm -split-input-file | FileCheck %s
+
+// Unstructured control flow is converted, but the enclosing op is not
+// converted.
+
+// CHECK-LABEL: func.func @cf_br(
+//  CHECK-SAME:     %[[arg0:.*]]: index) -> index {
+//       CHECK:   %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : index to i64
+//       CHECK:   llvm.br ^[[bb1:.*]](%[[cast0]] : i64)
+//       CHECK: ^[[bb1]](%[[arg1:.*]]: i64):
+//       CHECK:   %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : i64 to index
+//       CHECK:   return %[[cast1]] : index
+//       CHECK: }
+func.func @cf_br(%arg0: index) -> index {
+  cf.br ^bb1(%arg0 : index)
+^bb1(%arg1: index):
+  return %arg1 : index
+}
+
+// -----
+
+// func.func and func.return types match. No unrealized_conversion_cast is
+// needed.
+
+// CHECK-LABEL: func.func @cf_br_type_match(
+//  CHECK-SAME:     %[[arg0:.*]]: i64) -> i64 {
+//       CHECK:   llvm.br ^[[bb1:.*]](%[[arg0:.*]] : i64)
+//       CHECK: ^[[bb1]](%[[arg1:.*]]: i64):
+//       CHECK:   return %[[arg1]] : i64
+//       CHECK: }
+func.func @cf_br_type_match(%arg0: i64) -> i64 {
+  cf.br ^bb1(%arg0 : i64)
+^bb1(%arg1: i64):
+  return %arg1 : i64
+}
+
+// -----
+
+// Test case for cf.cond_br.
+
+//   CHECK-LABEL: func.func @cf_cond_br
+// CHECK-COUNT-2:   unrealized_conversion_cast {{.*}} : index to i64
+//         CHECK:   llvm.cond_br %{{.*}}, ^{{.*}}(%{{.*}} : i64), ^{{.*}}(%{{.*}} : i64)
+//         CHECK: ^{{.*}}(%{{.*}}: i64):
+//         CHECK:   unrealized_conversion_cast {{.*}} : i64 to index
+//         CHECK: ^{{.*}}(%{{.*}}: i64):
+//         CHECK:   unrealized_conversion_cast {{.*}} : i64 to index
+func.func @cf_cond_br(%cond: i1, %a: index, %b: index) -> index {
+  cf.cond_br %cond, ^bb1(%a : index), ^bb2(%b : index)
+^bb1(%arg1: index):
+  return %arg1 : index
+^bb2(%arg2: index):
+  return %arg2 : index
+}
+
+// -----
+
+// Unreachable block (and IR in general) is not converted during a dialect
+// conversion.
+
+// CHECK-LABEL: func.func @unreachable_block()
+//       CHECK:   return
+//       CHECK: ^[[bb1:.*]](%[[arg0:.*]]: index):
+//       CHECK:   cf.br ^[[bb1]](%[[arg0]] : index)
+func.func @unreachable_block() {
+  return
+^bb1(%arg0: index):
+  cf.br ^bb1(%arg0 : index)
+}
diff --git a/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir b/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir
deleted file mode 100644
index a2afa233a26e8d..00000000000000
--- a/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir
+++ /dev/null
@@ -1,42 +0,0 @@
-// RUN: mlir-opt %s -convert-cf-to-llvm | FileCheck %s
-
-func.func @name(%flag: i32, %pred: i1){
-    // Test cf.br lowering failure with type mismatch
-    // CHECK: cf.br
-    %c0 = arith.constant 0 : index
-    cf.br ^bb1(%c0 : index)
-
-  // Test cf.cond_br lowering failure with type mismatch in false_dest
-  // CHECK: cf.cond_br
-  ^bb1(%0: index):  // 2 preds: ^bb0, ^bb2
-    %c1 = arith.constant 1 : i1
-    %c2 = arith.constant 1 : index
-    cf.cond_br %pred, ^bb2(%c1: i1), ^bb3(%c2: index)
-
-  // Test cf.cond_br lowering failure with type mismatch in true_dest
-  // CHECK: cf.cond_br
-  ^bb2(%1: i1):
-    %c3 = arith.constant 1 : i1
-    %c4 = arith.constant 1 : index
-    cf.cond_br %pred, ^bb3(%c4: index), ^bb2(%c3: i1)
-
-  // Test cf.switch lowering failure with type mismatch in default case
-  // CHECK: cf.switch
-  ^bb3(%2: index):  // pred: ^bb1
-    %c5 = arith.constant 1 : i1
-    %c6 = arith.constant 1 : index
-    cf.switch %flag : i32, [
-      default: ^bb1(%c6 : index),
-      42: ^bb4(%c5 : i1)
-    ]
-
-  // Test cf.switch lowering failure with type mismatch in non-default case
-  // CHECK: cf.switch
-  ^bb4(%3: i1):  // pred: ^bb1
-    %c7 = arith.constant 1 : i1
-    %c8 = arith.constant 1 : index
-    cf.switch %flag : i32, [
-      default: ^bb2(%c7 : i1),
-      41: ^bb1(%c8 : index)
-    ]
-  }
diff --git a/mlir/test/Conversion/ControlFlowToLLVM/switch.mlir b/mlir/test/Conversion/ControlFlowToLLVM/switch.mlir
new file mode 100644
index 00000000000000..0bf4b02e8e3d70
--- /dev/null
+++ b/mlir/test/Conversion/ControlFlowToLLVM/switch.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -convert-cf-to-llvm -split-input-file | FileCheck %s
+
+// Unstructured control flow is converted, but the enclosing op is not
+// converted.
+
+// CHECK-LABEL: func.func @single_case(
+//  CHECK-SAME:     %[[val:.*]]: i32, %[[idx:.*]]: index) -> index {
+//       CHECK:   %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[idx]] : index to i64
+//       CHECK:   llvm.switch %[[val]] : i32, ^[[bb1:.*]](%[[cast0]] : i64) [
+//       CHECK:   ]
+//       CHECK: ^[[bb1]](%[[arg0:.*]]: i64):
+//       CHECK:   %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to index
+//       CHECK:   return %[[cast1]] : index
+//       CHECK: }
+func.func @single_case(%val: i32, %idx: index) -> index {
+  cf.switch %val : i32, [
+    default: ^bb1(%idx : index)
+  ]
+^bb1(%arg0: index):
+  return %arg0 : index
+}
+
+// -----
+
+// func.func and func.return types match. No unrealized_conversion_cast is
+// needed.
+
+// CHECK-LABEL: func.func @single_case_type_match(
+//  CHECK-SAME:     %[[val:.*]]: i32, %[[i:.*]]: i64) -> i64 {
+//       CHECK:   llvm.switch %[[val]] : i32, ^[[bb1:.*]](%[[i]] : i64) [
+//       CHECK:   ]
+//       CHECK: ^[[bb1]](%[[arg0:.*]]: i64):
+//       CHECK:   return %[[arg0]] : i64
+//       CHECK: }
+func.func @single_case_type_match(%val: i32, %i: i64) -> i64 {
+  cf.switch %val : i32, [
+    default: ^bb1(%i : i64)
+  ]
+^bb1(%arg0: i64):
+  return %arg0 : i64
+}
+
+// -----
+
+//   CHECK-LABEL: func.func @multi_case
+// CHECK-COUNT-2:   unrealized_conversion_cast {{.*}} : index to i64
+//         CHECK:   llvm.switch %{{.*}} : i32, ^{{.*}}(%{{.*}} : i64) [
+//         CHECK:     12: ^{{.*}}(%{{.*}} : i64),
+//         CHECK:     13: ^{{.*}}(%{{.*}} : i64),
+//         CHECK:     14: ^{{.*}}(%{{.*}} : i64)
+//         CHECK:   ]
+func.func @multi_case(%val: i32, %idx1: index, %idx2: index, %i: i64) -> index {
+  cf.switch %val : i32, [
+    default: ^bb1(%idx1 : index),
+    12: ^bb2(%idx2 : index),
+    13: ^bb1(%idx1 : index),
+    14: ^bb3(%i : i64)
+  ]
+^bb1(%arg0: index):
+  return %arg0 : index
+^bb2(%arg1: index):
+  return %arg1 : index
+^bb3(%arg2: i64):
+  %cast = arith.index_cast %arg2 : i64 to index
+  return %cast : index
+}
diff --git a/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir b/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir
index 755c4cf42689c2..ae1dc70d0686b2 100644
--- a/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-func-to-llvm -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -convert-func-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts -split-input-file -verify-diagnostics %s | FileCheck %s
 
 //CHECK: llvm.func @second_order_arg(!llvm.ptr)
 func.func private @second_order_arg(%arg0 : () -> ())
diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref.mlir
index d44a07bdcc9ab0..15a96543eb6b72 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-memref.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-memref.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" -split-input-file %s | FileCheck %s
-// RUN: mlir-opt -pass-pipeline="builtin.module(func...
[truncated]

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment on lines +97 to +103
/// Helper function for converting branch ops. This function converts the
/// signature of the given block. If the new block signature is different from
/// `expectedTypes`, returns "failure".
static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
const TypeConverter *converter,
Operation *branchOp, Block *block,
TypeRange expectedTypes) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(non-actionable side note): Feels like this could be part of dialect conversion in the future 🙂 Similar to remapValues but for successors

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is slightly different because replaceAllUsesWith for blocks is reflected immediately in the IR. (In contrast to value replacements, which are being kept track of in the ConversionValueMapping.)

The reason why getConvertedBlock is needed is because there could be multiple branch ops that jump to the same block. In that case, the block should be converted only once.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/arith_to_llvm branch 2 times, most recently from 78704d8 to 6e95065 Compare December 20, 2024 09:01
Base automatically changed from users/matthias-springer/arith_to_llvm to main December 20, 2024 09:14
@matthias-springer matthias-springer force-pushed the users/matthias-springer/split_cf_conversion branch from a86a580 to 9eed724 Compare December 20, 2024 11:06
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:linalg mlir:spirv labels Dec 20, 2024
Copy link

github-actions bot commented Dec 20, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/split_cf_conversion branch from 9eed724 to e55c383 Compare December 20, 2024 11:12
@matthias-springer matthias-springer force-pushed the users/matthias-springer/split_cf_conversion branch from e55c383 to bd51fa2 Compare December 20, 2024 12:31
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Dec 20, 2024
@matthias-springer matthias-springer merged commit eb6c419 into main Dec 20, 2024
9 of 10 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/split_cf_conversion branch December 20, 2024 12:46
@llvm-ci
Copy link
Collaborator

llvm-ci commented Dec 20, 2024

LLVM Buildbot has detected a new failure on builder mlir-nvidia running on mlir-nvidia while building flang,mlir at step 7 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/8068

Here is the relevant piece of the build log for the reference
Step 7 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR :: Integration/GPU/CUDA/async.mlir' FAILED ********************
Exit Code: 2

Command Output (stdout):
--
# RUN: at line 1
/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-kernel-outlining  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary="format=fatbin"  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -reconcile-unrealized-casts  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-cpu-runner    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_cuda_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_async_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_runner_utils.so    --entry-point-result=void -O0  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-kernel-outlining
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt '-pass-pipeline=builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary=format=fatbin
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -reconcile-unrealized-casts
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-cpu-runner --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_cuda_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_async_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_runner_utils.so --entry-point-result=void -O0
# .---command stderr------------
# | loc("<stdin>":102:5): error: Dialect `cf' not found for custom op 'cf.assert' 
# | could not parse the input IR
# `-----------------------------
# error: command failed with exit status: 1
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# .---command stderr------------
# | FileCheck error: '<stdin>' is empty.
# | FileCheck command line:  /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# `-----------------------------
# error: command failed with exit status: 2

--

********************


matthias-springer added a commit that referenced this pull request Dec 20, 2024
This commit should have been part of #120580.
matthias-springer added a commit that referenced this pull request Dec 20, 2024
This commit should have been part of #120580.
matthias-springer added a commit that referenced this pull request Dec 20, 2024
This commit should have been part of #120580.
matthias-springer added a commit that referenced this pull request Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir]: Remove arith/cf converion patterns from FuncToLLVM.cpp
4 participants