-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Dialect conversion: Unify materialization of value replacements #108381
[mlir][Transforms] Dialect conversion: Unify materialization of value replacements #108381
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesPR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical ( This PR simplifies the code base and improves the efficiency a bit: previously, Full diff: https://github.com/llvm/llvm-project/pull/108381.diff 2 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ed15b571f01883..0556b4ab833c30 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2336,17 +2336,6 @@ struct OperationConverter {
/// remaining artifacts and complete the conversion.
LogicalResult finalize(ConversionPatternRewriter &rewriter);
- /// Legalize the types of converted block arguments.
- LogicalResult
- legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl);
-
- /// Legalize the types of converted op results.
- LogicalResult legalizeConvertedOpResultTypes(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping);
-
/// Dialect conversion configuration.
ConversionConfig config;
@@ -2510,19 +2499,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
return success();
}
-LogicalResult
-OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
- ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
- if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
- return failure();
- DenseMap<Value, SmallVector<Value>> inverseMapping =
- rewriterImpl.mapping.getInverse();
- if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
- inverseMapping)))
- return failure();
- return success();
-}
-
/// Finds a user of the given value, or of any other value that the given value
/// replaced, that was not replaced in the conversion process.
static Operation *findLiveUserOfReplaced(
@@ -2546,87 +2522,61 @@ static Operation *findLiveUserOfReplaced(
return nullptr;
}
-LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- // Process requested operation replacements.
- for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
- auto *opReplacement =
- dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
- if (!opReplacement)
- continue;
- Operation *op = opReplacement->getOperation();
- for (OpResult result : op->getResults()) {
- // If the type of this op result changed and the result is still live,
- // we need to materialize a conversion.
- if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
+/// Helper function that returns the replaced values and the type converter if
+/// the given rewrite object is an "operation replacement" or a "block type
+/// conversion" (which corresponds to a "block replacement"). Otherwise, return
+/// an empty ValueRange and a null type converter pointer.
+static std::pair<ValueRange, const TypeConverter *>
+getReplacedValues(IRRewrite *rewrite) {
+ if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
+ return std::make_pair(opRewrite->getOperation()->getResults(),
+ opRewrite->getConverter());
+ if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
+ return std::make_pair(blockRewrite->getOrigBlock()->getArguments(),
+ blockRewrite->getConverter());
+ return std::make_pair(ValueRange(), nullptr);
+}
+
+LogicalResult
+OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
+ ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+ DenseMap<Value, SmallVector<Value>> inverseMapping =
+ rewriterImpl.mapping.getInverse();
+
+ // Process requested value replacements.
+ for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
+ ValueRange replacedValues;
+ const TypeConverter *converter;
+ std::tie(replacedValues, converter) =
+ getReplacedValues(rewriterImpl.rewrites[i].get());
+ for (Value originalValue : replacedValues) {
+ // If the type of this value changed and the value is still live, we need
+ // to materialize a conversion.
+ if (rewriterImpl.mapping.lookupOrNull(originalValue,
+ originalValue.getType()))
continue;
Operation *liveUser =
- findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
+ findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
if (!liveUser)
continue;
- // Legalize this result.
- Value newValue = rewriterImpl.mapping.lookupOrNull(result);
+ // Legalize this value replacement.
+ Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
assert(newValue && "replacement value not found");
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
- /*inputs=*/newValue, /*outputType=*/result.getType(),
- opReplacement->getConverter());
- rewriterImpl.mapping.map(result, castValue);
- inverseMapping[castValue].push_back(result);
- llvm::erase(inverseMapping[newValue], result);
+ MaterializationKind::Source, computeInsertPoint(newValue),
+ originalValue.getLoc(),
+ /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
+ converter);
+ rewriterImpl.mapping.map(originalValue, castValue);
+ inverseMapping[castValue].push_back(originalValue);
+ llvm::erase(inverseMapping[newValue], originalValue);
}
}
return success();
}
-LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl) {
- // Functor used to check if all users of a value will be dead after
- // conversion.
- // TODO: This should probably query the inverse mapping, same as in
- // `legalizeConvertedOpResultTypes`.
- auto findLiveUser = [&](Value val) {
- auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
- return rewriterImpl.isOpIgnored(user);
- });
- return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
- };
- // Note: `rewrites` may be reallocated as the loop is running.
- for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
- ++i) {
- auto &rewrite = rewriterImpl.rewrites[i];
- if (auto *blockTypeConversionRewrite =
- dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
- // Process the remapping for each of the original arguments.
- for (Value origArg :
- blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
- // If the type of this argument changed and the argument is still live,
- // we need to materialize a conversion.
- if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
- continue;
- Operation *liveUser = findLiveUser(origArg);
- if (!liveUser)
- continue;
-
- Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
- assert(replacementValue && "replacement value not found");
- Value repl = rewriterImpl.buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(replacementValue),
- origArg.getLoc(), /*inputs=*/replacementValue,
- /*outputType=*/origArg.getType(),
- blockTypeConversionRewrite->getConverter());
- rewriterImpl.mapping.map(origArg, repl);
- }
- }
- }
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index d8570bdaf4247f..25ec5d0159bd5d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -558,8 +558,8 @@ func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
// CHECK-LABEL: func @deinterleave_scalar
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
-// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
-// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+// CHECK-DAG: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+// CHECK-DAG: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
// CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
// CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
// CHECK: return %[[CAST0]], %[[CAST1]]
|
a9c69d1
to
4cb4bcf
Compare
This commit is in preparation of #108381, which changes the insertion point source materializations during a block type conversion slightly.
This commit is in preparation of #108381, which changes the insertion point source materializations during a block type conversion slightly.
4cb4bcf
to
066359e
Compare
1f215ac
to
0fd4cb8
Compare
066359e
to
e724e44
Compare
0fd4cb8
to
22e4550
Compare
… replacements PR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (`legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes`). This PR merges the two functions and moves the implementation directly into `finalize`. This PR simplifies the code base and improves the efficiency a bit: previously, `finalize` iterates over `ConversionPatternRewriterImpl::rewrites` twice. Now, only one iteration is needed.
Co-authored-by: Jakub Kuderski <jakub@nod-labs.com>
c3d664a
to
0a06737
Compare
You can test this locally with the following command:git-clang-format --diff 97aa8cc94d94e8f0adc85489f7832ba7c0a9b577 0a06737c45740e46f73d32e06bed12dc6d17cccc --extensions cpp -- mlir/lib/Transforms/Utils/DialectConversion.cpp View the diff from clang-format here.diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7eabca4338..69036e947e 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2531,11 +2531,10 @@ static Operation *findLiveUserOfReplaced(
static std::pair<ValueRange, const TypeConverter *>
getReplacedValues(IRRewrite *rewrite) {
if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
- return {opRewrite->getOperation()->getResults(),
- opRewrite->getConverter()};
+ return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
return {blockRewrite->getOrigBlock()->getArguments(),
- blockRewrite->getConverter()};
+ blockRewrite->getConverter()};
return {};
}
|
PR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (
legalizeConvertedArgumentTypes
andlegalizeConvertedOpResultTypes
). This PR merges the two functions and moves the implementation directly intofinalize
.This PR simplifies the code base and improves the efficiency a bit: previously,
finalize
iterated overConversionPatternRewriterImpl::rewrites
twice. Now, only one iteration is needed.