Skip to content

Commit e0aaa19

Browse files
[VectorCombine][RISCV] Convert VPIntrinsics with splat operands to splats (#65706)
of the scalar operation VP Intrinsics whose vector operands are both splat values may be simplified into the scalar version of the operation and the result is splatted. This issue is the intrinsic dual of #65072.
1 parent db4ba21 commit e0aaa19

File tree

2 files changed

+1564
-0
lines changed

2 files changed

+1564
-0
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class VectorCombine {
102102
bool foldInsExtFNeg(Instruction &I);
103103
bool foldBitcastShuf(Instruction &I);
104104
bool scalarizeBinopOrCmp(Instruction &I);
105+
bool scalarizeVPIntrinsic(Instruction &I);
105106
bool foldExtractedCmps(Instruction &I);
106107
bool foldSingleElementStore(Instruction &I);
107108
bool scalarizeLoadExtract(Instruction &I);
@@ -729,6 +730,111 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) {
729730
return true;
730731
}
731732

733+
/// VP Intrinsics whose vector operands are both splat values may be simplified
734+
/// into the scalar version of the operation and the result splatted. This
735+
/// can lead to scalarization down the line.
736+
bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
737+
if (!isa<VPIntrinsic>(I))
738+
return false;
739+
VPIntrinsic &VPI = cast<VPIntrinsic>(I);
740+
Value *Op0 = VPI.getArgOperand(0);
741+
Value *Op1 = VPI.getArgOperand(1);
742+
743+
if (!isSplatValue(Op0) || !isSplatValue(Op1))
744+
return false;
745+
746+
// For the binary VP intrinsics supported here, the result on disabled lanes
747+
// is a poison value. For now, only do this simplification if all lanes
748+
// are active.
749+
// TODO: Relax the condition that all lanes are active by using insertelement
750+
// on inactive lanes.
751+
auto IsAllTrueMask = [](Value *MaskVal) {
752+
if (Value *SplattedVal = getSplatValue(MaskVal))
753+
if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
754+
return ConstValue->isAllOnesValue();
755+
return false;
756+
};
757+
if (!IsAllTrueMask(VPI.getArgOperand(2)))
758+
return false;
759+
760+
// Check to make sure we support scalarization of the intrinsic
761+
Intrinsic::ID IntrID = VPI.getIntrinsicID();
762+
if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
763+
return false;
764+
765+
// Calculate cost of splatting both operands into vectors and the vector
766+
// intrinsic
767+
VectorType *VecTy = cast<VectorType>(VPI.getType());
768+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
769+
InstructionCost SplatCost =
770+
TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
771+
TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy);
772+
773+
// Calculate the cost of the VP Intrinsic
774+
SmallVector<Type *, 4> Args;
775+
for (Value *V : VPI.args())
776+
Args.push_back(V->getType());
777+
IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
778+
InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
779+
InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
780+
781+
// Determine scalar opcode
782+
std::optional<unsigned> FunctionalOpcode =
783+
VPI.getFunctionalOpcode();
784+
std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
785+
if (!FunctionalOpcode) {
786+
ScalarIntrID = VPI.getFunctionalIntrinsicID();
787+
if (!ScalarIntrID)
788+
return false;
789+
}
790+
791+
// Calculate cost of scalarizing
792+
InstructionCost ScalarOpCost = 0;
793+
if (ScalarIntrID) {
794+
IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
795+
ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
796+
} else {
797+
ScalarOpCost =
798+
TTI.getArithmeticInstrCost(*FunctionalOpcode, VecTy->getScalarType());
799+
}
800+
801+
// The existing splats may be kept around if other instructions use them.
802+
InstructionCost CostToKeepSplats =
803+
(SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
804+
InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
805+
806+
LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
807+
<< "\n");
808+
LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
809+
<< ", Cost of scalarizing:" << NewCost << "\n");
810+
811+
// We want to scalarize unless the vector variant actually has lower cost.
812+
if (OldCost < NewCost || !NewCost.isValid())
813+
return false;
814+
815+
// Scalarize the intrinsic
816+
ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
817+
Value *EVL = VPI.getArgOperand(3);
818+
const DataLayout &DL = VPI.getModule()->getDataLayout();
819+
bool MustHaveNonZeroVL =
820+
IntrID == Intrinsic::vp_sdiv || IntrID == Intrinsic::vp_udiv ||
821+
IntrID == Intrinsic::vp_srem || IntrID == Intrinsic::vp_urem;
822+
823+
if (!MustHaveNonZeroVL || isKnownNonZero(EVL, DL, 0, &AC, &VPI, &DT)) {
824+
Value *ScalarOp0 = getSplatValue(Op0);
825+
Value *ScalarOp1 = getSplatValue(Op1);
826+
Value *ScalarVal =
827+
ScalarIntrID
828+
? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
829+
{ScalarOp0, ScalarOp1})
830+
: Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
831+
ScalarOp0, ScalarOp1);
832+
replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
833+
return true;
834+
}
835+
return false;
836+
}
837+
732838
/// Match a vector binop or compare instruction with at least one inserted
733839
/// scalar operand and convert to scalar binop/cmp followed by insertelement.
734840
bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
@@ -1737,6 +1843,7 @@ bool VectorCombine::run() {
17371843
if (isa<VectorType>(I.getType())) {
17381844
MadeChange |= scalarizeBinopOrCmp(I);
17391845
MadeChange |= scalarizeLoadExtract(I);
1846+
MadeChange |= scalarizeVPIntrinsic(I);
17401847
}
17411848

17421849
if (Opcode == Instruction::Store)

0 commit comments

Comments
 (0)