Skip to content

Commit abea562

Browse files
committed
[LV] Vectorize conditional scalar assignments
Based on Michael Maitland's previous work: #121222 This PR uses the existing recurrences code instead of introducing a new pass just for CSA autovec. I've also made recipes that are more generic. I've enabled it by default to see the impact on tests; if there are regressions we can put it behind a cli option.
1 parent f6d6d2d commit abea562

18 files changed

+1895
-266
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ enum class RecurKind {
7070
FindLastIVUMax, ///< FindLast reduction with select(cmp(),x,y) where one of
7171
///< (x,y) is increasing loop induction, and both x and y
7272
///< are integer type, producing a UMax reduction.
73+
FindLast, ///< FindLast reduction with select(cmp(),x,y) where x and y
74+
///< are an integer type, one is the current recurrence value,
75+
///< and the other is an arbitrary value.
7376
// clang-format on
7477
// TODO: Any_of and FindLast reduction need not be restricted to integer type
7578
// only.
@@ -175,13 +178,12 @@ class RecurrenceDescriptor {
175178
/// Returns a struct describing whether the instruction is either a
176179
/// Select(ICmp(A, B), X, Y), or
177180
/// Select(FCmp(A, B), X, Y)
178-
/// where one of (X, Y) is an increasing (FindLast) or decreasing (FindFirst)
179-
/// loop induction variable, and the other is a PHI value.
180-
// TODO: Support non-monotonic variable. FindLast does not need be restricted
181-
// to increasing loop induction variables.
182-
LLVM_ABI static InstDesc isFindIVPattern(RecurKind Kind, Loop *TheLoop,
183-
PHINode *OrigPhi, Instruction *I,
184-
ScalarEvolution &SE);
181+
/// where one of (X, Y) is an increasing (FindLastIV) or decreasing
182+
/// (FindFirstIV) loop induction variable, or an arbitrary integer value
183+
/// (FindLast), and the other is a PHI value.
184+
LLVM_ABI static InstDesc isFindPattern(RecurKind Kind, Loop *TheLoop,
185+
PHINode *OrigPhi, Instruction *I,
186+
ScalarEvolution &SE);
185187

186188
/// Returns a struct describing if the instruction is a
187189
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
@@ -305,6 +307,13 @@ class RecurrenceDescriptor {
305307
isFindLastIVRecurrenceKind(Kind);
306308
}
307309

310+
/// Returns true if the recurrence kind is of the form
311+
/// select(cmp(),x,y) where one of (x,y) is an arbitrary value and the
312+
/// other is a recurrence.
313+
static bool isFindLastRecurrenceKind(RecurKind Kind) {
314+
return Kind == RecurKind::FindLast;
315+
}
316+
308317
/// Returns the type of the recurrence. This type can be narrower than the
309318
/// actual type of the Phi if the recurrence has been type-promoted.
310319
Type *getRecurrenceType() const { return RecurrenceType; }

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5656
case RecurKind::FindFirstIVUMin:
5757
case RecurKind::FindLastIVSMax:
5858
case RecurKind::FindLastIVUMax:
59+
// TODO: Make type-agnostic.
60+
case RecurKind::FindLast:
5961
return true;
6062
}
6163
return false;
@@ -691,9 +693,9 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
691693
// value of the data type or a non-constant value by using mask and multiple
692694
// reduction operations.
693695
RecurrenceDescriptor::InstDesc
694-
RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop,
695-
PHINode *OrigPhi, Instruction *I,
696-
ScalarEvolution &SE) {
696+
RecurrenceDescriptor::isFindPattern(RecurKind Kind, Loop *TheLoop,
697+
PHINode *OrigPhi, Instruction *I,
698+
ScalarEvolution &SE) {
697699
// TODO: Support the vectorization of FindLastIV when the reduction phi is
698700
// used by more than one select instruction. This vectorization is only
699701
// performed when the SCEV of each increasing induction variable used by the
@@ -702,8 +704,10 @@ RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop,
702704
return InstDesc(false, I);
703705

704706
// We are looking for selects of the form:
705-
// select(cmp(), phi, loop_induction) or
706-
// select(cmp(), loop_induction, phi)
707+
// select(cmp(), phi, value) or
708+
// select(cmp(), value, phi)
709+
// where 'value' might be a loop induction variable
710+
// (for FindFirstIV/FindLastIV) or an arbitrary value (for FindLast).
707711
// TODO: Match selects with multi-use cmp conditions.
708712
Value *NonRdxPhi = nullptr;
709713
if (!match(I, m_CombineOr(m_Select(m_OneUse(m_Cmp()), m_Value(NonRdxPhi),
@@ -712,6 +716,25 @@ RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop,
712716
m_Value(NonRdxPhi)))))
713717
return InstDesc(false, I);
714718

719+
if (isFindLastRecurrenceKind(Kind)) {
720+
// Must be an integer scalar.
721+
Type *Type = OrigPhi->getType();
722+
if (!Type->isIntegerTy() && !Type->isPointerTy())
723+
return InstDesc(false, I);
724+
725+
// FIXME: Support more complex patterns, including multiple selects.
726+
// The Select must be used only outside the loop and by the PHI.
727+
for (User *U : I->users()) {
728+
if (U == OrigPhi)
729+
continue;
730+
if (auto *UI = dyn_cast<Instruction>(U); UI && !TheLoop->contains(UI))
731+
continue;
732+
return InstDesc(false, I);
733+
}
734+
735+
return InstDesc(I, RecurKind::FindLast);
736+
}
737+
715738
// Returns either FindFirstIV/FindLastIV, if such a pattern is found, or
716739
// std::nullopt.
717740
auto GetRecurKind = [&](Value *V) -> std::optional<RecurKind> {
@@ -920,8 +943,8 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
920943
Kind == RecurKind::Add || Kind == RecurKind::Mul ||
921944
Kind == RecurKind::Sub || Kind == RecurKind::AddChainWithSubs)
922945
return isConditionalRdxPattern(I);
923-
if (isFindIVRecurrenceKind(Kind) && SE)
924-
return isFindIVPattern(Kind, L, OrigPhi, I, *SE);
946+
if ((isFindIVRecurrenceKind(Kind) || isFindLastRecurrenceKind(Kind)) && SE)
947+
return isFindPattern(Kind, L, OrigPhi, I, *SE);
925948
[[fallthrough]];
926949
case Instruction::FCmp:
927950
case Instruction::ICmp:
@@ -1118,7 +1141,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
11181141
<< "\n");
11191142
return true;
11201143
}
1121-
1144+
if (AddReductionVar(Phi, RecurKind::FindLast, TheLoop, FMF, RedDes, DB, AC,
1145+
DT, SE)) {
1146+
LLVM_DEBUG(dbgs() << "Found a FindLast reduction PHI." << *Phi << "\n");
1147+
return true;
1148+
}
11221149
// Not a reduction of known type.
11231150
return false;
11241151
}
@@ -1248,6 +1275,8 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
12481275
case RecurKind::FMaximumNum:
12491276
case RecurKind::FMinimumNum:
12501277
return Instruction::FCmp;
1278+
case RecurKind::FindLast:
1279+
return Instruction::Select;
12511280
default:
12521281
llvm_unreachable("Unknown recurrence operation");
12531282
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5373,6 +5373,7 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
53735373
case RecurKind::FMax:
53745374
case RecurKind::FMulAdd:
53755375
case RecurKind::AnyOf:
5376+
case RecurKind::FindLast:
53765377
return true;
53775378
default:
53785379
return false;

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4559,6 +4559,12 @@ LoopVectorizationPlanner::selectInterleaveCount(VPlan &Plan, ElementCount VF,
45594559
any_of(Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis(),
45604560
IsaPred<VPReductionPHIRecipe>);
45614561

4562+
// FIXME: implement interleaving for FindLast transform correctly.
4563+
for (auto &[_, RdxDesc] : Legal->getReductionVars())
4564+
if (RecurrenceDescriptor::isFindLastRecurrenceKind(
4565+
RdxDesc.getRecurrenceKind()))
4566+
return 1;
4567+
45624568
// If we did not calculate the cost for VF (because the user selected the VF)
45634569
// then we calculate the cost of VF here.
45644570
if (LoopCost == 0) {
@@ -8475,6 +8481,10 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
84758481
*Plan, Builder))
84768482
return nullptr;
84778483

8484+
// Create whole-vector selects for find-last recurrences.
8485+
VPlanTransforms::runPass(VPlanTransforms::convertFindLastRecurrences, *Plan,
8486+
RecipeBuilder, Legal);
8487+
84788488
if (useActiveLaneMask(Style)) {
84798489
// TODO: Move checks to VPlanTransforms::addActiveLaneMask once
84808490
// TailFoldingStyle is visible there.
@@ -8569,6 +8579,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
85698579

85708580
RecurKind Kind = PhiR->getRecurrenceKind();
85718581
assert(
8582+
!RecurrenceDescriptor::isFindLastRecurrenceKind(Kind) &&
85728583
!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
85738584
!RecurrenceDescriptor::isFindIVRecurrenceKind(Kind) &&
85748585
"AnyOf and FindIV reductions are not allowed for in-loop reductions");
@@ -8872,7 +8883,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
88728883
RecurKind RK = RdxDesc.getRecurrenceKind();
88738884
if ((!RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) &&
88748885
!RecurrenceDescriptor::isFindIVRecurrenceKind(RK) &&
8875-
!RecurrenceDescriptor::isMinMaxRecurrenceKind(RK))) {
8886+
!RecurrenceDescriptor::isMinMaxRecurrenceKind(RK) &&
8887+
!RecurrenceDescriptor::isFindLastRecurrenceKind(RK))) {
88768888
VPBuilder PHBuilder(Plan->getVectorPreheader());
88778889
VPValue *Iden = Plan->getOrAddLiveIn(
88788890
getRecurrenceIdentity(RK, PhiTy, RdxDesc.getFastMathFlags()));
@@ -9996,6 +10008,21 @@ bool LoopVectorizePass::processLoop(Loop *L) {
999610008
// Override IC if user provided an interleave count.
999710009
IC = UserIC > 0 ? UserIC : IC;
999810010

10011+
// FIXME: Enable interleaving for last_active reductions.
10012+
if (any_of(make_second_range(LVL.getReductionVars()), [&](auto &RdxDesc) {
10013+
return RecurrenceDescriptor::isFindLastRecurrenceKind(
10014+
RdxDesc.getRecurrenceKind());
10015+
})) {
10016+
LLVM_DEBUG(dbgs() << "LV: Not interleaving without vectorization due "
10017+
<< "to conditional scalar assignments.\n");
10018+
IntDiagMsg = {
10019+
"ConditionalAssignmentPreventsScalarInterleaving",
10020+
"Unable to interleave without vectorization due to conditional "
10021+
"assignments"};
10022+
InterleaveLoop = false;
10023+
IC = 1;
10024+
}
10025+
999910026
// Emit diagnostic messages, if any.
1000010027
const char *VAPassName = Hints.vectorizeAnalysisPassName();
1000110028
if (!VectorizeLoop && !InterleaveLoop) {

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25123,6 +25123,7 @@ class HorizontalReduction {
2512325123
case RecurKind::FindFirstIVUMin:
2512425124
case RecurKind::FindLastIVSMax:
2512525125
case RecurKind::FindLastIVUMax:
25126+
case RecurKind::FindLast:
2512625127
case RecurKind::FMaxNum:
2512725128
case RecurKind::FMinNum:
2512825129
case RecurKind::FMaximumNum:
@@ -25264,6 +25265,7 @@ class HorizontalReduction {
2526425265
case RecurKind::FindFirstIVUMin:
2526525266
case RecurKind::FindLastIVSMax:
2526625267
case RecurKind::FindLastIVUMax:
25268+
case RecurKind::FindLast:
2526725269
case RecurKind::FMaxNum:
2526825270
case RecurKind::FMinNum:
2526925271
case RecurKind::FMaximumNum:
@@ -25370,6 +25372,7 @@ class HorizontalReduction {
2537025372
case RecurKind::FindFirstIVUMin:
2537125373
case RecurKind::FindLastIVSMax:
2537225374
case RecurKind::FindLastIVUMax:
25375+
case RecurKind::FindLast:
2537325376
case RecurKind::FMaxNum:
2537425377
case RecurKind::FMinNum:
2537525378
case RecurKind::FMaximumNum:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,8 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
10671067
/// Returns the value for vscale.
10681068
VScale,
10691069
OpsEnd = VScale,
1070+
/// Extracts the last active lane based on a predicate vector operand.
1071+
ExtractLastActive,
10701072
};
10711073

10721074
/// Returns true if this VPInstruction generates scalar values for all lanes.

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
118118
return Type::getIntNTy(Ctx, 64);
119119
case VPInstruction::ExtractLastElement:
120120
case VPInstruction::ExtractLastLanePerPart:
121-
case VPInstruction::ExtractPenultimateElement: {
121+
case VPInstruction::ExtractPenultimateElement:
122+
case VPInstruction::ExtractLastActive: {
122123
Type *BaseTy = inferScalarType(R->getOperand(0));
123124
if (auto *VecTy = dyn_cast<VectorType>(BaseTy))
124125
return VecTy->getElementType();

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
546546
case VPInstruction::ActiveLaneMask:
547547
case VPInstruction::ComputeAnyOfResult:
548548
case VPInstruction::ReductionStartVector:
549+
case VPInstruction::ExtractLastActive:
549550
return 3;
550551
case VPInstruction::ComputeFindIVResult:
551552
return 4;
@@ -999,6 +1000,17 @@ Value *VPInstruction::generate(VPTransformState &State) {
9991000
}
10001001
case VPInstruction::ResumeForEpilogue:
10011002
return State.get(getOperand(0), true);
1003+
case VPInstruction::ExtractLastActive: {
1004+
Value *Data = State.get(getOperand(0));
1005+
Value *Mask = State.get(getOperand(1));
1006+
Value *Default = State.get(getOperand(2), /*IsScalar=*/true);
1007+
Type *VTy = Data->getType();
1008+
1009+
Module *M = State.Builder.GetInsertBlock()->getModule();
1010+
Function *ExtractLast = Intrinsic::getOrInsertDeclaration(
1011+
M, Intrinsic::experimental_vector_extract_last_active, {VTy});
1012+
return Builder.CreateCall(ExtractLast, {Data, Mask, Default});
1013+
}
10021014
default:
10031015
llvm_unreachable("Unsupported opcode for instruction");
10041016
}
@@ -1135,6 +1147,15 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
11351147
{PredTy, Type::getInt1Ty(Ctx.LLVMCtx)});
11361148
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
11371149
}
1150+
case VPInstruction::ExtractLastActive: {
1151+
Type *ScalarTy = Ctx.Types.inferScalarType(this);
1152+
Type *VecTy = toVectorTy(ScalarTy, VF);
1153+
Type *MaskTy = toVectorTy(Type::getInt1Ty(Ctx.LLVMCtx), VF);
1154+
IntrinsicCostAttributes ICA(
1155+
Intrinsic::experimental_vector_extract_last_active, ScalarTy,
1156+
{VecTy, MaskTy, ScalarTy});
1157+
return Ctx.TTI.getIntrinsicInstrCost(ICA, Ctx.CostKind);
1158+
}
11381159
case VPInstruction::FirstOrderRecurrenceSplice: {
11391160
assert(VF.isVector() && "Scalar FirstOrderRecurrenceSplice?");
11401161
SmallVector<int> Mask(VF.getKnownMinValue());
@@ -1191,6 +1212,7 @@ bool VPInstruction::isVectorToScalar() const {
11911212
getOpcode() == VPInstruction::FirstActiveLane ||
11921213
getOpcode() == VPInstruction::ComputeAnyOfResult ||
11931214
getOpcode() == VPInstruction::ComputeFindIVResult ||
1215+
getOpcode() == VPInstruction::ExtractLastActive ||
11941216
getOpcode() == VPInstruction::ComputeReductionResult ||
11951217
getOpcode() == VPInstruction::AnyOf;
11961218
}
@@ -1252,6 +1274,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
12521274
case VPInstruction::ExtractPenultimateElement:
12531275
case VPInstruction::ActiveLaneMask:
12541276
case VPInstruction::FirstActiveLane:
1277+
case VPInstruction::ExtractLastActive:
12551278
case VPInstruction::FirstOrderRecurrenceSplice:
12561279
case VPInstruction::LogicalAnd:
12571280
case VPInstruction::Not:
@@ -1437,6 +1460,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
14371460
case VPInstruction::Unpack:
14381461
O << "unpack";
14391462
break;
1463+
case VPInstruction::ExtractLastActive:
1464+
O << "extract-last-active";
1465+
break;
14401466
default:
14411467
O << Instruction::getOpcodeName(getOpcode());
14421468
}

0 commit comments

Comments
 (0)