diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 0880f9c65aa45..fa331c93d712e 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -460,6 +460,10 @@ class ScalarEvolution { LoopComputable ///< The SCEV varies predictably with the loop. }; + bool AssumeLoopFinite = false; + void setAssumeLoopExits(); + SmallPtrSet GuaranteedUnreachable; + /// An enum describing the relationship between a SCEV and a basic block. enum BlockDisposition { DoesNotDominateBlock, ///< The SCEV does not dominate the block. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 4b2db80bc1ec3..f261623dcd069 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -509,6 +509,8 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) { return S; } +void ScalarEvolution::setAssumeLoopExits() { this->AssumeLoopFinite = true; } + SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty) : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} @@ -7413,7 +7415,8 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { // A mustprogress loop without side effects must be finite. // TODO: The check used here is very conservative. It's only *specific* // side effects which are well defined in infinite loops. - return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); + return AssumeLoopFinite || isFinite(L) || + (isMustProgress(L) && loopHasNoSideEffects(L)); } const SCEV *ScalarEvolution::createSCEVIter(Value *V) { @@ -8828,6 +8831,26 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, bool AllowPredicates) { + if (AssumeLoopFinite) { + SmallVector ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + for (auto &ExitingBlock : ExitingBlocks) { + BasicBlock *Exit = nullptr; + for (auto *SBB : successors(ExitingBlock)) { + if (!L->contains(SBB)) { + if (GuaranteedUnreachable.count(SBB)) + continue; + Exit = SBB; + break; + } + } + if (!Exit) + ExitingBlock = nullptr; + } + ExitingBlocks.erase( + std::remove(ExitingBlocks.begin(), ExitingBlocks.end(), nullptr), + ExitingBlocks.end()); + } assert(L->contains(ExitingBlock) && "Exit count for non-loop block?"); // If our exiting block does not dominate the latch, then its connection with // loop's exit limit may be far from trivial. @@ -8853,6 +8876,8 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, BasicBlock *Exit = nullptr; for (auto *SBB : successors(ExitingBlock)) if (!L->contains(SBB)) { + if (AssumeLoopFinite and GuaranteedUnreachable.count(SBB)) + continue; if (Exit) // Multiple exit successors. return getCouldNotCompute(); Exit = SBB; @@ -8923,6 +8948,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { + // Handle BinOp conditions (And, Or). if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates)) @@ -8950,6 +8976,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( if (ExitIfTrue == !CI->getZExtValue()) // The backedge is always taken. return getCouldNotCompute(); + // The backedge is never taken. return getZero(CI->getType()); } @@ -8961,9 +8988,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( const APInt *C; if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) && match(WO->getRHS(), m_APInt(C))) { - ConstantRange NWR = - ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C, - WO->getNoWrapKind()); + ConstantRange NWR = ConstantRange::makeExactNoWrapRegion( + WO->getBinaryOp(), *C, WO->getNoWrapKind()); CmpInst::Predicate Pred; APInt NewRHSC, Offset; NWR.getEquivalentICmp(Pred, NewRHSC, Offset); @@ -9019,6 +9045,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( const SCEV *SymbolicMaxBECount = getCouldNotCompute(); if (EitherMayExit) { bool UseSequentialUMin = !isa(ExitCond); + // Both conditions must be same for the loop to continue executing. // Choose the less conservative count. if (EL0.ExactNotTaken != getCouldNotCompute() && @@ -9026,6 +9053,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken, UseSequentialUMin); } + if (EL0.ConstantMaxNotTaken == getCouldNotCompute()) ConstantMaxBECount = EL1.ConstantMaxNotTaken; else if (EL1.ConstantMaxNotTaken == getCouldNotCompute()) @@ -9045,6 +9073,12 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( // For now, be conservative. if (EL0.ExactNotTaken == EL1.ExactNotTaken) BECount = EL0.ExactNotTaken; + // This was executed in Enzyme's must exit code under the + // logic for when the binary op was OR + if (AssumeLoopFinite && !IsAnd) { + if (EL0.ExactNotTaken == EL1.ExactNotTaken) + ConstantMaxBECount = EL0.ExactNotTaken; + } } // There are cases (e.g. PR26207) where computeExitLimitFromCond is able @@ -9053,12 +9087,14 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( // and // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and // EL1.ConstantMaxNotTaken to not. - if (isa(ConstantMaxBECount) && - !isa(BECount)) - ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); - if (isa(SymbolicMaxBECount)) - SymbolicMaxBECount = - isa(BECount) ? ConstantMaxBECount : BECount; + if (!AssumeLoopFinite || !IsAnd) { // should skip if assume exits and OR + if (isa(ConstantMaxBECount) && + !isa(BECount)) + ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); + if (isa(SymbolicMaxBECount)) + SymbolicMaxBECount = + isa(BECount) ? ConstantMaxBECount : BECount; + } return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false, { &EL0.Predicates, &EL1.Predicates }); } @@ -9082,8 +9118,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( if (EL.hasAnyInfo()) return EL; - auto *ExhaustiveCount = - computeExitCountExhaustively(L, ExitCond, ExitIfTrue); + auto *ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, ExitIfTrue); if (!isa(ExhaustiveCount)) return ExhaustiveCount; @@ -9094,7 +9129,31 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, bool ControlsOnlyExit, bool AllowPredicates) { - + if (AssumeLoopFinite) { +#define PROP_PHI(LHS) \ + if (auto un = dyn_cast(LHS)) { \ + if (auto pn = dyn_cast_or_null(un->getValue())) { \ + const SCEV *sc = nullptr; \ + bool failed = false; \ + for (auto &a : pn->incoming_values()) { \ + auto subsc = getSCEV(a); \ + if (sc == nullptr) { \ + sc = subsc; \ + continue; \ + } \ + if (subsc != sc) { \ + failed = true; \ + break; \ + } \ + } \ + if (!failed) { \ + LHS = sc; \ + } \ + } \ + } + PROP_PHI(LHS) + PROP_PHI(RHS) + } // Try to evaluate any dependencies out of the loop. LHS = getSCEVAtScope(LHS, L); RHS = getSCEVAtScope(RHS, L); @@ -9107,6 +9166,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( Pred = ICmpInst::getSwappedPredicate(Pred); } + // was not present in Enzyme code, the last condition is true if + // AssumeLoopExits is true + // will the first two checks cause enzyme to fail? bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) && loopIsFiniteByAssumption(L); // Simplify the operands before analyzing them. @@ -9184,15 +9246,19 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( if (EL.hasAnyInfo()) return EL; break; } + case ICmpInst::ICMP_SLE: case ICmpInst::ICMP_ULE: - // Since the loop is finite, an invariant RHS cannot include the boundary - // value, otherwise it would loop forever. - if (!EnableFiniteLoopControl || !ControllingFiniteLoop || - !isLoopInvariant(RHS, L)) - break; - RHS = getAddExpr(getOne(RHS->getType()), RHS); + if (!AssumeLoopFinite) { + // Since the loop is finite, an invariant RHS cannot include the boundary + // value, otherwise it would loop forever. + if (!EnableFiniteLoopControl || !ControllingFiniteLoop || + !isLoopInvariant(RHS, L)) + break; + RHS = getAddExpr(getOne(RHS->getType()), RHS); + } [[fallthrough]]; + case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = ICmpInst::isSigned(Pred); @@ -9204,16 +9270,33 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( } case ICmpInst::ICMP_SGE: case ICmpInst::ICMP_UGE: - // Since the loop is finite, an invariant RHS cannot include the boundary - // value, otherwise it would loop forever. - if (!EnableFiniteLoopControl || !ControllingFiniteLoop || - !isLoopInvariant(RHS, L)) - break; - RHS = getAddExpr(getMinusOne(RHS->getType()), RHS); + if (!AssumeLoopFinite) { + // Since the loop is finite, an invariant RHS cannot include the boundary + // value, otherwise it would loop forever. + if (!EnableFiniteLoopControl || !ControllingFiniteLoop || + !isLoopInvariant(RHS, L)) + break; + RHS = getAddExpr(getMinusOne(RHS->getType()), RHS); + } [[fallthrough]]; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = ICmpInst::isSigned(Pred); + if (AssumeLoopFinite) { + if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) { + if (!isa(RHS->getType())) + break; + SmallVector sv = { + RHS, getConstant( + ConstantInt::get(cast(RHS->getType()), -1))}; + // Since this is not an infinite loop by induction, RHS cannot be + // int_min/uint_min Therefore subtracting 1 does not wrap. + if (IsSigned) + RHS = getAddExpr(sv, SCEV::FlagNSW); + else + RHS = getAddExpr(sv, SCEV::FlagNUW); + } + } ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, AllowPredicates); if (EL.hasAnyInfo()) @@ -9238,8 +9321,14 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, if (Switch->getDefaultDest() == ExitingBlock) return getCouldNotCompute(); - assert(L->contains(Switch->getDefaultDest()) && - "Default case must not exit the loop!"); + // if not using enzyme executes by default + // if using enzyme and the code is guaranteed unreachable, + // the default destination doesn't matter + if (!AssumeLoopFinite || + !GuaranteedUnreachable.count(Switch->getDefaultDest())) { + assert(L->contains(Switch->getDefaultDest()) && + "Default case must not exit the loop!"); + } const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); @@ -12752,9 +12841,9 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // If RHS <=u Limit, then there must exist a value V in the sequence // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and // V <=u UINT_MAX. Thus, we must exit the loop before unsigned - // overflow occurs. This limit also implies that a signed comparison - // (in the wide bitwidth) is equivalent to an unsigned comparison as - // the high bits on both sides must be zero. + // overflow occurs. This limit also implies that a signed + // comparison (in the wide bitwidth) is equivalent to an unsigned + // comparison as the high bits on both sides must be zero. APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this)); APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1); Limit = Limit.zext(OuterBitWidth); @@ -12765,6 +12854,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, Flags = setFlags(Flags, SCEV::FlagNUW); setNoWrapFlags(const_cast(AR), Flags); + if (AR->hasNoUnsignedWrap()) { // Emulate what getZeroExtendExpr would have done during construction // if we'd been able to infer the fact just above at that time. @@ -12848,6 +12938,13 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, !loopHasNoAbnormalExits(L)) return getCouldNotCompute(); + // This bailout is protecting the logic in computeMaxBECountForLT which + // has not yet been sufficiently auditted or tested with negative strides. + // We used to filter out all known-non-positive cases here, we're in the + // process of being less restrictive bit by bit. + if (AssumeLoopFinite && IsSigned && isKnownNonPositive(Stride)) + return getCouldNotCompute(); + if (!isKnownNonZero(Stride)) { // If we have a step of zero, and RHS isn't invariant in L, we don't know // if it might eventually be greater than start and if so, on which @@ -12977,13 +13074,20 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (!BECount) { auto canProveRHSGreaterThanEqualStart = [&]() { auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; - const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L); - const SCEV *GuardedStart = applyLoopGuards(OrigStart, L); - if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) || - isKnownPredicate(CondGE, GuardedRHS, GuardedStart)) + if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart)) { return true; - + } + // In the Enzyme MustExitScalarEvolutionCode, this check was missing + // I do not have enough context to know if these two checks should be + // mutually Exclusive. If they aren't then this bool check is unnecessary + if (!AssumeLoopFinite) { + const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L); + const SCEV *GuardedStart = applyLoopGuards(OrigStart, L); + + if (isKnownPredicate(CondGE, GuardedRHS, GuardedStart)) + return true; + } // (RHS > Start - 1) implies RHS >= Start. // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if // "Start - 1" doesn't overflow. @@ -13120,7 +13224,10 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isa(ConstantMaxBECount) && !isa(BECount)) ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); - + if (AssumeLoopFinite) { + return ExitLimit(BECount, ConstantMaxBECount, ConstantMaxBECount, MaxOrZero, + Predicates); + } const SCEV *SymbolicMaxBECount = isa(BECount) ? ConstantMaxBECount : BECount; return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,