Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms] Dialect conversion: Unify materialization of value replacements #108381

Merged
merged 3 commits into from
Sep 21, 2024

Conversation

matthias-springer
Copy link
Member

PR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (legalizeConvertedArgumentTypes and legalizeConvertedOpResultTypes). This PR merges the two functions and moves the implementation directly into finalize.

This PR simplifies the code base and improves the efficiency a bit: previously, finalize iterated over ConversionPatternRewriterImpl::rewrites twice. Now, only one iteration is needed.

@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2024

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

PR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (legalizeConvertedArgumentTypes and legalizeConvertedOpResultTypes). This PR merges the two functions and moves the implementation directly into finalize.

This PR simplifies the code base and improves the efficiency a bit: previously, finalize iterated over ConversionPatternRewriterImpl::rewrites twice. Now, only one iteration is needed.


Full diff: https://github.com/llvm/llvm-project/pull/108381.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+42-92)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+2-2)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ed15b571f01883..0556b4ab833c30 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2336,17 +2336,6 @@ struct OperationConverter {
   /// remaining artifacts and complete the conversion.
   LogicalResult finalize(ConversionPatternRewriter &rewriter);
 
-  /// Legalize the types of converted block arguments.
-  LogicalResult
-  legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
-                                 ConversionPatternRewriterImpl &rewriterImpl);
-
-  /// Legalize the types of converted op results.
-  LogicalResult legalizeConvertedOpResultTypes(
-      ConversionPatternRewriter &rewriter,
-      ConversionPatternRewriterImpl &rewriterImpl,
-      DenseMap<Value, SmallVector<Value>> &inverseMapping);
-
   /// Dialect conversion configuration.
   ConversionConfig config;
 
@@ -2510,19 +2499,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   return success();
 }
 
-LogicalResult
-OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
-  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
-  if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
-    return failure();
-  DenseMap<Value, SmallVector<Value>> inverseMapping =
-      rewriterImpl.mapping.getInverse();
-  if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
-                                            inverseMapping)))
-    return failure();
-  return success();
-}
-
 /// Finds a user of the given value, or of any other value that the given value
 /// replaced, that was not replaced in the conversion process.
 static Operation *findLiveUserOfReplaced(
@@ -2546,87 +2522,61 @@ static Operation *findLiveUserOfReplaced(
   return nullptr;
 }
 
-LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  // Process requested operation replacements.
-  for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
-    auto *opReplacement =
-        dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
-    if (!opReplacement)
-      continue;
-    Operation *op = opReplacement->getOperation();
-    for (OpResult result : op->getResults()) {
-      // If the type of this op result changed and the result is still live,
-      // we need to materialize a conversion.
-      if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
+/// Helper function that returns the replaced values and the type converter if
+/// the given rewrite object is an "operation replacement" or a "block type
+/// conversion" (which corresponds to a "block replacement"). Otherwise, return
+/// an empty ValueRange and a null type converter pointer.
+static std::pair<ValueRange, const TypeConverter *>
+getReplacedValues(IRRewrite *rewrite) {
+  if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
+    return std::make_pair(opRewrite->getOperation()->getResults(),
+                          opRewrite->getConverter());
+  if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
+    return std::make_pair(blockRewrite->getOrigBlock()->getArguments(),
+                          blockRewrite->getConverter());
+  return std::make_pair(ValueRange(), nullptr);
+}
+
+LogicalResult
+OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
+  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+  DenseMap<Value, SmallVector<Value>> inverseMapping =
+      rewriterImpl.mapping.getInverse();
+
+  // Process requested value replacements.
+  for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
+    ValueRange replacedValues;
+    const TypeConverter *converter;
+    std::tie(replacedValues, converter) =
+        getReplacedValues(rewriterImpl.rewrites[i].get());
+    for (Value originalValue : replacedValues) {
+      // If the type of this value changed and the value is still live, we need
+      // to materialize a conversion.
+      if (rewriterImpl.mapping.lookupOrNull(originalValue,
+                                            originalValue.getType()))
         continue;
       Operation *liveUser =
-          findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
+          findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
       if (!liveUser)
         continue;
 
-      // Legalize this result.
-      Value newValue = rewriterImpl.mapping.lookupOrNull(result);
+      // Legalize this value replacement.
+      Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
       assert(newValue && "replacement value not found");
       Value castValue = rewriterImpl.buildUnresolvedMaterialization(
-          MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
-          /*inputs=*/newValue, /*outputType=*/result.getType(),
-          opReplacement->getConverter());
-      rewriterImpl.mapping.map(result, castValue);
-      inverseMapping[castValue].push_back(result);
-      llvm::erase(inverseMapping[newValue], result);
+          MaterializationKind::Source, computeInsertPoint(newValue),
+          originalValue.getLoc(),
+          /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
+          converter);
+      rewriterImpl.mapping.map(originalValue, castValue);
+      inverseMapping[castValue].push_back(originalValue);
+      llvm::erase(inverseMapping[newValue], originalValue);
     }
   }
 
   return success();
 }
 
-LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl) {
-  // Functor used to check if all users of a value will be dead after
-  // conversion.
-  // TODO: This should probably query the inverse mapping, same as in
-  // `legalizeConvertedOpResultTypes`.
-  auto findLiveUser = [&](Value val) {
-    auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
-      return rewriterImpl.isOpIgnored(user);
-    });
-    return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
-  };
-  // Note: `rewrites` may be reallocated as the loop is running.
-  for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
-       ++i) {
-    auto &rewrite = rewriterImpl.rewrites[i];
-    if (auto *blockTypeConversionRewrite =
-            dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
-      // Process the remapping for each of the original arguments.
-      for (Value origArg :
-           blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
-        // If the type of this argument changed and the argument is still live,
-        // we need to materialize a conversion.
-        if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-          continue;
-        Operation *liveUser = findLiveUser(origArg);
-        if (!liveUser)
-          continue;
-
-        Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
-        assert(replacementValue && "replacement value not found");
-        Value repl = rewriterImpl.buildUnresolvedMaterialization(
-            MaterializationKind::Source, computeInsertPoint(replacementValue),
-            origArg.getLoc(), /*inputs=*/replacementValue,
-            /*outputType=*/origArg.getType(),
-            blockTypeConversionRewrite->getConverter());
-        rewriterImpl.mapping.map(origArg, repl);
-      }
-    }
-  }
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Reconcile Unrealized Casts
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index d8570bdaf4247f..25ec5d0159bd5d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -558,8 +558,8 @@ func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
 
 // CHECK-LABEL: func @deinterleave_scalar
 // CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
-//       CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
-//       CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+//   CHECK-DAG: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+//   CHECK-DAG: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
 //   CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
 //   CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
 //       CHECK: return %[[CAST0]], %[[CAST1]]

@matthias-springer matthias-springer force-pushed the users/matthias-springer/mat_cache branch from a9c69d1 to 4cb4bcf Compare September 12, 2024 13:31
matthias-springer added a commit that referenced this pull request Sep 12, 2024
This commit is in preparation of #108381, which changes the insertion point source materializations during a block type conversion slightly.
matthias-springer added a commit that referenced this pull request Sep 12, 2024
This commit is in preparation of #108381, which changes the insertion
point source materializations during a block type conversion slightly.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/mat_cache branch from 4cb4bcf to 066359e Compare September 12, 2024 13:36
@matthias-springer matthias-springer force-pushed the users/matthias-springer/bbarg_opresult_mat branch from 1f215ac to 0fd4cb8 Compare September 12, 2024 13:36
@matthias-springer matthias-springer force-pushed the users/matthias-springer/mat_cache branch from 066359e to e724e44 Compare September 13, 2024 17:55
Base automatically changed from users/matthias-springer/mat_cache to main September 13, 2024 18:16
@matthias-springer matthias-springer force-pushed the users/matthias-springer/bbarg_opresult_mat branch from 0fd4cb8 to 22e4550 Compare September 13, 2024 18:17
matthias-springer and others added 2 commits September 21, 2024 09:52
… replacements

PR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (`legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes`). This PR merges the two functions and moves the implementation directly into `finalize`.

This PR simplifies the code base and improves the efficiency a bit: previously, `finalize` iterates over `ConversionPatternRewriterImpl::rewrites` twice. Now, only one iteration is needed.
Co-authored-by: Jakub Kuderski <jakub@nod-labs.com>
@matthias-springer matthias-springer force-pushed the users/matthias-springer/bbarg_opresult_mat branch from c3d664a to 0a06737 Compare September 21, 2024 07:53
Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 97aa8cc94d94e8f0adc85489f7832ba7c0a9b577 0a06737c45740e46f73d32e06bed12dc6d17cccc --extensions cpp -- mlir/lib/Transforms/Utils/DialectConversion.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7eabca4338..69036e947e 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2531,11 +2531,10 @@ static Operation *findLiveUserOfReplaced(
 static std::pair<ValueRange, const TypeConverter *>
 getReplacedValues(IRRewrite *rewrite) {
   if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
-    return {opRewrite->getOperation()->getResults(),
-                          opRewrite->getConverter()};
+    return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
   if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
     return {blockRewrite->getOrigBlock()->getArguments(),
-                          blockRewrite->getConverter()};
+            blockRewrite->getConverter()};
   return {};
 }
 

@matthias-springer matthias-springer merged commit 8527861 into main Sep 21, 2024
5 of 8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/bbarg_opresult_mat branch September 21, 2024 08:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:spirv mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants