diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index e74d184c0a35d..19e407414627d 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13589,6 +13589,52 @@ static bool matchIndexAsShuffle(EVT VT, SDValue Index, SDValue Mask, return ActiveLanes.all(); } +/// Match the index of a gather or scatter operation as an operation +/// with twice the element width and half the number of elements. This is +/// generally profitable (if legal) because these operations are linear +/// in VL, so even if we cause some extract VTYPE/VL toggles, we still +/// come out ahead. +static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask, + Align BaseAlign, const RISCVSubtarget &ST) { + if (!ISD::isConstantSplatVectorAllOnes(Mask.getNode())) + return false; + if (!ISD::isBuildVectorOfConstantSDNodes(Index.getNode())) + return false; + + // Attempt a doubling. If we can use a element type 4x or 8x in + // size, this will happen via multiply iterations of the transform. + const unsigned NumElems = VT.getVectorNumElements(); + if (NumElems % 2 != 0) + return false; + + const unsigned ElementSize = VT.getScalarStoreSize(); + const unsigned WiderElementSize = ElementSize * 2; + if (WiderElementSize > ST.getELen()/8) + return false; + + if (!ST.enableUnalignedVectorMem() && BaseAlign < WiderElementSize) + return false; + + for (unsigned i = 0; i < Index->getNumOperands(); i++) { + // TODO: We've found an active bit of UB, and could be + // more aggressive here if desired. + if (Index->getOperand(i)->isUndef()) + return false; + // TODO: This offset check is too strict if we support fully + // misaligned memory operations. + uint64_t C = Index->getConstantOperandVal(i); + if (C % ElementSize != 0) + return false; + if (i % 2 == 0) + continue; + uint64_t Last = Index->getConstantOperandVal(i-1); + if (C != Last + ElementSize) + return false; + } + return true; +} + + SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -14020,6 +14066,36 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAG.getVectorShuffle(VT, DL, Load, DAG.getUNDEF(VT), ShuffleMask); return DAG.getMergeValues({Shuffle, Load.getValue(1)}, DL); } + + if (MGN->getExtensionType() == ISD::NON_EXTLOAD && + matchIndexAsWiderOp(VT, Index, MGN->getMask(), + MGN->getMemOperand()->getBaseAlign(), Subtarget)) { + SmallVector NewIndices; + for (unsigned i = 0; i < Index->getNumOperands(); i += 2) + NewIndices.push_back(Index.getOperand(i)); + EVT IndexVT = Index.getValueType() + .getHalfNumVectorElementsVT(*DAG.getContext()); + Index = DAG.getBuildVector(IndexVT, DL, NewIndices); + + unsigned ElementSize = VT.getScalarStoreSize(); + EVT WideScalarVT = MVT::getIntegerVT(ElementSize * 8 * 2); + auto EltCnt = VT.getVectorElementCount(); + assert(EltCnt.isKnownEven() && "Splitting vector, but not in half!"); + EVT WideVT = EVT::getVectorVT(*DAG.getContext(), WideScalarVT, + EltCnt.divideCoefficientBy(2)); + SDValue Passthru = DAG.getBitcast(WideVT, MGN->getPassThru()); + EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, + EltCnt.divideCoefficientBy(2)); + SDValue Mask = DAG.getSplat(MaskVT, DL, DAG.getConstant(1, DL, MVT::i1)); + + SDValue Gather = + DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other), WideVT, DL, + {MGN->getChain(), Passthru, Mask, MGN->getBasePtr(), + Index, ScaleOp}, + MGN->getMemOperand(), IndexType, ISD::NON_EXTLOAD); + SDValue Result = DAG.getBitcast(VT, Gather.getValue(0)); + return DAG.getMergeValues({Result, Gather.getValue(1)}, DL); + } break; } case ISD::MSCATTER:{ diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll index ac5c11ca88df5..130d2c7613b32 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll @@ -13024,19 +13024,19 @@ define <4 x i32> @mgather_narrow_edge_case(ptr %base) { define <8 x i16> @mgather_strided_2xSEW(ptr %base) { ; RV32-LABEL: mgather_strided_2xSEW: ; RV32: # %bb.0: -; RV32-NEXT: lui a1, %hi(.LCPI107_0) -; RV32-NEXT: addi a1, a1, %lo(.LCPI107_0) -; RV32-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; RV32-NEXT: vle8.v v9, (a1) +; RV32-NEXT: vsetivli zero, 4, e8, mf4, ta, ma +; RV32-NEXT: vid.v v8 +; RV32-NEXT: vsll.vi v9, v8, 3 +; RV32-NEXT: vsetvli zero, zero, e32, m1, ta, ma ; RV32-NEXT: vluxei8.v v8, (a0), v9 ; RV32-NEXT: ret ; ; RV64V-LABEL: mgather_strided_2xSEW: ; RV64V: # %bb.0: -; RV64V-NEXT: lui a1, %hi(.LCPI107_0) -; RV64V-NEXT: addi a1, a1, %lo(.LCPI107_0) -; RV64V-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; RV64V-NEXT: vle8.v v9, (a1) +; RV64V-NEXT: vsetivli zero, 4, e8, mf4, ta, ma +; RV64V-NEXT: vid.v v8 +; RV64V-NEXT: vsll.vi v9, v8, 3 +; RV64V-NEXT: vsetvli zero, zero, e32, m1, ta, ma ; RV64V-NEXT: vluxei8.v v8, (a0), v9 ; RV64V-NEXT: ret ; @@ -13141,19 +13141,19 @@ define <8 x i16> @mgather_strided_2xSEW(ptr %base) { define <8 x i16> @mgather_gather_2xSEW(ptr %base) { ; RV32-LABEL: mgather_gather_2xSEW: ; RV32: # %bb.0: -; RV32-NEXT: lui a1, %hi(.LCPI108_0) -; RV32-NEXT: addi a1, a1, %lo(.LCPI108_0) -; RV32-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; RV32-NEXT: vle8.v v9, (a1) +; RV32-NEXT: lui a1, 82176 +; RV32-NEXT: addi a1, a1, 1024 +; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; RV32-NEXT: vmv.s.x v9, a1 ; RV32-NEXT: vluxei8.v v8, (a0), v9 ; RV32-NEXT: ret ; ; RV64V-LABEL: mgather_gather_2xSEW: ; RV64V: # %bb.0: -; RV64V-NEXT: lui a1, %hi(.LCPI108_0) -; RV64V-NEXT: addi a1, a1, %lo(.LCPI108_0) -; RV64V-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; RV64V-NEXT: vle8.v v9, (a1) +; RV64V-NEXT: lui a1, 82176 +; RV64V-NEXT: addiw a1, a1, 1024 +; RV64V-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; RV64V-NEXT: vmv.s.x v9, a1 ; RV64V-NEXT: vluxei8.v v8, (a0), v9 ; RV64V-NEXT: ret ;