diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 8a70883293d91..6a9316cbc690f 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -897,7 +897,18 @@ class ConversionPatternRewriter final : public PatternRewriter { /// Replace the given operation with the new value ranges. The number of op /// results and value ranges must match. The given operation is erased. - void replaceOpWithMultiple(Operation *op, ArrayRef newValues); + void replaceOpWithMultiple(Operation *op, + SmallVector> &&newValues); + template + void replaceOpWithMultiple(Operation *op, ArrayRef newValues) { + replaceOpWithMultiple(op, + llvm::to_vector_of>(newValues)); + } + template + void replaceOpWithMultiple(Operation *op, RangeT &&newValues) { + replaceOpWithMultiple(op, + ArrayRef(llvm::to_vector_of(newValues))); + } /// PatternRewriter hook for erasing a dead operation. The uses of this /// operation *must* be made dead by the end of the conversion process, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 6a66ad24a87b4..e5f9717c3fbaa 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -616,8 +616,7 @@ class SparseCallConverter : public OpConversionPattern { } assert(packedResultVals.size() == op.getNumResults()); - rewriter.replaceOpWithMultiple( - op, llvm::to_vector_of(packedResultVals)); + rewriter.replaceOpWithMultiple(op, std::move(packedResultVals)); return success(); } }; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index bca31f86683fa..b9475a7cc95a8 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -173,6 +173,10 @@ struct ConversionValueMapping { } } + void map(Value oldVal, SmallVector &&newVal) { + map(ValueVector{oldVal}, ValueVector(std::move(newVal))); + } + /// Drop the last mapping for the given values. void erase(const ValueVector &value) { mapping.erase(value); } @@ -946,7 +950,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { OpBuilder::InsertPoint previous) override; /// Notifies that an op is about to be replaced with the given values. - void notifyOpReplaced(Operation *op, ArrayRef newValues); + void notifyOpReplaced(Operation *op, + SmallVector> &&newValues); /// Notifies that a block is about to be erased. void notifyBlockIsBeingErased(Block *block); @@ -1519,7 +1524,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( } void ConversionPatternRewriterImpl::notifyOpReplaced( - Operation *op, ArrayRef newValues) { + Operation *op, SmallVector> &&newValues) { assert(newValues.size() == op->getNumResults()); assert(!ignoredOps.contains(op) && "operation was already replaced"); @@ -1561,7 +1566,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( // Remap result to replacement value. if (repl.empty()) continue; - mapping.map(result, repl); + mapping.map(static_cast(result), std::move(repl)); } appendRewrite(op, currentTypeConverter); @@ -1639,26 +1644,22 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); - SmallVector newVals; - for (size_t i = 0; i < newValues.size(); ++i) { - if (newValues[i]) { - newVals.push_back(newValues.slice(i, 1)); - } else { - newVals.push_back(ValueRange()); - } - } - impl->notifyOpReplaced(op, newVals); + SmallVector> newVals = + llvm::map_to_vector(newValues, [](Value v) -> SmallVector { + return v ? SmallVector{v} : SmallVector(); + }); + impl->notifyOpReplaced(op, std::move(newVals)); } void ConversionPatternRewriter::replaceOpWithMultiple( - Operation *op, ArrayRef newValues) { + Operation *op, SmallVector> &&newValues) { assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); LLVM_DEBUG({ impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); - impl->notifyOpReplaced(op, newValues); + impl->notifyOpReplaced(op, std::move(newValues)); } void ConversionPatternRewriter::eraseOp(Operation *op) { @@ -1666,8 +1667,8 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); - SmallVector nullRepls(op->getNumResults(), {}); - impl->notifyOpReplaced(op, nullRepls); + SmallVector> nullRepls(op->getNumResults(), {}); + impl->notifyOpReplaced(op, std::move(nullRepls)); } void ConversionPatternRewriter::eraseBlock(Block *block) { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index b868f1a3a08da..bfdcaf431eeff 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1278,6 +1278,29 @@ class TestMultiple1ToNReplacement : public ConversionPattern { } }; +/// Test unambiguous overload resolution of replaceOpWithMultiple. This +/// function is just to trigger compiler errors. It is never executed. +[[maybe_unused]] void testReplaceOpWithMultipleOverloads( + ConversionPatternRewriter &rewriter, Operation *op, ArrayRef r1, + SmallVector r2, ArrayRef> r3, + SmallVector> r4, ArrayRef> r5, + SmallVector> r6, SmallVector> &&r7, + Value v, ValueRange vr, ArrayRef ar) { + rewriter.replaceOpWithMultiple(op, r1); + rewriter.replaceOpWithMultiple(op, r2); + rewriter.replaceOpWithMultiple(op, r3); + rewriter.replaceOpWithMultiple(op, r4); + rewriter.replaceOpWithMultiple(op, r5); + rewriter.replaceOpWithMultiple(op, r6); + rewriter.replaceOpWithMultiple(op, std::move(r7)); + rewriter.replaceOpWithMultiple(op, {vr}); + rewriter.replaceOpWithMultiple(op, {ar}); + rewriter.replaceOpWithMultiple(op, {{v}}); + rewriter.replaceOpWithMultiple(op, {{v, v}}); + rewriter.replaceOpWithMultiple(op, {{v, v}, vr}); + rewriter.replaceOpWithMultiple(op, {{v, v}, ar}); + rewriter.replaceOpWithMultiple(op, {ar, {v, v}, vr}); +} } // namespace namespace {