diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index d40d4997d76149..0339b302fb2186 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -2758,19 +2758,6 @@ static SDValue getAllOnesMask(MVT VecVT, SDValue VL, const SDLoc &DL, return DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); } -static SDValue getVLOp(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL, - SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - // If we know the exact VLEN, and our VL is exactly equal to VLMAX, - // canonicalize the representation. InsertVSETVLI will pick the immediate - // encoding later if profitable. - const auto [MinVLMAX, MaxVLMAX] = - RISCVTargetLowering::computeVLMAXBounds(ContainerVT, Subtarget); - if (MinVLMAX == MaxVLMAX && NumElts == MinVLMAX) - return DAG.getRegister(RISCV::X0, Subtarget.getXLenVT()); - - return DAG.getConstant(NumElts, DL, Subtarget.getXLenVT()); -} - static std::pair getDefaultScalableVLOps(MVT VecVT, const SDLoc &DL, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { @@ -2784,7 +2771,7 @@ static std::pair getDefaultVLOps(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { assert(ContainerVT.isScalableVector() && "Expecting scalable container type"); - SDValue VL = getVLOp(NumElts, ContainerVT, DL, DAG, Subtarget); + SDValue VL = DAG.getConstant(NumElts, DL, Subtarget.getXLenVT()); SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG); return {Mask, VL}; } @@ -9427,8 +9414,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, MVT VT = Op->getSimpleValueType(0); MVT ContainerVT = getContainerForFixedLengthVector(VT); - SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG, - Subtarget); + SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT); SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT); auto *Load = cast(Op); SmallVector ContainerVTs(NF, ContainerVT); @@ -9507,8 +9493,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op, MVT VT = Op->getOperand(2).getSimpleValueType(); MVT ContainerVT = getContainerForFixedLengthVector(VT); - SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG, - Subtarget); + SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT); SDValue IntID = DAG.getTargetConstant(VssegInts[NF - 2], DL, XLenVT); SDValue Ptr = Op->getOperand(NF + 2); @@ -9974,7 +9959,7 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op, // Set the vector length to only the number of elements we care about. Note // that for slideup this includes the offset. unsigned EndIndex = OrigIdx + SubVecVT.getVectorNumElements(); - SDValue VL = getVLOp(EndIndex, ContainerVT, DL, DAG, Subtarget); + SDValue VL = DAG.getConstant(EndIndex, DL, XLenVT); // Use tail agnostic policy if we're inserting over Vec's tail. unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED; @@ -10211,8 +10196,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op, getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first; // Set the vector length to only the number of elements we care about. This // avoids sliding down elements we're going to discard straight away. - SDValue VL = getVLOp(SubVecVT.getVectorNumElements(), ContainerVT, DL, DAG, - Subtarget); + SDValue VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT); SDValue SlidedownAmt = DAG.getConstant(OrigIdx, DL, XLenVT); SDValue Slidedown = getVSlidedown(DAG, Subtarget, DL, ContainerVT, @@ -10287,8 +10271,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op, SDValue SlidedownAmt = DAG.getElementCount(DL, XLenVT, RemIdx); auto [Mask, VL] = getDefaultScalableVLOps(InterSubVT, DL, DAG, Subtarget); if (SubVecVT.isFixedLengthVector()) - VL = getVLOp(SubVecVT.getVectorNumElements(), InterSubVT, DL, DAG, - Subtarget); + VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT); SDValue Slidedown = getVSlidedown(DAG, Subtarget, DL, InterSubVT, DAG.getUNDEF(InterSubVT), Vec, SlidedownAmt, Mask, VL); @@ -10668,7 +10651,7 @@ RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op, return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL); } - SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG, Subtarget); + SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT); bool IsMaskOp = VT.getVectorElementType() == MVT::i1; SDValue IntID = DAG.getTargetConstant( @@ -10715,7 +10698,6 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op, SDValue NewValue = convertToScalableVector(ContainerVT, StoreVal, DAG, Subtarget); - // If we know the exact VLEN and our fixed length vector completely fills // the container, use a whole register store instead. const auto [MinVLMAX, MaxVLMAX] = @@ -10728,8 +10710,7 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op, MMO->getFlags(), MMO->getAAInfo()); } - SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG, - Subtarget); + SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT); bool IsMaskOp = VT.getVectorElementType() == MVT::i1; SDValue IntID = DAG.getTargetConstant( diff --git a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp index b083e64cfc8d7e..f328c55e1d3bac 100644 --- a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp +++ b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp @@ -47,6 +47,7 @@ class RISCVVectorPeephole : public MachineFunctionPass { const TargetInstrInfo *TII; MachineRegisterInfo *MRI; const TargetRegisterInfo *TRI; + const RISCVSubtarget *ST; RISCVVectorPeephole() : MachineFunctionPass(ID) {} bool runOnMachineFunction(MachineFunction &MF) override; @@ -64,6 +65,7 @@ class RISCVVectorPeephole : public MachineFunctionPass { bool convertVMergeToVMv(MachineInstr &MI) const; bool isAllOnesMask(const MachineInstr *MaskDef) const; + std::optional getConstant(const MachineOperand &VL) const; /// Maps uses of V0 to the corresponding def of V0. DenseMap V0Defs; @@ -76,13 +78,44 @@ char RISCVVectorPeephole::ID = 0; INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false, false) -// If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it -// to the VLMAX sentinel value. +/// Check if an operand is an immediate or a materialized ADDI $x0, imm. +std::optional +RISCVVectorPeephole::getConstant(const MachineOperand &VL) const { + if (VL.isImm()) + return VL.getImm(); + + MachineInstr *Def = MRI->getVRegDef(VL.getReg()); + if (!Def || Def->getOpcode() != RISCV::ADDI || + Def->getOperand(1).getReg() != RISCV::X0) + return std::nullopt; + return Def->getOperand(2).getImm(); +} + +/// Convert AVLs that are known to be VLMAX to the VLMAX sentinel. bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const { if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) || !RISCVII::hasSEWOp(MI.getDesc().TSFlags)) return false; + + auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags)); + // Fixed-point value, denominator=8 + unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first; + unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); + // A Log2SEW of 0 is an operation on mask registers only + unsigned SEW = Log2SEW ? 1 << Log2SEW : 8; + assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW"); + assert(8 * LMULFixed / SEW > 0); + + // If the exact VLEN is known then we know VLMAX, check if the AVL == VLMAX. MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc())); + if (auto VLen = ST->getRealVLen(), AVL = getConstant(VL); + VLen && AVL && (*VLen * LMULFixed) / SEW == *AVL * 8) { + VL.ChangeToImmediate(RISCV::VLMaxSentinel); + return true; + } + + // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert + // it to the VLMAX sentinel value. if (!VL.isReg()) return false; MachineInstr *Def = MRI->getVRegDef(VL.getReg()); @@ -105,15 +138,6 @@ bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const { if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB) return false; - auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags)); - // Fixed-point value, denominator=8 - unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first; - unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); - // A Log2SEW of 0 is an operation on mask registers only - unsigned SEW = Log2SEW ? 1 << Log2SEW : 8; - assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW"); - assert(8 * LMULFixed / SEW > 0); - // AVL = (VLENB * Scale) // // VLMAX = (VLENB * 8 * LMUL) / SEW @@ -302,11 +326,11 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) { return false; // Skip if the vector extension is not enabled. - const RISCVSubtarget &ST = MF.getSubtarget(); - if (!ST.hasVInstructions()) + ST = &MF.getSubtarget(); + if (!ST->hasVInstructions()) return false; - TII = ST.getInstrInfo(); + TII = ST->getInstrInfo(); MRI = &MF.getRegInfo(); TRI = MRI->getTargetRegisterInfo(); diff --git a/llvm/test/CodeGen/RISCV/rvv/pr83017.ll b/llvm/test/CodeGen/RISCV/rvv/pr83017.ll index 3719a2ad994d6f..beca480378a358 100644 --- a/llvm/test/CodeGen/RISCV/rvv/pr83017.ll +++ b/llvm/test/CodeGen/RISCV/rvv/pr83017.ll @@ -35,11 +35,11 @@ define void @aliasing(ptr %p) { ; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma ; CHECK-NEXT: vmv.v.i v8, 0 ; CHECK-NEXT: vs1r.v v8, (a2) -; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma -; CHECK-NEXT: vmv.v.i v12, 0 -; CHECK-NEXT: vs4r.v v12, (a0) ; CHECK-NEXT: addi a2, a0, 64 ; CHECK-NEXT: vs1r.v v8, (a2) +; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma +; CHECK-NEXT: vmv.v.i v8, 0 +; CHECK-NEXT: vs4r.v v8, (a0) ; CHECK-NEXT: sw a1, 84(a0) ; CHECK-NEXT: ret %q = getelementptr inbounds i8, ptr %p, i64 84 diff --git a/llvm/test/CodeGen/RISCV/rvv/pr90559.ll b/llvm/test/CodeGen/RISCV/rvv/pr90559.ll index 8d330b12055ae9..7e109f307c4a53 100644 --- a/llvm/test/CodeGen/RISCV/rvv/pr90559.ll +++ b/llvm/test/CodeGen/RISCV/rvv/pr90559.ll @@ -32,11 +32,11 @@ define void @f(ptr %p) vscale_range(2,2) { ; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma ; CHECK-NEXT: vmv.v.i v8, 0 ; CHECK-NEXT: vs1r.v v8, (a2) -; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma -; CHECK-NEXT: vmv.v.i v12, 0 -; CHECK-NEXT: vs4r.v v12, (a0) ; CHECK-NEXT: addi a2, a0, 64 ; CHECK-NEXT: vs1r.v v8, (a2) +; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma +; CHECK-NEXT: vmv.v.i v8, 0 +; CHECK-NEXT: vs4r.v v8, (a0) ; CHECK-NEXT: sw a1, 84(a0) ; CHECK-NEXT: ret %q = getelementptr inbounds i8, ptr %p, i64 84