1616#include " mlir/Interfaces/MemorySlotInterfaces.h"
1717#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
1818#include " mlir/Transforms/Passes.h"
19+ #include " mlir/Transforms/RegionUtils.h"
20+ #include " llvm/ADT/PostOrderIterator.h"
1921#include " llvm/ADT/STLExtras.h"
2022#include " llvm/Support/Casting.h"
2123#include " llvm/Support/GenericIteratedDominanceFrontier.h"
@@ -96,6 +98,9 @@ using namespace mlir;
9698
9799namespace {
98100
101+ using BlockingUsesMap =
102+ llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4 >>;
103+
99104// / Information computed during promotion analysis used to perform actual
100105// / promotion.
101106struct MemorySlotPromotionInfo {
@@ -106,7 +111,7 @@ struct MemorySlotPromotionInfo {
106111 // / its uses, it is because the defining ops of the blocking uses requested
107112 // / it. The defining ops therefore must also have blocking uses or be the
108113 // / starting point of the bloccking uses.
109- DenseMap<Operation *, SmallPtrSet<OpOperand *, 4 >> userToBlockingUses;
114+ BlockingUsesMap userToBlockingUses;
110115};
111116
112117// / Computes information for basic slot promotion. This will check that direct
@@ -129,8 +134,7 @@ class MemorySlotPromotionAnalyzer {
129134 // / uses (typically, removing its users because it will delete itself to
130135 // / resolve its own blocking uses). This will fail if one of the transitive
131136 // / users cannot remove a requested use, and should prevent promotion.
132- LogicalResult computeBlockingUses (
133- DenseMap<Operation *, SmallPtrSet<OpOperand *, 4 >> &userToBlockingUses);
137+ LogicalResult computeBlockingUses (BlockingUsesMap &userToBlockingUses);
134138
135139 // / Computes in which blocks the value stored in the slot is actually used,
136140 // / meaning blocks leading to a load. This method uses `definingBlocks`, the
@@ -233,7 +237,7 @@ Value MemorySlotPromoter::getLazyDefaultValue() {
233237}
234238
235239LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses (
236- DenseMap<Operation *, SmallPtrSet<OpOperand *, 4 >> &userToBlockingUses) {
240+ BlockingUsesMap &userToBlockingUses) {
237241 // The promotion of an operation may require the promotion of further
238242 // operations (typically, removing operations that use an operation that must
239243 // delete itself). We thus need to start from the use of the slot pointer and
@@ -243,7 +247,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
243247 // use it.
244248 for (OpOperand &use : slot.ptr .getUses ()) {
245249 SmallPtrSet<OpOperand *, 4 > &blockingUses =
246- userToBlockingUses. getOrInsertDefault ( use.getOwner ()) ;
250+ userToBlockingUses[ use.getOwner ()] ;
247251 blockingUses.insert (&use);
248252 }
249253
@@ -281,7 +285,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
281285 assert (llvm::is_contained (user->getResults (), blockingUse->get ()));
282286
283287 SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
284- userToBlockingUses. getOrInsertDefault ( blockingUse->getOwner ()) ;
288+ userToBlockingUses[ blockingUse->getOwner ()] ;
285289 newUserBlockingUseSet.insert (blockingUse);
286290 }
287291 }
@@ -515,15 +519,37 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
515519 }
516520}
517521
522+ // / Sorts `ops` according to dominance. Relies on the topological order of basic
523+ // / blocks to get a deterministic ordering.
524+ static void dominanceSort (SmallVector<Operation *> &ops, Region ®ion) {
525+ // Produce a topological block order and construct a map to lookup the indices
526+ // of blocks.
527+ DenseMap<Block *, size_t > topoBlockIndices;
528+ SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks (region);
529+ for (auto [index, block] : llvm::enumerate (topologicalOrder))
530+ topoBlockIndices[block] = index;
531+
532+ // Combining the topological order of the basic blocks together with block
533+ // internal operation order guarantees a deterministic, dominance respecting
534+ // order.
535+ llvm::sort (ops, [&](Operation *lhs, Operation *rhs) {
536+ size_t lhsBlockIndex = topoBlockIndices.at (lhs->getBlock ());
537+ size_t rhsBlockIndex = topoBlockIndices.at (rhs->getBlock ());
538+ if (lhsBlockIndex == rhsBlockIndex)
539+ return lhs->isBeforeInBlock (rhs);
540+ return lhsBlockIndex < rhsBlockIndex;
541+ });
542+ }
543+
518544void MemorySlotPromoter::removeBlockingUses () {
519- llvm::SetVector <Operation *> usersToRemoveUses;
520- for ( auto &user : llvm::make_first_range (info.userToBlockingUses ))
521- usersToRemoveUses. insert (user);
522- SetVector<Operation *> sortedUsersToRemoveUses =
523- mlir::topologicalSort (usersToRemoveUses);
545+ llvm::SmallVector <Operation *> usersToRemoveUses (
546+ llvm::make_first_range (info.userToBlockingUses ));
547+
548+ // Sort according to dominance.
549+ dominanceSort (usersToRemoveUses, *slot. ptr . getParentBlock ()-> getParent () );
524550
525551 llvm::SmallVector<Operation *> toErase;
526- for (Operation *toPromote : llvm::reverse (sortedUsersToRemoveUses )) {
552+ for (Operation *toPromote : llvm::reverse (usersToRemoveUses )) {
527553 if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
528554 Value reachingDef = reachingDefs.lookup (toPromoteMemOp);
529555 // If no reaching definition is known, this use is outside the reach of
0 commit comments