Skip to content

Commit

Permalink
[LoopVectorize] Vectorize select-cmp reduction pattern for increasing…
Browse files Browse the repository at this point in the history
… integer induction variable (llvm#67812)

Consider the following loop:
```
  int rdx = init;
  for (int i = 0; i < n; ++i)
    rdx = (a[i] > b[i]) ? i : rdx;
```
We can vectorize this loop if `i` is an increasing induction variable.
The final reduced value will be the maximum of `i` that the condition
`a[i] > b[i]` is satisfied, or the start value `init`.

This patch added new RecurKind enums - IFindLastIV and FFindLastIV.

---------

Co-authored-by: Alexey Bataev <5361294+alexey-bataev@users.noreply.github.com>
  • Loading branch information
Mel-Chen and alexey-bataev authored Dec 12, 2024
1 parent 0876c11 commit b3cba9b
Show file tree
Hide file tree
Showing 12 changed files with 4,115 additions and 717 deletions.
38 changes: 35 additions & 3 deletions llvm/include/llvm/Analysis/IVDescriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,16 @@ enum class RecurKind {
FMulAdd, ///< Sum of float products with llvm.fmuladd(a * b + sum).
IAnyOf, ///< Any_of reduction with select(icmp(),x,y) where one of (x,y) is
///< loop invariant, and both x and y are integer type.
FAnyOf ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
FAnyOf, ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
///< loop invariant, and both x and y are integer type.
// TODO: Any_of reduction need not be restricted to integer type only.
IFindLastIV, ///< FindLast reduction with select(icmp(),x,y) where one of
///< (x,y) is increasing loop induction, and both x and y are
///< integer type.
FFindLastIV ///< FindLast reduction with select(fcmp(),x,y) where one of (x,y)
///< is increasing loop induction, and both x and y are integer
///< type.
// TODO: Any_of and FindLast reduction need not be restricted to integer type
// only.
};

/// The RecurrenceDescriptor is used to identify recurrences variables in a
Expand Down Expand Up @@ -124,7 +131,7 @@ class RecurrenceDescriptor {
/// the returned struct.
static InstDesc isRecurrenceInstr(Loop *L, PHINode *Phi, Instruction *I,
RecurKind Kind, InstDesc &Prev,
FastMathFlags FuncFMF);
FastMathFlags FuncFMF, ScalarEvolution *SE);

/// Returns true if instruction I has multiple uses in Insts
static bool hasMultipleUsesOf(Instruction *I,
Expand All @@ -151,6 +158,16 @@ class RecurrenceDescriptor {
static InstDesc isAnyOfPattern(Loop *Loop, PHINode *OrigPhi, Instruction *I,
InstDesc &Prev);

/// Returns a struct describing whether the instruction is either a
/// Select(ICmp(A, B), X, Y), or
/// Select(FCmp(A, B), X, Y)
/// where one of (X, Y) is an increasing loop induction variable, and the
/// other is a PHI value.
// TODO: Support non-monotonic variable. FindLast does not need be restricted
// to increasing loop induction variables.
static InstDesc isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
Instruction *I, ScalarEvolution &SE);

/// Returns a struct describing if the instruction is a
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
static InstDesc isConditionalRdxPattern(RecurKind Kind, Instruction *I);
Expand Down Expand Up @@ -236,10 +253,25 @@ class RecurrenceDescriptor {
return Kind == RecurKind::IAnyOf || Kind == RecurKind::FAnyOf;
}

/// Returns true if the recurrence kind is of the form
/// select(cmp(),x,y) where one of (x,y) is increasing loop induction.
static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
return Kind == RecurKind::IFindLastIV || Kind == RecurKind::FFindLastIV;
}

/// Returns the type of the recurrence. This type can be narrower than the
/// actual type of the Phi if the recurrence has been type-promoted.
Type *getRecurrenceType() const { return RecurrenceType; }

/// Returns the sentinel value for FindLastIV recurrences to replace the start
/// value.
Value *getSentinelValue() const {
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
Type *Ty = StartValue->getType();
return ConstantInt::get(Ty,
APInt::getSignedMinValue(Ty->getIntegerBitWidth()));
}

/// Returns a reference to the instructions used for type-promoting the
/// recurrence.
const SmallPtrSet<Instruction *, 8> &getCastInsts() const { return CastInsts; }
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/Transforms/Utils/LoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,12 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
const RecurrenceDescriptor &Desc,
PHINode *OrigPhi);

/// Create a reduction of the given vector \p Src for a reduction of the
/// kind RecurKind::IFindLastIV or RecurKind::FFindLastIV. The reduction
/// operation is described by \p Desc.
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
const RecurrenceDescriptor &Desc);

/// Create a generic reduction using a recurrence descriptor \p Desc
/// Fast-math-flags are propagated using the RecurrenceDescriptor.
Value *createReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc,
Expand Down
113 changes: 108 additions & 5 deletions llvm/lib/Analysis/IVDescriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
case RecurKind::UMin:
case RecurKind::IAnyOf:
case RecurKind::FAnyOf:
case RecurKind::IFindLastIV:
case RecurKind::FFindLastIV:
return true;
}
return false;
Expand Down Expand Up @@ -372,7 +374,7 @@ bool RecurrenceDescriptor::AddReductionVar(
// type-promoted).
if (Cur != Start) {
ReduxDesc =
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF);
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF, SE);
ExactFPMathInst = ExactFPMathInst == nullptr
? ReduxDesc.getExactFPMathInst()
: ExactFPMathInst;
Expand Down Expand Up @@ -658,6 +660,95 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
: RecurKind::FAnyOf);
}

// We are looking for loops that do something like this:
// int r = 0;
// for (int i = 0; i < n; i++) {
// if (src[i] > 3)
// r = i;
// }
// The reduction value (r) is derived from either the values of an increasing
// induction variable (i) sequence, or from the start value (0).
// The LLVM IR generated for such loops would be as follows:
// for.body:
// %r = phi i32 [ %spec.select, %for.body ], [ 0, %entry ]
// %i = phi i32 [ %inc, %for.body ], [ 0, %entry ]
// ...
// %cmp = icmp sgt i32 %5, 3
// %spec.select = select i1 %cmp, i32 %i, i32 %r
// %inc = add nsw i32 %i, 1
// ...
// Since 'i' is an increasing induction variable, the reduction value after the
// loop will be the maximum value of 'i' that the condition (src[i] > 3) is
// satisfied, or the start value (0 in the example above). When the start value
// of the increasing induction variable 'i' is greater than the minimum value of
// the data type, we can use the minimum value of the data type as a sentinel
// value to replace the start value. This allows us to perform a single
// reduction max operation to obtain the final reduction result.
// TODO: It is possible to solve the case where the start value is the minimum
// value of the data type or a non-constant value by using mask and multiple
// reduction operations.
RecurrenceDescriptor::InstDesc
RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
Instruction *I, ScalarEvolution &SE) {
// TODO: Support the vectorization of FindLastIV when the reduction phi is
// used by more than one select instruction. This vectorization is only
// performed when the SCEV of each increasing induction variable used by the
// select instructions is identical.
if (!OrigPhi->hasOneUse())
return InstDesc(false, I);

// TODO: Match selects with multi-use cmp conditions.
Value *NonRdxPhi = nullptr;
if (!match(I, m_CombineOr(m_Select(m_OneUse(m_Cmp()), m_Value(NonRdxPhi),
m_Specific(OrigPhi)),
m_Select(m_OneUse(m_Cmp()), m_Specific(OrigPhi),
m_Value(NonRdxPhi)))))
return InstDesc(false, I);

auto IsIncreasingLoopInduction = [&](Value *V) {
Type *Ty = V->getType();
if (!SE.isSCEVable(Ty))
return false;

auto *AR = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(V));
if (!AR || AR->getLoop() != TheLoop)
return false;

const SCEV *Step = AR->getStepRecurrence(SE);
if (!SE.isKnownPositive(Step))
return false;

const ConstantRange IVRange = SE.getSignedRange(AR);
unsigned NumBits = Ty->getIntegerBitWidth();
// Keep the minimum value of the recurrence type as the sentinel value.
// The maximum acceptable range for the increasing induction variable,
// called the valid range, will be defined as
// [<sentinel value> + 1, <sentinel value>)
// where <sentinel value> is SignedMin(<recurrence type>)
// TODO: This range restriction can be lifted by adding an additional
// virtual OR reduction.
const APInt Sentinel = APInt::getSignedMinValue(NumBits);
const ConstantRange ValidRange =
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
<< ", and the signed range of " << *AR << " is "
<< IVRange << "\n");
// Ensure the induction variable does not wrap around by verifying that its
// range is fully contained within the valid range.
return ValidRange.contains(IVRange);
};

// We are looking for selects of the form:
// select(cmp(), phi, increasing_loop_induction) or
// select(cmp(), increasing_loop_induction, phi)
// TODO: Support for monotonically decreasing induction variable
if (!IsIncreasingLoopInduction(NonRdxPhi))
return InstDesc(false, I);

return InstDesc(I, isa<ICmpInst>(I->getOperand(0)) ? RecurKind::IFindLastIV
: RecurKind::FFindLastIV);
}

RecurrenceDescriptor::InstDesc
RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
const InstDesc &Prev) {
Expand Down Expand Up @@ -756,10 +847,9 @@ RecurrenceDescriptor::isConditionalRdxPattern(RecurKind Kind, Instruction *I) {
return InstDesc(true, SI);
}

RecurrenceDescriptor::InstDesc
RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
Instruction *I, RecurKind Kind,
InstDesc &Prev, FastMathFlags FuncFMF) {
RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
Loop *L, PHINode *OrigPhi, Instruction *I, RecurKind Kind, InstDesc &Prev,
FastMathFlags FuncFMF, ScalarEvolution *SE) {
assert(Prev.getRecKind() == RecurKind::None || Prev.getRecKind() == Kind);
switch (I->getOpcode()) {
default:
Expand Down Expand Up @@ -789,6 +879,8 @@ RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul ||
Kind == RecurKind::Add || Kind == RecurKind::Mul)
return isConditionalRdxPattern(Kind, I);
if (isFindLastIVRecurrenceKind(Kind) && SE)
return isFindLastIVPattern(L, OrigPhi, I, *SE);
[[fallthrough]];
case Instruction::FCmp:
case Instruction::ICmp:
Expand Down Expand Up @@ -893,6 +985,15 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
<< *Phi << "\n");
return true;
}
if (AddReductionVar(Phi, RecurKind::IFindLastIV, TheLoop, FMF, RedDes, DB, AC,
DT, SE)) {
LLVM_DEBUG(dbgs() << "Found a "
<< (RedDes.getRecurrenceKind() == RecurKind::FFindLastIV
? "F"
: "I")
<< "FindLastIV reduction PHI." << *Phi << "\n");
return true;
}
if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
SE)) {
LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
Expand Down Expand Up @@ -1048,12 +1149,14 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
case RecurKind::UMax:
case RecurKind::UMin:
case RecurKind::IAnyOf:
case RecurKind::IFindLastIV:
return Instruction::ICmp;
case RecurKind::FMax:
case RecurKind::FMin:
case RecurKind::FMaximum:
case RecurKind::FMinimum:
case RecurKind::FAnyOf:
case RecurKind::FFindLastIV:
return Instruction::FCmp;
default:
llvm_unreachable("Unknown recurrence operation");
Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/Transforms/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,23 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
return Builder.CreateSelect(AnyOf, NewVal, InitVal, "rdx.select");
}

Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
const RecurrenceDescriptor &Desc) {
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
Desc.getRecurrenceKind()) &&
"Unexpected reduction kind");
Value *StartVal = Desc.getRecurrenceStartValue();
Value *Sentinel = Desc.getSentinelValue();
Value *MaxRdx = Src->getType()->isVectorTy()
? Builder.CreateIntMaxReduce(Src, true)
: Src;
// Correct the final reduction result back to the start value if the maximum
// reduction is sentinel value.
Value *Cmp =
Builder.CreateCmp(CmpInst::ICMP_NE, MaxRdx, Sentinel, "rdx.select.cmp");
return Builder.CreateSelect(Cmp, MaxRdx, StartVal, "rdx.select");
}

Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
FastMathFlags Flags) {
bool Negative = false;
Expand Down Expand Up @@ -1315,6 +1332,8 @@ Value *llvm::createReduction(IRBuilderBase &B,
RecurKind RK = Desc.getRecurrenceKind();
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
return createAnyOfReduction(B, Src, Desc, OrigPhi);
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
return createFindLastIVReduction(B, Src, Desc);

return createSimpleReduction(B, Src, RK);
}
Expand Down
11 changes: 7 additions & 4 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5185,8 +5185,9 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF,
HasReductions &&
any_of(Legal->getReductionVars(), [&](auto &Reduction) -> bool {
const RecurrenceDescriptor &RdxDesc = Reduction.second;
return RecurrenceDescriptor::isAnyOfRecurrenceKind(
RdxDesc.getRecurrenceKind());
RecurKind RK = RdxDesc.getRecurrenceKind();
return RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK);
});
if (HasSelectCmpReductions) {
LLVM_DEBUG(dbgs() << "LV: Not interleaving select-cmp reductions.\n");
Expand Down Expand Up @@ -9449,8 +9450,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(

const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
RecurKind Kind = RdxDesc.getRecurrenceKind();
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
"AnyOf reductions are not allowed for in-loop reductions");
assert(
!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
"AnyOf and FindLast reductions are not allowed for in-loop reductions");

// Collect the chain of "link" recipes for the reduction starting at PhiR.
SetVector<VPSingleDefRecipe *> Worklist;
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20451,6 +20451,8 @@ class HorizontalReduction {
case RecurKind::FMulAdd:
case RecurKind::IAnyOf:
case RecurKind::FAnyOf:
case RecurKind::IFindLastIV:
case RecurKind::FFindLastIV:
case RecurKind::None:
llvm_unreachable("Unexpected reduction kind for repeated scalar.");
}
Expand Down Expand Up @@ -20548,6 +20550,8 @@ class HorizontalReduction {
case RecurKind::FMulAdd:
case RecurKind::IAnyOf:
case RecurKind::FAnyOf:
case RecurKind::IFindLastIV:
case RecurKind::FFindLastIV:
case RecurKind::None:
llvm_unreachable("Unexpected reduction kind for reused scalars.");
}
Expand Down
20 changes: 19 additions & 1 deletion llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,9 @@ Value *VPInstruction::generate(VPTransformState &State) {
if (Op != Instruction::ICmp && Op != Instruction::FCmp)
ReducedPartRdx = Builder.CreateBinOp(
(Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
ReducedPartRdx =
createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx, RdxPart);
else
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
}
Expand All @@ -575,7 +578,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
// Create the reduction after the loop. Note that inloop reductions create
// the target reduction in the loop using a Reduction recipe.
if ((State.VF.isVector() ||
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
!PhiR->isInLoop()) {
ReducedPartRdx =
createReduction(Builder, RdxDesc, ReducedPartRdx, OrigPhi);
Expand Down Expand Up @@ -3398,6 +3402,20 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
Builder.SetInsertPoint(VectorPH->getTerminator());
StartV = Iden = State.get(StartVPV);
}
} else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) {
// [I|F]FindLastIV will use a sentinel value to initialize the reduction
// phi. In the exit block, ComputeReductionResult will generate checks to
// verify if the reduction result is the sentinel value. If the result is
// the sentinel value, it will be corrected back to the start value.
// TODO: The sentinel value is not always necessary. When the start value is
// a constant, and smaller than the start value of the induction variable,
// the start value can be directly used to initialize the reduction phi.
StartV = Iden = RdxDesc.getSentinelValue();
if (!ScalarPHI) {
IRBuilderBase::InsertPointGuard IPBuilder(Builder);
Builder.SetInsertPoint(VectorPH->getTerminator());
StartV = Iden = Builder.CreateVectorSplat(State.VF, Iden);
}
} else {
Iden = llvm::getRecurrenceIdentity(RK, VecTy->getScalarType(),
RdxDesc.getFastMathFlags());
Expand Down
Loading

0 comments on commit b3cba9b

Please sign in to comment.