diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 2277989bf8411be..90db42d479a193b 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -28,7 +28,6 @@ #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -1468,47 +1467,6 @@ void SliceTrackingListener::notifyOperationReplaced(Operation *op, ValueRange replacement) { removeOp(op); } - -//===----------------------------------------------------------------------===// -// ReplacementListener -//===----------------------------------------------------------------------===// - -/// Listener that tracks updates replacements for values which can be mutated. -/// This listener runs on top of the existing listener for the rewriter, -/// to make sure external users can still run listeners. -class ReplacementListener : public RewriterBase::ForwardingListener { -public: - ReplacementListener(DenseMap &replacements, - OpBuilder::Listener *listener) - : ForwardingListener(listener), replacements(replacements) {} - - void updateReplacementValues(ValueRange origValues, - ValueRange replaceValues) { - // This can probably be written better, but just iterates over the map - // and the new replacements for now. - for (auto &[key, val] : replacements) { - for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) { - if (val == orig) { - val = replace; - } - } - } - } - - void notifyOperationReplaced(Operation *op, Operation *newOp) override { - ForwardingListener::notifyOperationReplaced(op, newOp); - updateReplacementValues(op->getResults(), newOp->getResults()); - } - - void notifyOperationReplaced(Operation *op, ValueRange values) override { - ForwardingListener::notifyOperationReplaced(op, values); - updateReplacementValues(op->getResults(), values); - } - -private: - DenseMap &replacements; -}; - } // namespace /// Implementation of tile consumer and fuse producer greedily. @@ -1535,27 +1493,26 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( for (auto *tiledOp : tilingResult->tiledOps) tiledAndFusedOps.insert(tiledOp); - DenseMap replacements; - for (auto [origVal, replacement] : llvm::zip_equal( - consumer->getResults(), tilingResult->mergeResult.replacements)) { - replacements[origVal] = replacement; - } - // If there are no loops generated, fusion is immaterial. auto &loops = tilingResult->loops; if (loops.empty()) { + DenseMap replacements; + for (auto [origVal, replacement] : llvm::zip_equal( + consumer->getResults(), tilingResult->mergeResult.replacements)) { + replacements[origVal] = replacement; + } return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, replacements}; } - // Since the loop gets potentially replaced during fusion, we need to track - // the mutation of replacement values. To do this, we attach a listener to - // update the replacements as they happen. - OpBuilder::Listener *previousListener = rewriter.getListener(); - auto resetListener = - llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); }); - ReplacementListener replaceListener(replacements, previousListener); - rewriter.setListener(&replaceListener); + // To keep track of replacements for now just record the map from the + // original untiled value to the result number of the for loop. Since the + // loop gets potentially replaced during fusion, keeping the value directly + // wont work. + DenseMap origValToResultNumber; + for (auto [index, result] : llvm::enumerate(consumer->getResults())) { + origValToResultNumber[result] = index; + } // 2. Typically, the operands of the tiled operation are slices of the // operands of the untiled operation. These are expressed in IR using @@ -1624,9 +1581,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( worklistCandidates.append(newSlices.value()); for (auto [index, result] : llvm::enumerate(fusableProducerOp->getResults())) { - replacements[result] = loops.front()->getResult( - loops.front()->getNumResults() - - fusableProducerOp->getNumResults() + index); + origValToResultNumber[result] = loops.front()->getNumResults() - + fusableProducerOp->getNumResults() + + index; } } if (Operation *tiledAndFusedOp = @@ -1640,6 +1597,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( } } + DenseMap replacements; + for (auto [origVal, resultNumber] : origValToResultNumber) { + replacements[origVal] = loops.front()->getResult(resultNumber); + } + return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, replacements}; }