diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 90db42d479a193..2277989bf8411b 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -28,6 +28,7 @@ #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -1467,6 +1468,47 @@ void SliceTrackingListener::notifyOperationReplaced(Operation *op, ValueRange replacement) { removeOp(op); } + +//===----------------------------------------------------------------------===// +// ReplacementListener +//===----------------------------------------------------------------------===// + +/// Listener that tracks updates replacements for values which can be mutated. +/// This listener runs on top of the existing listener for the rewriter, +/// to make sure external users can still run listeners. +class ReplacementListener : public RewriterBase::ForwardingListener { +public: + ReplacementListener(DenseMap &replacements, + OpBuilder::Listener *listener) + : ForwardingListener(listener), replacements(replacements) {} + + void updateReplacementValues(ValueRange origValues, + ValueRange replaceValues) { + // This can probably be written better, but just iterates over the map + // and the new replacements for now. + for (auto &[key, val] : replacements) { + for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) { + if (val == orig) { + val = replace; + } + } + } + } + + void notifyOperationReplaced(Operation *op, Operation *newOp) override { + ForwardingListener::notifyOperationReplaced(op, newOp); + updateReplacementValues(op->getResults(), newOp->getResults()); + } + + void notifyOperationReplaced(Operation *op, ValueRange values) override { + ForwardingListener::notifyOperationReplaced(op, values); + updateReplacementValues(op->getResults(), values); + } + +private: + DenseMap &replacements; +}; + } // namespace /// Implementation of tile consumer and fuse producer greedily. @@ -1493,26 +1535,27 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( for (auto *tiledOp : tilingResult->tiledOps) tiledAndFusedOps.insert(tiledOp); + DenseMap replacements; + for (auto [origVal, replacement] : llvm::zip_equal( + consumer->getResults(), tilingResult->mergeResult.replacements)) { + replacements[origVal] = replacement; + } + // If there are no loops generated, fusion is immaterial. auto &loops = tilingResult->loops; if (loops.empty()) { - DenseMap replacements; - for (auto [origVal, replacement] : llvm::zip_equal( - consumer->getResults(), tilingResult->mergeResult.replacements)) { - replacements[origVal] = replacement; - } return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, replacements}; } - // To keep track of replacements for now just record the map from the - // original untiled value to the result number of the for loop. Since the - // loop gets potentially replaced during fusion, keeping the value directly - // wont work. - DenseMap origValToResultNumber; - for (auto [index, result] : llvm::enumerate(consumer->getResults())) { - origValToResultNumber[result] = index; - } + // Since the loop gets potentially replaced during fusion, we need to track + // the mutation of replacement values. To do this, we attach a listener to + // update the replacements as they happen. + OpBuilder::Listener *previousListener = rewriter.getListener(); + auto resetListener = + llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); }); + ReplacementListener replaceListener(replacements, previousListener); + rewriter.setListener(&replaceListener); // 2. Typically, the operands of the tiled operation are slices of the // operands of the untiled operation. These are expressed in IR using @@ -1581,9 +1624,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( worklistCandidates.append(newSlices.value()); for (auto [index, result] : llvm::enumerate(fusableProducerOp->getResults())) { - origValToResultNumber[result] = loops.front()->getNumResults() - - fusableProducerOp->getNumResults() + - index; + replacements[result] = loops.front()->getResult( + loops.front()->getNumResults() - + fusableProducerOp->getNumResults() + index); } } if (Operation *tiledAndFusedOp = @@ -1597,11 +1640,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( } } - DenseMap replacements; - for (auto [origVal, resultNumber] : origValToResultNumber) { - replacements[origVal] = loops.front()->getResult(resultNumber); - } - return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, replacements}; }