@@ -102,6 +102,7 @@ class VectorCombine {
102
102
bool foldInsExtFNeg (Instruction &I);
103
103
bool foldBitcastShuf (Instruction &I);
104
104
bool scalarizeBinopOrCmp (Instruction &I);
105
+ bool scalarizeVPIntrinsic (Instruction &I);
105
106
bool foldExtractedCmps (Instruction &I);
106
107
bool foldSingleElementStore (Instruction &I);
107
108
bool scalarizeLoadExtract (Instruction &I);
@@ -729,6 +730,111 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) {
729
730
return true ;
730
731
}
731
732
733
+ // / VP Intrinsics whose vector operands are both splat values may be simplified
734
+ // / into the scalar version of the operation and the result splatted. This
735
+ // / can lead to scalarization down the line.
736
+ bool VectorCombine::scalarizeVPIntrinsic (Instruction &I) {
737
+ if (!isa<VPIntrinsic>(I))
738
+ return false ;
739
+ VPIntrinsic &VPI = cast<VPIntrinsic>(I);
740
+ Value *Op0 = VPI.getArgOperand (0 );
741
+ Value *Op1 = VPI.getArgOperand (1 );
742
+
743
+ if (!isSplatValue (Op0) || !isSplatValue (Op1))
744
+ return false ;
745
+
746
+ // For the binary VP intrinsics supported here, the result on disabled lanes
747
+ // is a poison value. For now, only do this simplification if all lanes
748
+ // are active.
749
+ // TODO: Relax the condition that all lanes are active by using insertelement
750
+ // on inactive lanes.
751
+ auto IsAllTrueMask = [](Value *MaskVal) {
752
+ if (Value *SplattedVal = getSplatValue (MaskVal))
753
+ if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
754
+ return ConstValue->isAllOnesValue ();
755
+ return false ;
756
+ };
757
+ if (!IsAllTrueMask (VPI.getArgOperand (2 )))
758
+ return false ;
759
+
760
+ // Check to make sure we support scalarization of the intrinsic
761
+ Intrinsic::ID IntrID = VPI.getIntrinsicID ();
762
+ if (!VPBinOpIntrinsic::isVPBinOp (IntrID))
763
+ return false ;
764
+
765
+ // Calculate cost of splatting both operands into vectors and the vector
766
+ // intrinsic
767
+ VectorType *VecTy = cast<VectorType>(VPI.getType ());
768
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
769
+ InstructionCost SplatCost =
770
+ TTI.getVectorInstrCost (Instruction::InsertElement, VecTy, CostKind, 0 ) +
771
+ TTI.getShuffleCost (TargetTransformInfo::SK_Broadcast, VecTy);
772
+
773
+ // Calculate the cost of the VP Intrinsic
774
+ SmallVector<Type *, 4 > Args;
775
+ for (Value *V : VPI.args ())
776
+ Args.push_back (V->getType ());
777
+ IntrinsicCostAttributes Attrs (IntrID, VecTy, Args);
778
+ InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost (Attrs, CostKind);
779
+ InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
780
+
781
+ // Determine scalar opcode
782
+ std::optional<unsigned > FunctionalOpcode =
783
+ VPI.getFunctionalOpcode ();
784
+ std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt ;
785
+ if (!FunctionalOpcode) {
786
+ ScalarIntrID = VPI.getFunctionalIntrinsicID ();
787
+ if (!ScalarIntrID)
788
+ return false ;
789
+ }
790
+
791
+ // Calculate cost of scalarizing
792
+ InstructionCost ScalarOpCost = 0 ;
793
+ if (ScalarIntrID) {
794
+ IntrinsicCostAttributes Attrs (*ScalarIntrID, VecTy->getScalarType (), Args);
795
+ ScalarOpCost = TTI.getIntrinsicInstrCost (Attrs, CostKind);
796
+ } else {
797
+ ScalarOpCost =
798
+ TTI.getArithmeticInstrCost (*FunctionalOpcode, VecTy->getScalarType ());
799
+ }
800
+
801
+ // The existing splats may be kept around if other instructions use them.
802
+ InstructionCost CostToKeepSplats =
803
+ (SplatCost * !Op0->hasOneUse ()) + (SplatCost * !Op1->hasOneUse ());
804
+ InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
805
+
806
+ LLVM_DEBUG (dbgs () << " Found a VP Intrinsic to scalarize: " << VPI
807
+ << " \n " );
808
+ LLVM_DEBUG (dbgs () << " Cost of Intrinsic: " << OldCost
809
+ << " , Cost of scalarizing:" << NewCost << " \n " );
810
+
811
+ // We want to scalarize unless the vector variant actually has lower cost.
812
+ if (OldCost < NewCost || !NewCost.isValid ())
813
+ return false ;
814
+
815
+ // Scalarize the intrinsic
816
+ ElementCount EC = cast<VectorType>(Op0->getType ())->getElementCount ();
817
+ Value *EVL = VPI.getArgOperand (3 );
818
+ const DataLayout &DL = VPI.getModule ()->getDataLayout ();
819
+ bool MustHaveNonZeroVL =
820
+ IntrID == Intrinsic::vp_sdiv || IntrID == Intrinsic::vp_udiv ||
821
+ IntrID == Intrinsic::vp_srem || IntrID == Intrinsic::vp_urem;
822
+
823
+ if (!MustHaveNonZeroVL || isKnownNonZero (EVL, DL, 0 , &AC, &VPI, &DT)) {
824
+ Value *ScalarOp0 = getSplatValue (Op0);
825
+ Value *ScalarOp1 = getSplatValue (Op1);
826
+ Value *ScalarVal =
827
+ ScalarIntrID
828
+ ? Builder.CreateIntrinsic (VecTy->getScalarType (), *ScalarIntrID,
829
+ {ScalarOp0, ScalarOp1})
830
+ : Builder.CreateBinOp ((Instruction::BinaryOps)(*FunctionalOpcode),
831
+ ScalarOp0, ScalarOp1);
832
+ replaceValue (VPI, *Builder.CreateVectorSplat (EC, ScalarVal));
833
+ return true ;
834
+ }
835
+ return false ;
836
+ }
837
+
732
838
// / Match a vector binop or compare instruction with at least one inserted
733
839
// / scalar operand and convert to scalar binop/cmp followed by insertelement.
734
840
bool VectorCombine::scalarizeBinopOrCmp (Instruction &I) {
@@ -1737,6 +1843,7 @@ bool VectorCombine::run() {
1737
1843
if (isa<VectorType>(I.getType ())) {
1738
1844
MadeChange |= scalarizeBinopOrCmp (I);
1739
1845
MadeChange |= scalarizeLoadExtract (I);
1846
+ MadeChange |= scalarizeVPIntrinsic (I);
1740
1847
}
1741
1848
1742
1849
if (Opcode == Instruction::Store)
0 commit comments