Skip to content

Commit

Permalink
[RISCV] Move exact VLEN VLMAX transform to RISCVVectorPeephole (llvm#…
Browse files Browse the repository at this point in the history
…100551)

We can teach RISCVVectorPeephole to detect when an AVL is equal to the
VLMAX when the exact VLEN is known and use the VLMAX sentinel instead,
and in doing so remove the need for getVLOp in RISCVISelLowering. This
keeps all the VLMAX logic in one place.
  • Loading branch information
lukel97 authored Jul 25, 2024
1 parent 7432ad6 commit 754dc9f
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 47 deletions.
35 changes: 8 additions & 27 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2758,19 +2758,6 @@ static SDValue getAllOnesMask(MVT VecVT, SDValue VL, const SDLoc &DL,
return DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
}

static SDValue getVLOp(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL,
SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
// If we know the exact VLEN, and our VL is exactly equal to VLMAX,
// canonicalize the representation. InsertVSETVLI will pick the immediate
// encoding later if profitable.
const auto [MinVLMAX, MaxVLMAX] =
RISCVTargetLowering::computeVLMAXBounds(ContainerVT, Subtarget);
if (MinVLMAX == MaxVLMAX && NumElts == MinVLMAX)
return DAG.getRegister(RISCV::X0, Subtarget.getXLenVT());

return DAG.getConstant(NumElts, DL, Subtarget.getXLenVT());
}

static std::pair<SDValue, SDValue>
getDefaultScalableVLOps(MVT VecVT, const SDLoc &DL, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
Expand All @@ -2784,7 +2771,7 @@ static std::pair<SDValue, SDValue>
getDefaultVLOps(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL,
SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
assert(ContainerVT.isScalableVector() && "Expecting scalable container type");
SDValue VL = getVLOp(NumElts, ContainerVT, DL, DAG, Subtarget);
SDValue VL = DAG.getConstant(NumElts, DL, Subtarget.getXLenVT());
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
return {Mask, VL};
}
Expand Down Expand Up @@ -9427,8 +9414,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
MVT VT = Op->getSimpleValueType(0);
MVT ContainerVT = getContainerForFixedLengthVector(VT);

SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
Subtarget);
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT);
auto *Load = cast<MemIntrinsicSDNode>(Op);
SmallVector<EVT, 9> ContainerVTs(NF, ContainerVT);
Expand Down Expand Up @@ -9507,8 +9493,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
MVT VT = Op->getOperand(2).getSimpleValueType();
MVT ContainerVT = getContainerForFixedLengthVector(VT);

SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
Subtarget);
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
SDValue IntID = DAG.getTargetConstant(VssegInts[NF - 2], DL, XLenVT);
SDValue Ptr = Op->getOperand(NF + 2);

Expand Down Expand Up @@ -9974,7 +9959,7 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
// Set the vector length to only the number of elements we care about. Note
// that for slideup this includes the offset.
unsigned EndIndex = OrigIdx + SubVecVT.getVectorNumElements();
SDValue VL = getVLOp(EndIndex, ContainerVT, DL, DAG, Subtarget);
SDValue VL = DAG.getConstant(EndIndex, DL, XLenVT);

// Use tail agnostic policy if we're inserting over Vec's tail.
unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
Expand Down Expand Up @@ -10211,8 +10196,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
// Set the vector length to only the number of elements we care about. This
// avoids sliding down elements we're going to discard straight away.
SDValue VL = getVLOp(SubVecVT.getVectorNumElements(), ContainerVT, DL, DAG,
Subtarget);
SDValue VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT);
SDValue SlidedownAmt = DAG.getConstant(OrigIdx, DL, XLenVT);
SDValue Slidedown =
getVSlidedown(DAG, Subtarget, DL, ContainerVT,
Expand Down Expand Up @@ -10287,8 +10271,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
SDValue SlidedownAmt = DAG.getElementCount(DL, XLenVT, RemIdx);
auto [Mask, VL] = getDefaultScalableVLOps(InterSubVT, DL, DAG, Subtarget);
if (SubVecVT.isFixedLengthVector())
VL = getVLOp(SubVecVT.getVectorNumElements(), InterSubVT, DL, DAG,
Subtarget);
VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT);
SDValue Slidedown =
getVSlidedown(DAG, Subtarget, DL, InterSubVT, DAG.getUNDEF(InterSubVT),
Vec, SlidedownAmt, Mask, VL);
Expand Down Expand Up @@ -10668,7 +10651,7 @@ RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,
return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL);
}

SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG, Subtarget);
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);

bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
SDValue IntID = DAG.getTargetConstant(
Expand Down Expand Up @@ -10715,7 +10698,6 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
SDValue NewValue =
convertToScalableVector(ContainerVT, StoreVal, DAG, Subtarget);


// If we know the exact VLEN and our fixed length vector completely fills
// the container, use a whole register store instead.
const auto [MinVLMAX, MaxVLMAX] =
Expand All @@ -10728,8 +10710,7 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
MMO->getFlags(), MMO->getAAInfo());
}

SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
Subtarget);
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);

bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
SDValue IntID = DAG.getTargetConstant(
Expand Down
52 changes: 38 additions & 14 deletions llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
const TargetInstrInfo *TII;
MachineRegisterInfo *MRI;
const TargetRegisterInfo *TRI;
const RISCVSubtarget *ST;
RISCVVectorPeephole() : MachineFunctionPass(ID) {}

bool runOnMachineFunction(MachineFunction &MF) override;
Expand All @@ -64,6 +65,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
bool convertVMergeToVMv(MachineInstr &MI) const;

bool isAllOnesMask(const MachineInstr *MaskDef) const;
std::optional<unsigned> getConstant(const MachineOperand &VL) const;

/// Maps uses of V0 to the corresponding def of V0.
DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
Expand All @@ -76,13 +78,44 @@ char RISCVVectorPeephole::ID = 0;
INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
false)

// If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it
// to the VLMAX sentinel value.
/// Check if an operand is an immediate or a materialized ADDI $x0, imm.
std::optional<unsigned>
RISCVVectorPeephole::getConstant(const MachineOperand &VL) const {
if (VL.isImm())
return VL.getImm();

MachineInstr *Def = MRI->getVRegDef(VL.getReg());
if (!Def || Def->getOpcode() != RISCV::ADDI ||
Def->getOperand(1).getReg() != RISCV::X0)
return std::nullopt;
return Def->getOperand(2).getImm();
}

/// Convert AVLs that are known to be VLMAX to the VLMAX sentinel.
bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) ||
!RISCVII::hasSEWOp(MI.getDesc().TSFlags))
return false;

auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
// Fixed-point value, denominator=8
unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
// A Log2SEW of 0 is an operation on mask registers only
unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
assert(8 * LMULFixed / SEW > 0);

// If the exact VLEN is known then we know VLMAX, check if the AVL == VLMAX.
MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
if (auto VLen = ST->getRealVLen(), AVL = getConstant(VL);
VLen && AVL && (*VLen * LMULFixed) / SEW == *AVL * 8) {
VL.ChangeToImmediate(RISCV::VLMaxSentinel);
return true;
}

// If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert
// it to the VLMAX sentinel value.
if (!VL.isReg())
return false;
MachineInstr *Def = MRI->getVRegDef(VL.getReg());
Expand All @@ -105,15 +138,6 @@ bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
return false;

auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
// Fixed-point value, denominator=8
unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
// A Log2SEW of 0 is an operation on mask registers only
unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
assert(8 * LMULFixed / SEW > 0);

// AVL = (VLENB * Scale)
//
// VLMAX = (VLENB * 8 * LMUL) / SEW
Expand Down Expand Up @@ -302,11 +326,11 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
return false;

// Skip if the vector extension is not enabled.
const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
if (!ST.hasVInstructions())
ST = &MF.getSubtarget<RISCVSubtarget>();
if (!ST->hasVInstructions())
return false;

TII = ST.getInstrInfo();
TII = ST->getInstrInfo();
MRI = &MF.getRegInfo();
TRI = MRI->getTargetRegisterInfo();

Expand Down
6 changes: 3 additions & 3 deletions llvm/test/CodeGen/RISCV/rvv/pr83017.ll
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ define void @aliasing(ptr %p) {
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vs1r.v v8, (a2)
; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
; CHECK-NEXT: vmv.v.i v12, 0
; CHECK-NEXT: vs4r.v v12, (a0)
; CHECK-NEXT: addi a2, a0, 64
; CHECK-NEXT: vs1r.v v8, (a2)
; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vs4r.v v8, (a0)
; CHECK-NEXT: sw a1, 84(a0)
; CHECK-NEXT: ret
%q = getelementptr inbounds i8, ptr %p, i64 84
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/CodeGen/RISCV/rvv/pr90559.ll
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ define void @f(ptr %p) vscale_range(2,2) {
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vs1r.v v8, (a2)
; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
; CHECK-NEXT: vmv.v.i v12, 0
; CHECK-NEXT: vs4r.v v12, (a0)
; CHECK-NEXT: addi a2, a0, 64
; CHECK-NEXT: vs1r.v v8, (a2)
; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vs4r.v v8, (a0)
; CHECK-NEXT: sw a1, 84(a0)
; CHECK-NEXT: ret
%q = getelementptr inbounds i8, ptr %p, i64 84
Expand Down

0 comments on commit 754dc9f

Please sign in to comment.