diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index caea9e111afed..69036e947ebdb 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2338,17 +2338,6 @@ struct OperationConverter { /// remaining artifacts and complete the conversion. LogicalResult finalize(ConversionPatternRewriter &rewriter); - /// Legalize the types of converted block arguments. - LogicalResult - legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl); - - /// Legalize the types of converted op results. - LogicalResult legalizeConvertedOpResultTypes( - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl, - DenseMap> &inverseMapping); - /// Dialect conversion configuration. ConversionConfig config; @@ -2512,19 +2501,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { return success(); } -LogicalResult -OperationConverter::finalize(ConversionPatternRewriter &rewriter) { - ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) - return failure(); - DenseMap> inverseMapping = - rewriterImpl.mapping.getInverse(); - if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl, - inverseMapping))) - return failure(); - return success(); -} - /// Finds a user of the given value, or of any other value that the given value /// replaced, that was not replaced in the conversion process. static Operation *findLiveUserOfReplaced( @@ -2548,87 +2524,60 @@ static Operation *findLiveUserOfReplaced( return nullptr; } -LogicalResult OperationConverter::legalizeConvertedOpResultTypes( - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl, - DenseMap> &inverseMapping) { - // Process requested operation replacements. - for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) { - auto *opReplacement = - dyn_cast(rewriterImpl.rewrites[i].get()); - if (!opReplacement) - continue; - Operation *op = opReplacement->getOperation(); - for (OpResult result : op->getResults()) { - // If the type of this op result changed and the result is still live, - // we need to materialize a conversion. - if (rewriterImpl.mapping.lookupOrNull(result, result.getType())) +/// Helper function that returns the replaced values and the type converter if +/// the given rewrite object is an "operation replacement" or a "block type +/// conversion" (which corresponds to a "block replacement"). Otherwise, return +/// an empty ValueRange and a null type converter pointer. +static std::pair +getReplacedValues(IRRewrite *rewrite) { + if (auto *opRewrite = dyn_cast(rewrite)) + return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()}; + if (auto *blockRewrite = dyn_cast(rewrite)) + return {blockRewrite->getOrigBlock()->getArguments(), + blockRewrite->getConverter()}; + return {}; +} + +LogicalResult +OperationConverter::finalize(ConversionPatternRewriter &rewriter) { + ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); + DenseMap> inverseMapping = + rewriterImpl.mapping.getInverse(); + + // Process requested value replacements. + for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) { + ValueRange replacedValues; + const TypeConverter *converter; + std::tie(replacedValues, converter) = + getReplacedValues(rewriterImpl.rewrites[i].get()); + for (Value originalValue : replacedValues) { + // If the type of this value changed and the value is still live, we need + // to materialize a conversion. + if (rewriterImpl.mapping.lookupOrNull(originalValue, + originalValue.getType())) continue; Operation *liveUser = - findLiveUserOfReplaced(result, rewriterImpl, inverseMapping); + findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping); if (!liveUser) continue; - // Legalize this result. - Value newValue = rewriterImpl.mapping.lookupOrNull(result); + // Legalize this value replacement. + Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue); assert(newValue && "replacement value not found"); Value castValue = rewriterImpl.buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(result), op->getLoc(), - /*inputs=*/newValue, /*outputType=*/result.getType(), - opReplacement->getConverter()); - rewriterImpl.mapping.map(result, castValue); - inverseMapping[castValue].push_back(result); - llvm::erase(inverseMapping[newValue], result); + MaterializationKind::Source, computeInsertPoint(newValue), + originalValue.getLoc(), + /*inputs=*/newValue, /*outputType=*/originalValue.getType(), + converter); + rewriterImpl.mapping.map(originalValue, castValue); + inverseMapping[castValue].push_back(originalValue); + llvm::erase(inverseMapping[newValue], originalValue); } } return success(); } -LogicalResult OperationConverter::legalizeConvertedArgumentTypes( - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl) { - // Functor used to check if all users of a value will be dead after - // conversion. - // TODO: This should probably query the inverse mapping, same as in - // `legalizeConvertedOpResultTypes`. - auto findLiveUser = [&](Value val) { - auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) { - return rewriterImpl.isOpIgnored(user); - }); - return liveUserIt == val.user_end() ? nullptr : *liveUserIt; - }; - // Note: `rewrites` may be reallocated as the loop is running. - for (int64_t i = 0; i < static_cast(rewriterImpl.rewrites.size()); - ++i) { - auto &rewrite = rewriterImpl.rewrites[i]; - if (auto *blockTypeConversionRewrite = - dyn_cast(rewrite.get())) { - // Process the remapping for each of the original arguments. - for (Value origArg : - blockTypeConversionRewrite->getOrigBlock()->getArguments()) { - // If the type of this argument changed and the argument is still live, - // we need to materialize a conversion. - if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) - continue; - Operation *liveUser = findLiveUser(origArg); - if (!liveUser) - continue; - - Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg); - assert(replacementValue && "replacement value not found"); - Value repl = rewriterImpl.buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(replacementValue), - origArg.getLoc(), /*inputs=*/replacementValue, - /*outputType=*/origArg.getType(), - blockTypeConversionRewrite->getConverter()); - rewriterImpl.mapping.map(origArg, repl); - } - } - } - return success(); -} - //===----------------------------------------------------------------------===// // Reconcile Unrealized Casts //===----------------------------------------------------------------------===//