@@ -833,7 +833,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
833833 setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
834834 ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
835835 ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
836- ISD::TRUNCATE, ISD::LOAD, ISD::BITCAST});
836+ ISD::TRUNCATE, ISD::LOAD, ISD::STORE, ISD:: BITCAST});
837837
838838 // setcc for f16x2 and bf16x2 needs special handling to prevent
839839 // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3092,10 +3092,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
30923092 if (Op.getValueType () == MVT::i1)
30933093 return LowerLOADi1 (Op, DAG);
30943094
3095- // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3096- // unaligned loads and have to handle it here.
3095+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3096+ // handle unaligned loads and have to handle it here.
30973097 EVT VT = Op.getValueType ();
3098- if (Isv2x16VT (VT) || VT == MVT::v4i8) {
3098+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) {
30993099 LoadSDNode *Load = cast<LoadSDNode>(Op);
31003100 EVT MemVT = Load->getMemoryVT ();
31013101 if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
@@ -3139,15 +3139,15 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31393139 if (VT == MVT::i1)
31403140 return LowerSTOREi1 (Op, DAG);
31413141
3142- // v2f16 is legal, so we can't rely on legalizer to handle unaligned
3143- // stores and have to handle it here.
3144- if ((Isv2x16VT (VT) || VT == MVT::v4i8) &&
3142+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3143+ // handle unaligned stores and have to handle it here.
3144+ if ((Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) &&
31453145 !allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
31463146 VT, *Store->getMemOperand ()))
31473147 return expandUnalignedStore (Store, DAG);
31483148
3149- // v2f16, v2bf16 and v2i16 don't need special handling.
3150- if (Isv2x16VT (VT) || VT == MVT::v4i8)
3149+ // v2f16/ v2bf16/ v2i16/v4i8/v2f32 don't need special handling.
3150+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 )
31513151 return SDValue ();
31523152
31533153 if (VT.isVector ())
@@ -3156,8 +3156,8 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31563156 return SDValue ();
31573157}
31583158
3159- SDValue
3160- NVPTXTargetLowering::LowerSTOREVector (SDValue Op, SelectionDAG &DAG) const {
3159+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG,
3160+ const SmallVectorImpl<SDValue> &Elements) {
31613161 SDNode *N = Op.getNode ();
31623162 SDValue Val = N->getOperand (1 );
31633163 SDLoc DL (N);
@@ -3224,6 +3224,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32243224 SDValue SubVector = DAG.getBuildVector (EltVT, DL, SubVectorElts);
32253225 Ops.push_back (SubVector);
32263226 }
3227+ } else if (!Elements.empty ()) {
3228+ Ops.insert (Ops.end (), Elements.begin (), Elements.end ());
32273229 } else {
32283230 for (unsigned i = 0 ; i < NumElts; ++i) {
32293231 SDValue ExtVal = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
@@ -3241,10 +3243,19 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32413243 DAG.getMemIntrinsicNode (Opcode, DL, DAG.getVTList (MVT::Other), Ops,
32423244 MemSD->getMemoryVT (), MemSD->getMemOperand ());
32433245
3244- // return DCI.CombineTo(N, NewSt, true);
32453246 return NewSt;
32463247}
32473248
3249+ // Default variant where we don't pass in elements.
3250+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG) {
3251+ return convertVectorStore (Op, DAG, SmallVector<SDValue>{});
3252+ }
3253+
3254+ SDValue NVPTXTargetLowering::LowerSTOREVector (SDValue Op,
3255+ SelectionDAG &DAG) const {
3256+ return convertVectorStore (Op, DAG);
3257+ }
3258+
32483259// st i1 v, addr
32493260// =>
32503261// v1 = zxt v to i16
@@ -5400,6 +5411,9 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
54005411 // -->
54015412 // StoreRetvalV2 {a, b}
54025413 // likewise for V2 -> V4 case
5414+ //
5415+ // We also handle target-independent stores, which require us to first
5416+ // convert to StoreV2.
54035417
54045418 std::optional<NVPTXISD::NodeType> NewOpcode;
54055419 switch (N->getOpcode ()) {
@@ -5425,8 +5439,8 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
54255439 SDValue CurrentOp = N->getOperand (I);
54265440 if (CurrentOp->getOpcode () == ISD::BUILD_VECTOR) {
54275441 assert (CurrentOp.getValueType () == MVT::v2f32);
5428- NewOps.push_back (CurrentOp.getNode ()-> getOperand (0 ));
5429- NewOps.push_back (CurrentOp.getNode ()-> getOperand (1 ));
5442+ NewOps.push_back (CurrentOp.getOperand (0 ));
5443+ NewOps.push_back (CurrentOp.getOperand (1 ));
54305444 } else {
54315445 NewOps.clear ();
54325446 break ;
@@ -6197,6 +6211,18 @@ static SDValue PerformBITCASTCombine(SDNode *N,
61976211 return SDValue ();
61986212}
61996213
6214+ static SDValue PerformStoreCombine (SDNode *N,
6215+ TargetLowering::DAGCombinerInfo &DCI) {
6216+ // check if the store'd value can be scalarized
6217+ SDValue StoredVal = N->getOperand (1 );
6218+ if (StoredVal.getValueType () == MVT::v2f32 &&
6219+ StoredVal.getOpcode () == ISD::BUILD_VECTOR) {
6220+ SmallVector<SDValue> Elements (StoredVal->op_values ());
6221+ return convertVectorStore (SDValue (N, 0 ), DCI.DAG , Elements);
6222+ }
6223+ return SDValue ();
6224+ }
6225+
62006226SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
62016227 DAGCombinerInfo &DCI) const {
62026228 CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -6226,6 +6252,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
62266252 case NVPTXISD::LoadParam:
62276253 case NVPTXISD::LoadParamV2:
62286254 return PerformLoadCombine (N, DCI);
6255+ case ISD::STORE:
6256+ return PerformStoreCombine (N, DCI);
62296257 case NVPTXISD::StoreParam:
62306258 case NVPTXISD::StoreParamV2:
62316259 case NVPTXISD::StoreParamV4:
0 commit comments