Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCEV] Collect and merge loop guards through PHI nodes with multiple incoming values #113915

Merged
merged 9 commits into from
Nov 15, 2024
5 changes: 5 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,11 @@ class ScalarEvolution {

LoopGuards(ScalarEvolution &SE) : SE(SE) {}

static LoopGuards
collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
const BasicBlock *Block, const BasicBlock *Pred,
SmallPtrSet<const BasicBlock *, 8> VisitedBlocks);

public:
/// Collect rewrite map for loop guards for loop \p L, together with flags
/// indicating if NUW and NSW can be preserved during rewriting.
Expand Down
85 changes: 76 additions & 9 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10648,7 +10648,7 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
if (const Loop *L = LI.getLoopFor(BB))
return {L->getLoopPredecessor(), L->getHeader()};

return {nullptr, nullptr};
return {nullptr, BB};
}

/// SCEV structural equivalence is usually sufficient for testing whether two
Expand Down Expand Up @@ -15217,7 +15217,16 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,

ScalarEvolution::LoopGuards
ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
BasicBlock *Header = L->getHeader();
BasicBlock *Pred = L->getLoopPredecessor();
LoopGuards Guards(SE);
return collectFromBlock(SE, Guards, Header, Pred, {});
}

ScalarEvolution::LoopGuards ScalarEvolution::LoopGuards::collectFromBlock(
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
const BasicBlock *Block, const BasicBlock *Pred,
SmallPtrSet<const BasicBlock *, 8> VisitedBlocks) {
SmallVector<const SCEV *> ExprsToRewrite;
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
const SCEV *RHS,
Expand Down Expand Up @@ -15556,14 +15565,13 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
}
};

BasicBlock *Header = L->getHeader();
SmallVector<PointerIntPair<Value *, 1, bool>> Terms;
// First, collect information from assumptions dominating the loop.
for (auto &AssumeVH : SE.AC.assumptions()) {
if (!AssumeVH)
continue;
auto *AssumeI = cast<CallInst>(AssumeVH);
if (!SE.DT.dominates(AssumeI, Header))
if (!SE.DT.dominates(AssumeI, Block))
continue;
Terms.emplace_back(AssumeI->getOperand(0), true);
}
Expand All @@ -15574,20 +15582,19 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
if (GuardDecl)
for (const auto *GU : GuardDecl->users())
if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
if (Guard->getFunction() == Header->getParent() &&
SE.DT.dominates(Guard, Header))
if (Guard->getFunction() == Block->getParent() &&
SE.DT.dominates(Guard, Block))
Terms.emplace_back(Guard->getArgOperand(0), true);

// Third, collect conditions from dominating branches. Starting at the loop
// predecessor, climb up the predecessor chain, as long as there are
// predecessors that can be found that have unique successors leading to the
// original header.
// TODO: share this logic with isLoopEntryGuardedByCond.
for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
L->getLoopPredecessor(), Header);
Pair.first;
std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
for (; Pair.first;
Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {

VisitedBlocks.insert(Pair.second);
const BranchInst *LoopEntryPredicate =
dyn_cast<BranchInst>(Pair.first->getTerminator());
if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
Expand All @@ -15596,6 +15603,66 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
Terms.emplace_back(LoopEntryPredicate->getCondition(),
LoopEntryPredicate->getSuccessor(0) == Pair.second);
}
// Finally, if we stopped climbing the predecessor chain because
// there wasn't a unique one to continue, try to collect conditions
// for PHINodes by recursively following all of their incoming
// blocks and try to merge the found conditions to build a new one
// for the Phi.
if (Pair.second->hasNPredecessorsOrMore(2)) {
for (auto &Phi : Pair.second->phis()) {
if (!SE.isSCEVable(Phi.getType()))
continue;

using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
auto GetMinMaxConst = [&SE, &VisitedBlocks, &Pair,
&Phi](unsigned int In) -> MinMaxPattern {
LoopGuards G(SE);
if (VisitedBlocks.insert(Phi.getIncomingBlock(In)).second)
collectFromBlock(SE, G, Pair.second, Phi.getIncomingBlock(In),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably only allow a single level of recursion to start with, i.e. don't allow multiple predecessors after recursing here the first time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, added MaxLoopGuardCollectionDepth and defaulted it to 1

VisitedBlocks);
const SCEV *S = G.RewriteMap[SE.getSCEV(Phi.getIncomingValue(In))];
auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S);
if (!SM)
return {nullptr, scCouldNotCompute};
if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
return {C0, SM->getSCEVType()};
if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(SM->getOperand(1)))
return {C1, SM->getSCEVType()};
return {nullptr, scCouldNotCompute};
};
auto MergeMinMaxConst = [](MinMaxPattern P1,
MinMaxPattern P2) -> MinMaxPattern {
auto [C1, T1] = P1;
auto [C2, T2] = P2;
if (!C1 || !C2 || T1 != T2)
return {nullptr, scCouldNotCompute};
switch (T1) {
case scUMaxExpr:
return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
case scSMaxExpr:
return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
case scUMinExpr:
return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
case scSMinExpr:
return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
default:
llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
}
};
auto P = GetMinMaxConst(0);
for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
if (!P.first)
break;
P = MergeMinMaxConst(P, GetMinMaxConst(In));
}
if (P.first) {
const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
SmallVector<const SCEV *, 2> Ops({P.first, LHS});
const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
Guards.RewriteMap.insert({LHS, RHS});
}
}
}

// Now apply the information from the collected conditions to
// Guards.RewriteMap. Conditions are processed in reverse order, so the
Expand Down
82 changes: 82 additions & 0 deletions llvm/test/Analysis/ScalarEvolution/trip-count.ll
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,85 @@ for.body:
exit:
ret void
}

define void @epilogue(i64 %count) {
; CHECK-LABEL: 'epilogue'
; CHECK-NEXT: Determining loop execution counts for: @epilogue
; CHECK-NEXT: Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: constant max backedge-taken count is i64 6
; CHECK-NEXT: Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: Trip multiple is 1
; CHECK-NEXT: Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
; CHECK-NEXT: Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: Trip multiple is 1
entry:
%cmp = icmp ugt i64 %count, 7
br i1 %cmp, label %while.body, label %epilogue.preheader

while.body:
%iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
%sub = add i64 %iv, -8
%exitcond.not = icmp ugt i64 %sub, 7
br i1 %exitcond.not, label %while.body, label %while.loopexit

while.loopexit:
%sub.exit = phi i64 [ %sub, %while.body ]
br label %epilogue.preheader

epilogue.preheader:
%count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
%epilogue.cmp = icmp eq i64 %count.epilogue, 0
br i1 %epilogue.cmp, label %exit, label %epilogue

epilogue:
%iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
%dec = add i64 %iv.epilogue, -1
%exitcond.epilogue = icmp eq i64 %dec, 0
br i1 %exitcond.epilogue, label %exit, label %epilogue

exit:
ret void

}

define void @epilogue2(i64 %count) {
; CHECK-LABEL: 'epilogue2'
; CHECK-NEXT: Determining loop execution counts for: @epilogue2
; CHECK-NEXT: Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: constant max backedge-taken count is i64 8
; CHECK-NEXT: Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: Trip multiple is 1
; CHECK-NEXT: Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
; CHECK-NEXT: Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: Trip multiple is 1
entry:
%cmp = icmp ugt i64 %count, 9
br i1 %cmp, label %while.body, label %epilogue.preheader

while.body:
%iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
%sub = add i64 %iv, -8
%exitcond.not = icmp ugt i64 %sub, 7
br i1 %exitcond.not, label %while.body, label %while.loopexit

while.loopexit:
%sub.exit = phi i64 [ %sub, %while.body ]
br label %epilogue.preheader

epilogue.preheader:
%count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
%epilogue.cmp = icmp eq i64 %count.epilogue, 0
br i1 %epilogue.cmp, label %exit, label %epilogue

epilogue:
%iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
%dec = add i64 %iv.epilogue, -1
%exitcond.epilogue = icmp eq i64 %dec, 0
br i1 %exitcond.epilogue, label %exit, label %epilogue

exit:
ret void

}
2 changes: 1 addition & 1 deletion llvm/test/Transforms/PhaseOrdering/X86/pr38280.ll
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ define void @apply_delta(ptr nocapture noundef %dst, ptr nocapture noundef reado
; CHECK-NEXT: [[INCDEC_PTR]] = getelementptr inbounds i8, ptr [[DST_ADDR_130]], i64 1
; CHECK-NEXT: [[INCDEC_PTR8]] = getelementptr inbounds i8, ptr [[SRC_ADDR_129]], i64 1
; CHECK-NEXT: [[TOBOOL_NOT:%.*]] = icmp eq i64 [[DEC]], 0
; CHECK-NEXT: br i1 [[TOBOOL_NOT]], label [[WHILE_END9]], label [[WHILE_BODY4]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK-NEXT: br i1 [[TOBOOL_NOT]], label [[WHILE_END9]], label [[WHILE_BODY4]]
; CHECK: while.end9:
; CHECK-NEXT: ret void
;
Expand Down
Loading