Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCEV] Collect and merge loop guards through PHI nodes with multiple incoming values #113915

Merged
merged 9 commits into from
Nov 15, 2024
5 changes: 5 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,11 @@ class ScalarEvolution {

LoopGuards(ScalarEvolution &SE) : SE(SE) {}

static LoopGuards
collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
const BasicBlock *Block, const BasicBlock *Pred,
SmallPtrSet<const BasicBlock *, 8> VisitedBlocks);
juliannagele marked this conversation as resolved.
Show resolved Hide resolved

public:
/// Collect rewrite map for loop guards for loop \p L, together with flags
/// indicating if NUW and NSW can be preserved during rewriting.
Expand Down
85 changes: 76 additions & 9 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10648,7 +10648,7 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
if (const Loop *L = LI.getLoopFor(BB))
return {L->getLoopPredecessor(), L->getHeader()};

return {nullptr, nullptr};
return {nullptr, BB};
}

/// SCEV structural equivalence is usually sufficient for testing whether two
Expand Down Expand Up @@ -15217,7 +15217,16 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,

ScalarEvolution::LoopGuards
ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
BasicBlock *Header = L->getHeader();
BasicBlock *Pred = L->getLoopPredecessor();
LoopGuards Guards(SE);
return collectFromBlock(SE, Guards, Header, Pred, {});
}

ScalarEvolution::LoopGuards ScalarEvolution::LoopGuards::collectFromBlock(
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
const BasicBlock *Block, const BasicBlock *Pred,
SmallPtrSet<const BasicBlock *, 8> VisitedBlocks) {
SmallVector<const SCEV *> ExprsToRewrite;
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
const SCEV *RHS,
Expand Down Expand Up @@ -15556,14 +15565,13 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
}
};

BasicBlock *Header = L->getHeader();
SmallVector<PointerIntPair<Value *, 1, bool>> Terms;
// First, collect information from assumptions dominating the loop.
for (auto &AssumeVH : SE.AC.assumptions()) {
if (!AssumeVH)
continue;
auto *AssumeI = cast<CallInst>(AssumeVH);
if (!SE.DT.dominates(AssumeI, Header))
if (!SE.DT.dominates(AssumeI, Block))
continue;
Terms.emplace_back(AssumeI->getOperand(0), true);
}
Expand All @@ -15574,20 +15582,19 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
if (GuardDecl)
for (const auto *GU : GuardDecl->users())
if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
if (Guard->getFunction() == Header->getParent() &&
SE.DT.dominates(Guard, Header))
if (Guard->getFunction() == Block->getParent() &&
SE.DT.dominates(Guard, Block))
Terms.emplace_back(Guard->getArgOperand(0), true);

// Third, collect conditions from dominating branches. Starting at the loop
// predecessor, climb up the predecessor chain, as long as there are
// predecessors that can be found that have unique successors leading to the
// original header.
// TODO: share this logic with isLoopEntryGuardedByCond.
for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
L->getLoopPredecessor(), Header);
Pair.first;
std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
for (; Pair.first;
Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {

VisitedBlocks.insert(Pair.second);
const BranchInst *LoopEntryPredicate =
dyn_cast<BranchInst>(Pair.first->getTerminator());
if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
Expand All @@ -15596,6 +15603,66 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
Terms.emplace_back(LoopEntryPredicate->getCondition(),
LoopEntryPredicate->getSuccessor(0) == Pair.second);
}
// Finally, if we stopped climbing the predecessor chain because
// there wasn't a unique one to continue, try to collect conditions
// for PHINodes by recursively following all of their incoming
// blocks and try to merge the found conditions to build a new one
// for the Phi.
if (Pair.second->hasNPredecessorsOrMore(2)) {
for (auto &Phi : Pair.second->phis()) {
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
if (!SE.isSCEVable(Phi.getType()))
continue;

using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
auto GetMinMaxConst = [&SE, &VisitedBlocks, &Pair,
&Phi](unsigned int In) -> MinMaxPattern {
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
LoopGuards G(SE);
if (VisitedBlocks.insert(Phi.getIncomingBlock(In)).second)
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
collectFromBlock(SE, G, Pair.second, Phi.getIncomingBlock(In),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably only allow a single level of recursion to start with, i.e. don't allow multiple predecessors after recursing here the first time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, added MaxLoopGuardCollectionDepth and defaulted it to 1

VisitedBlocks);
const SCEV *S = G.RewriteMap[SE.getSCEV(Phi.getIncomingValue(In))];
auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S);
if (!SM)
return {nullptr, scCouldNotCompute};
if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
return {C0, SM->getSCEVType()};
if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(SM->getOperand(1)))
return {C1, SM->getSCEVType()};
return {nullptr, scCouldNotCompute};
};
auto MergeMinMaxConst = [](MinMaxPattern P1,
MinMaxPattern P2) -> MinMaxPattern {
auto [C1, T1] = P1;
auto [C2, T2] = P2;
if (!C1 || !C2 || T1 != T2)
return {nullptr, scCouldNotCompute};
switch (T1) {
case scUMaxExpr:
return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
case scSMaxExpr:
return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
case scUMinExpr:
return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
case scSMinExpr:
return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
default:
llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
}
};
auto P = GetMinMaxConst(0);
for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
if (!P.first)
break;
P = MergeMinMaxConst(P, GetMinMaxConst(In));
}
if (P.first) {
const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
SmallVector<const SCEV *, 2> Ops({P.first, LHS});
const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
Guards.RewriteMap.insert({LHS, RHS});
}
}
}

// Now apply the information from the collected conditions to
// Guards.RewriteMap. Conditions are processed in reverse order, so the
Expand Down
214 changes: 214 additions & 0 deletions llvm/test/Analysis/ScalarEvolution/trip-count.ll
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,217 @@ for.body:
exit:
ret void
}

define void @epilogue(i64 %count) {
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
; CHECK-LABEL: 'epilogue'
; CHECK-NEXT: Determining loop execution counts for: @epilogue
; CHECK-NEXT: Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: constant max backedge-taken count is i64 6
; CHECK-NEXT: Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: Trip multiple is 1
; CHECK-NEXT: Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
; CHECK-NEXT: Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: Trip multiple is 1
entry:
%cmp = icmp ugt i64 %count, 7
br i1 %cmp, label %while.body, label %epilogue.preheader

while.body:
%iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
%sub = add i64 %iv, -8
%exitcond.not = icmp ugt i64 %sub, 7
br i1 %exitcond.not, label %while.body, label %while.loopexit

while.loopexit:
%sub.exit = phi i64 [ %sub, %while.body ]
br label %epilogue.preheader

epilogue.preheader:
%count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
%epilogue.cmp = icmp eq i64 %count.epilogue, 0
br i1 %epilogue.cmp, label %exit, label %epilogue

epilogue:
%iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
%dec = add i64 %iv.epilogue, -1
%exitcond.epilogue = icmp eq i64 %dec, 0
br i1 %exitcond.epilogue, label %exit, label %epilogue

exit:
ret void

juliannagele marked this conversation as resolved.
Show resolved Hide resolved
}

define void @epilogue2(i64 %count) {
; CHECK-LABEL: 'epilogue2'
; CHECK-NEXT: Determining loop execution counts for: @epilogue2
; CHECK-NEXT: Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: constant max backedge-taken count is i64 8
; CHECK-NEXT: Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: Trip multiple is 1
; CHECK-NEXT: Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
; CHECK-NEXT: Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: Trip multiple is 1
entry:
%cmp = icmp ugt i64 %count, 9
br i1 %cmp, label %while.body, label %epilogue.preheader

while.body:
%iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
%sub = add i64 %iv, -8
%exitcond.not = icmp ugt i64 %sub, 7
br i1 %exitcond.not, label %while.body, label %while.loopexit

while.loopexit:
%sub.exit = phi i64 [ %sub, %while.body ]
br label %epilogue.preheader

epilogue.preheader:
%count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
%epilogue.cmp = icmp eq i64 %count.epilogue, 0
br i1 %epilogue.cmp, label %exit, label %epilogue

epilogue:
%iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
%dec = add i64 %iv.epilogue, -1
%exitcond.epilogue = icmp eq i64 %dec, 0
br i1 %exitcond.epilogue, label %exit, label %epilogue

exit:
ret void

}

define void @slt(i16 %a, i16 %b, i1 %c) {
; CHECK-LABEL: 'slt'
; CHECK-NEXT: Determining loop execution counts for: @slt
; CHECK-NEXT: Loop %loop: backedge-taken count is (63 + (-1 * %count))
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i16 -32704
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (63 + (-1 * %count))
; CHECK-NEXT: Loop %loop: Trip multiple is 1
entry:
br i1 %c, label %b1, label %b2

b1:
%cmp1 = icmp slt i16 %a, 8
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
br i1 %cmp1, label %preheader, label %exit

b2:
%cmp2 = icmp slt i16 %b, 8
br i1 %cmp2, label %preheader, label %exit

preheader:
%count = phi i16 [ %a, %b1 ], [ %b, %b2 ]
br label %loop

loop:
%iv = phi i16 [ %iv.next, %loop ], [ %count, %preheader ]
%iv.next = add i16 %iv, 1
%exitcond = icmp slt i16 %iv.next, 64
br i1 %exitcond, label %loop, label %exit

exit:
ret void

}

define void @ult(i16 %a, i16 %b, i1 %c) {
; CHECK-LABEL: 'ult'
; CHECK-NEXT: Determining loop execution counts for: @ult
; CHECK-NEXT: Loop %loop: backedge-taken count is (-1 + %count)
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i16 -2
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (-1 + %count)
; CHECK-NEXT: Loop %loop: Trip multiple is 1
entry:
br i1 %c, label %b1, label %b2

b1:
%cmp1 = icmp ult i16 %a, 8
br i1 %cmp1, label %exit, label %preheader

b2:
%cmp2 = icmp ult i16 %b, 8
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
br i1 %cmp2, label %exit, label %preheader

preheader:
%count = phi i16 [ %a, %b1 ], [ %b, %b2 ]
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
br label %loop

loop:
%iv = phi i16 [ %iv.next, %loop ], [ %count, %preheader ]
%iv.next = add i16 %iv, -1
%exitcond = icmp eq i16 %iv.next, 0
br i1 %exitcond, label %exit, label %loop

exit:
ret void

}

define void @sgt(i16 %a, i16 %b, i1 %c) {
; CHECK-LABEL: 'sgt'
; CHECK-NEXT: Determining loop execution counts for: @sgt
; CHECK-NEXT: Loop %loop: backedge-taken count is %count
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i16 32767
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is %count
; CHECK-NEXT: Loop %loop: Trip multiple is 1
entry:
br i1 %c, label %b1, label %b2

b1:
%cmp1 = icmp sgt i16 %a, 8
br i1 %cmp1, label %preheader, label %exit

b2:
%cmp2 = icmp sgt i16 %b, 8
juliannagele marked this conversation as resolved.
Show resolved Hide resolved
br i1 %cmp2, label %preheader, label %exit

preheader:
%count = phi i16 [ %a, %b1 ], [ %b, %b2 ]
br label %loop

loop:
%iv = phi i16 [ %iv.next, %loop ], [ %count, %preheader ]
%iv.next = add i16 %iv, -1
%exitcond = icmp slt i16 %iv.next, 0
br i1 %exitcond, label %exit, label %loop

exit:
ret void
}


define void @mixed(i16 %a, i16 %b, i1 %c) {
; CHECK-LABEL: 'mixed'
; CHECK-NEXT: Determining loop execution counts for: @mixed
; CHECK-NEXT: Loop %loop: backedge-taken count is (-1 + (-1 * %count) + (64 smax (1 + %count)))
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i16 -32704
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (-1 + (-1 * %count) + (64 smax (1 + %count)))
; CHECK-NEXT: Loop %loop: Trip multiple is 1
entry:
br i1 %c, label %b1, label %b2

b1:
%cmp1 = icmp slt i16 %a, 8
br i1 %cmp1, label %preheader, label %exit

b2:
%cmp2 = icmp ult i16 %b, 8
br i1 %cmp2, label %preheader, label %exit

preheader:
%count = phi i16 [ %a, %b1 ], [ %b, %b2 ]
br label %loop

loop:
%iv = phi i16 [ %iv.next, %loop ], [ %count, %preheader ]
%iv.next = add i16 %iv, 1
%exitcond = icmp slt i16 %iv.next, 64
br i1 %exitcond, label %loop, label %exit

exit:
ret void

juliannagele marked this conversation as resolved.
Show resolved Hide resolved
}
2 changes: 1 addition & 1 deletion llvm/test/Transforms/PhaseOrdering/X86/pr38280.ll
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ define void @apply_delta(ptr nocapture noundef %dst, ptr nocapture noundef reado
; CHECK-NEXT: [[INCDEC_PTR]] = getelementptr inbounds i8, ptr [[DST_ADDR_130]], i64 1
; CHECK-NEXT: [[INCDEC_PTR8]] = getelementptr inbounds i8, ptr [[SRC_ADDR_129]], i64 1
; CHECK-NEXT: [[TOBOOL_NOT:%.*]] = icmp eq i64 [[DEC]], 0
; CHECK-NEXT: br i1 [[TOBOOL_NOT]], label [[WHILE_END9]], label [[WHILE_BODY4]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK-NEXT: br i1 [[TOBOOL_NOT]], label [[WHILE_END9]], label [[WHILE_BODY4]]
; CHECK: while.end9:
; CHECK-NEXT: ret void
;
Expand Down
Loading