@@ -13785,11 +13785,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
1378513785 return SDValue();
1378613786
1378713787 EVT BaseLdVT = BaseLd->getValueType(0);
13788- SDValue BasePtr = BaseLd->getBasePtr();
1378913788
1379013789 // Go through the loads and check that they're strided
13791- SmallVector<SDValue> Ptrs ;
13792- Ptrs .push_back(BasePtr );
13790+ SmallVector<LoadSDNode *> Lds ;
13791+ Lds .push_back(BaseLd );
1379313792 Align Align = BaseLd->getAlign();
1379413793 for (SDValue Op : N->ops().drop_front()) {
1379513794 auto *Ld = dyn_cast<LoadSDNode>(Op);
@@ -13798,60 +13797,38 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
1379813797 Ld->getValueType(0) != BaseLdVT)
1379913798 return SDValue();
1380013799
13801- Ptrs .push_back(Ld->getBasePtr() );
13800+ Lds .push_back(Ld);
1380213801
1380313802 // The common alignment is the most restrictive (smallest) of all the loads
1380413803 Align = std::min(Align, Ld->getAlign());
1380513804 }
1380613805
13807- auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
13808- SDValue Stride;
13809- for (auto Idx : enumerate(Ptrs)) {
13810- if (Idx.index() == 0)
13811- continue;
13812- SDValue Ptr = Idx.value();
13813- // Check that each load's pointer is (add LastPtr, Stride)
13814- if (Ptr.getOpcode() != ISD::ADD ||
13815- Ptr.getOperand(0) != Ptrs[Idx.index()-1])
13816- return SDValue();
13817- SDValue Offset = Ptr.getOperand(1);
13818- if (!Stride)
13819- Stride = Offset;
13820- else if (Offset != Stride)
13821- return SDValue();
13822- }
13823- return Stride;
13824- };
13825- auto matchReverseStrided = [](ArrayRef<SDValue> Ptrs) {
13826- SDValue Stride;
13827- for (auto Idx : enumerate(Ptrs)) {
13828- if (Idx.index() == Ptrs.size() - 1)
13829- continue;
13830- SDValue Ptr = Idx.value();
13831- // Check that each load's pointer is (add NextPtr, Stride)
13832- if (Ptr.getOpcode() != ISD::ADD ||
13833- Ptr.getOperand(0) != Ptrs[Idx.index()+1])
13834- return SDValue();
13835- SDValue Offset = Ptr.getOperand(1);
13836- if (!Stride)
13837- Stride = Offset;
13838- else if (Offset != Stride)
13839- return SDValue();
13840- }
13841- return Stride;
13806+ using PtrDiff = std::pair<SDValue, bool>;
13807+ auto GetPtrDiff = [](LoadSDNode *Ld1,
13808+ LoadSDNode *Ld2) -> std::optional<PtrDiff> {
13809+ SDValue P1 = Ld1->getBasePtr();
13810+ SDValue P2 = Ld2->getBasePtr();
13811+ if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
13812+ return {{P2.getOperand(1), false}};
13813+ if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2)
13814+ return {{P1.getOperand(1), true}};
13815+
13816+ return std::nullopt;
1384213817 };
1384313818
13844- bool Reversed = false;
13845- SDValue Stride = matchForwardStrided(Ptrs);
13846- if (!Stride) {
13847- Stride = matchReverseStrided(Ptrs);
13848- Reversed = true;
13849- // TODO: At this point, we've successfully matched a generalized gather
13850- // load. Maybe we should emit that, and then move the specialized
13851- // matchers above and below into a DAG combine?
13852- if (!Stride)
13819+ // Get the distance between the first and second loads
13820+ auto BaseDiff = GetPtrDiff(Lds[0], Lds[1]);
13821+ if (!BaseDiff)
13822+ return SDValue();
13823+
13824+ // Check all the loads are the same distance apart
13825+ for (auto *It = Lds.begin() + 1; It != Lds.end() - 1; It++)
13826+ if (GetPtrDiff(*It, *std::next(It)) != BaseDiff)
1385313827 return SDValue();
13854- }
13828+
13829+ // TODO: At this point, we've successfully matched a generalized gather
13830+ // load. Maybe we should emit that, and then move the specialized
13831+ // matchers above and below into a DAG combine?
1385513832
1385613833 // Get the widened scalar type, e.g. v4i8 -> i64
1385713834 unsigned WideScalarBitWidth =
@@ -13867,26 +13844,25 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
1386713844 if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
1386813845 return SDValue();
1386913846
13847+ auto [Stride, MustNegateStride] = *BaseDiff;
13848+ if (MustNegateStride)
13849+ Stride = DAG.getNegative(Stride, DL, Stride.getValueType());
13850+
1387013851 SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
1387113852 SDValue IntID =
1387213853 DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
1387313854 Subtarget.getXLenVT());
13874- if (Reversed)
13875- Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
13855+
1387613856 SDValue AllOneMask =
1387713857 DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
1387813858 DAG.getConstant(1, DL, MVT::i1));
1387913859
13880- SDValue Ops[] = {BaseLd->getChain(),
13881- IntID,
13882- DAG.getUNDEF(WideVecVT),
13883- BasePtr,
13884- Stride,
13885- AllOneMask};
13860+ SDValue Ops[] = {BaseLd->getChain(), IntID, DAG.getUNDEF(WideVecVT),
13861+ BaseLd->getBasePtr(), Stride, AllOneMask};
1388613862
1388713863 uint64_t MemSize;
1388813864 if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
13889- ConstStride && !Reversed && ConstStride->getSExtValue() >= 0)
13865+ ConstStride && ConstStride->getSExtValue() >= 0)
1389013866 // total size = (elsize * n) + (stride - elsize) * (n-1)
1389113867 // = elsize + stride * (n-1)
1389213868 MemSize = WideScalarVT.getSizeInBits() +
0 commit comments