From 4e24ef483e739a314f046f8a6797091cfd2d11c6 Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Sat, 14 Oct 2023 13:14:48 -0400 Subject: [PATCH 1/4] [RISCV] Refactor performCONCAT_VECTORSCombine. NFC Instead of doing a forward pass for positive strides and a reverse pass for negative strides, we can just do one pass by negating the offset if the pointers do happen to be in reverse order. We can extend getPtrDiff later in #68726 to handle more constant offset sequences. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 81 +++++++-------------- 1 file changed, 25 insertions(+), 56 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index d7552317fd8bc..9912f19c9a501 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13785,11 +13785,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, return SDValue(); EVT BaseLdVT = BaseLd->getValueType(0); - SDValue BasePtr = BaseLd->getBasePtr(); // Go through the loads and check that they're strided - SmallVector Ptrs; - Ptrs.push_back(BasePtr); + SmallVector Lds; + Lds.push_back(BaseLd); Align Align = BaseLd->getAlign(); for (SDValue Op : N->ops().drop_front()) { auto *Ld = dyn_cast(Op); @@ -13798,58 +13797,33 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, Ld->getValueType(0) != BaseLdVT) return SDValue(); - Ptrs.push_back(Ld->getBasePtr()); + Lds.push_back(Ld); // The common alignment is the most restrictive (smallest) of all the loads Align = std::min(Align, Ld->getAlign()); } - auto matchForwardStrided = [](ArrayRef Ptrs) { - SDValue Stride; - for (auto Idx : enumerate(Ptrs)) { - if (Idx.index() == 0) - continue; - SDValue Ptr = Idx.value(); - // Check that each load's pointer is (add LastPtr, Stride) - if (Ptr.getOpcode() != ISD::ADD || - Ptr.getOperand(0) != Ptrs[Idx.index()-1]) - return SDValue(); - SDValue Offset = Ptr.getOperand(1); - if (!Stride) - Stride = Offset; - else if (Offset != Stride) - return SDValue(); - } - return Stride; - }; - auto matchReverseStrided = [](ArrayRef Ptrs) { - SDValue Stride; - for (auto Idx : enumerate(Ptrs)) { - if (Idx.index() == Ptrs.size() - 1) - continue; - SDValue Ptr = Idx.value(); - // Check that each load's pointer is (add NextPtr, Stride) - if (Ptr.getOpcode() != ISD::ADD || - Ptr.getOperand(0) != Ptrs[Idx.index()+1]) - return SDValue(); - SDValue Offset = Ptr.getOperand(1); - if (!Stride) - Stride = Offset; - else if (Offset != Stride) - return SDValue(); - } - return Stride; + auto getPtrDiff = [&DAG, &DL](LoadSDNode *Ld1, LoadSDNode *Ld2) { + SDValue P1 = Ld1->getBasePtr(); + SDValue P2 = Ld2->getBasePtr(); + if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1) + return P2.getOperand(1); + if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2) + return DAG.getNegative(P1.getOperand(1), DL, + P1.getOperand(1).getValueType()); + return SDValue(); }; - bool Reversed = false; - SDValue Stride = matchForwardStrided(Ptrs); - if (!Stride) { - Stride = matchReverseStrided(Ptrs); - Reversed = true; - // TODO: At this point, we've successfully matched a generalized gather - // load. Maybe we should emit that, and then move the specialized - // matchers above and below into a DAG combine? + SDValue Stride; + for (auto [Idx, Ld] : enumerate(Lds)) { + if (Idx == 0) + continue; + SDValue Offset = getPtrDiff(Lds[Idx - 1], Ld); + if (!Offset) + return SDValue(); if (!Stride) + Stride = Offset; + else if (Offset != Stride) return SDValue(); } @@ -13871,22 +13845,17 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL, Subtarget.getXLenVT()); - if (Reversed) - Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0)); + SDValue AllOneMask = DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL, DAG.getConstant(1, DL, MVT::i1)); - SDValue Ops[] = {BaseLd->getChain(), - IntID, - DAG.getUNDEF(WideVecVT), - BasePtr, - Stride, - AllOneMask}; + SDValue Ops[] = {BaseLd->getChain(), IntID, DAG.getUNDEF(WideVecVT), + BaseLd->getBasePtr(), Stride, AllOneMask}; uint64_t MemSize; if (auto *ConstStride = dyn_cast(Stride); - ConstStride && !Reversed && ConstStride->getSExtValue() >= 0) + ConstStride && ConstStride->getSExtValue() >= 0) // total size = (elsize * n) + (stride - elsize) * (n-1) // = elsize + stride * (n-1) MemSize = WideScalarVT.getSizeInBits() + From 12abeba420ac37b510c6f353da5cd9802631829e Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Mon, 16 Oct 2023 09:22:40 -0400 Subject: [PATCH 2/4] Don't create node if combine will bail --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 35 +++++++++++---------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 9912f19c9a501..4b3ffe182d1cb 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13803,29 +13803,28 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, Align = std::min(Align, Ld->getAlign()); } - auto getPtrDiff = [&DAG, &DL](LoadSDNode *Ld1, LoadSDNode *Ld2) { + using PtrDiff = std::pair; + auto GetPtrDiff = [](LoadSDNode *Ld1, + LoadSDNode *Ld2) -> std::optional { SDValue P1 = Ld1->getBasePtr(); SDValue P2 = Ld2->getBasePtr(); if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1) - return P2.getOperand(1); + return {{P2.getOperand(1), false}}; if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2) - return DAG.getNegative(P1.getOperand(1), DL, - P1.getOperand(1).getValueType()); - return SDValue(); + return {{P1.getOperand(1), true}}; + + return std::nullopt; }; - SDValue Stride; - for (auto [Idx, Ld] : enumerate(Lds)) { - if (Idx == 0) - continue; - SDValue Offset = getPtrDiff(Lds[Idx - 1], Ld); - if (!Offset) - return SDValue(); - if (!Stride) - Stride = Offset; - else if (Offset != Stride) + // Get the distance between the first and second loads + auto BaseDiff = GetPtrDiff(Lds[0], Lds[1]); + if (!BaseDiff) + return SDValue(); + + // Check all the loads are the same distance apart + for (auto *It = Lds.begin() + 1; It != Lds.end() - 1; It++) + if (GetPtrDiff(*It, *std::next(It)) != BaseDiff) return SDValue(); - } // Get the widened scalar type, e.g. v4i8 -> i64 unsigned WideScalarBitWidth = @@ -13841,6 +13840,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, if (!TLI.isLegalStridedLoadStore(WideVecVT, Align)) return SDValue(); + auto [Stride, Reversed] = *BaseDiff; + if (Reversed) + Stride = DAG.getNegative(Stride, DL, Stride.getValueType()); + SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other}); SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL, From 470a73a1a2f0c2da7275dbfecbf8f49a2aeb6f60 Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Mon, 16 Oct 2023 09:25:42 -0400 Subject: [PATCH 3/4] Add back comment --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 4b3ffe182d1cb..86b7c1f232f91 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13826,6 +13826,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, if (GetPtrDiff(*It, *std::next(It)) != BaseDiff) return SDValue(); + // TODO: At this point, we've successfully matched a generalized gather + // load. Maybe we should emit that, and then move the specialized + // matchers above and below into a DAG combine? + // Get the widened scalar type, e.g. v4i8 -> i64 unsigned WideScalarBitWidth = BaseLdVT.getScalarSizeInBits() * BaseLdVT.getVectorNumElements(); From 6ce338f2cf1ecca883510491a05cf664cdbc1870 Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Mon, 16 Oct 2023 12:55:17 -0400 Subject: [PATCH 4/4] Reversed->MustNegateStride --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 86b7c1f232f91..580dbc66ce196 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13844,8 +13844,8 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, if (!TLI.isLegalStridedLoadStore(WideVecVT, Align)) return SDValue(); - auto [Stride, Reversed] = *BaseDiff; - if (Reversed) + auto [Stride, MustNegateStride] = *BaseDiff; + if (MustNegateStride) Stride = DAG.getNegative(Stride, DL, Stride.getValueType()); SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});