diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index 65de25dd2f326..62246a87b291a 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -96,6 +96,9 @@ using namespace mlir; namespace { +using BlockingUsesMap = + llvm::MapVector>; + /// Information computed during promotion analysis used to perform actual /// promotion. struct MemorySlotPromotionInfo { @@ -106,7 +109,7 @@ struct MemorySlotPromotionInfo { /// its uses, it is because the defining ops of the blocking uses requested /// it. The defining ops therefore must also have blocking uses or be the /// starting point of the bloccking uses. - DenseMap> userToBlockingUses; + BlockingUsesMap userToBlockingUses; }; /// Computes information for basic slot promotion. This will check that direct @@ -129,8 +132,7 @@ class MemorySlotPromotionAnalyzer { /// uses (typically, removing its users because it will delete itself to /// resolve its own blocking uses). This will fail if one of the transitive /// users cannot remove a requested use, and should prevent promotion. - LogicalResult computeBlockingUses( - DenseMap> &userToBlockingUses); + LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses); /// Computes in which blocks the value stored in the slot is actually used, /// meaning blocks leading to a load. This method uses `definingBlocks`, the @@ -233,7 +235,7 @@ Value MemorySlotPromoter::getLazyDefaultValue() { } LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( - DenseMap> &userToBlockingUses) { + BlockingUsesMap &userToBlockingUses) { // The promotion of an operation may require the promotion of further // operations (typically, removing operations that use an operation that must // delete itself). We thus need to start from the use of the slot pointer and @@ -243,7 +245,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( // use it. for (OpOperand &use : slot.ptr.getUses()) { SmallPtrSet &blockingUses = - userToBlockingUses.getOrInsertDefault(use.getOwner()); + userToBlockingUses[use.getOwner()]; blockingUses.insert(&use); } @@ -281,7 +283,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( assert(llvm::is_contained(user->getResults(), blockingUse->get())); SmallPtrSetImpl &newUserBlockingUseSet = - userToBlockingUses.getOrInsertDefault(blockingUse->getOwner()); + userToBlockingUses[blockingUse->getOwner()]; newUserBlockingUseSet.insert(blockingUse); } } @@ -516,14 +518,16 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region, } void MemorySlotPromoter::removeBlockingUses() { - llvm::SetVector usersToRemoveUses; - for (auto &user : llvm::make_first_range(info.userToBlockingUses)) - usersToRemoveUses.insert(user); - SetVector sortedUsersToRemoveUses = - mlir::topologicalSort(usersToRemoveUses); + llvm::SmallVector usersToRemoveUses( + llvm::make_first_range(info.userToBlockingUses)); + // The uses need to be traversed in *reverse dominance* order to ensure that + // transitive replacements are performed correctly. + llvm::sort(usersToRemoveUses, [&](Operation *lhs, Operation *rhs) { + return dominance.properlyDominates(rhs, lhs); + }); llvm::SmallVector toErase; - for (Operation *toPromote : llvm::reverse(sortedUsersToRemoveUses)) { + for (Operation *toPromote : usersToRemoveUses) { if (auto toPromoteMemOp = dyn_cast(toPromote)) { Value reachingDef = reachingDefs.lookup(toPromoteMemOp); // If no reaching definition is known, this use is outside the reach of diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir index 30ba459d07a49..32e3fed7e5485 100644 --- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir +++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir @@ -683,3 +683,16 @@ llvm.func @no_inner_alloca_promotion(%arg: i64) -> i64 { // CHECK: llvm.return %[[RES]] : i64 llvm.return %2 : i64 } + +// ----- + +// CHECK-LABEL: @transitive_reaching_def +llvm.func @transitive_reaching_def() -> !llvm.ptr { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NOT: alloca + %1 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr + llvm.store %2, %1 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr + %3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr + llvm.return %3 : !llvm.ptr +}