Skip to content

Commit

Permalink
[LLVM][XTHeadVector] Implement intrinsics for vmerge and `vmv.v.x/i…
Browse files Browse the repository at this point in the history
…`. (llvm#72)

* [LLVM][XTHeadVector] Define intrinsic functions for vmerge and vmv.v.{x,i}.

* [LLVM][XTHeadVector] Define pseudos and pats for vmerge.

* [LLVM][XTHeadVector] Add test cases for vmerge.

* [LLVM][XTHeadVector] Define policy-free pseudo nodes for vmv.v.{v/x/i}. Define pats for vmv.v.v.

* [LLVM][XTHeadVector] Define ISD node for vmv.v.x and map it to pseudo nodes.

* [LLVM][XTHeadVector] Select vmv.v.x using logic shared with its 1.0 version.

* [LLVM][XTHeadVector] Don't add policy for xthead pseudo nodes.

* [LLVM][XTHeadVector] Add test cases for vmv.v.x.

* [LLVM][XTHeadVector] Update test cases since now pseudo vmv do not accept policy fields any more.

* [NFC][XTHeadVector] Update readme.
  • Loading branch information
AinsleySnow authored and imkiva committed Apr 1, 2024
1 parent ea1dc1b commit 99ef8b1
Show file tree
Hide file tree
Showing 12 changed files with 3,613 additions and 37 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Any feature not listed below but present in the specification should be consider
- (Done) `12.11. Vector Widening Integer Multiply Instructions`
- (Done) `12.12. Vector Single-Width Integer Multiply-Add Instructions`
- (Done) `12.13. Vector Widening Integer Multiply-Add Instructions`
- (Done) `12.14. Vector Integer Merge and Move Instructions`
- (WIP) Clang intrinsics related to the `XTHeadVector` extension:
- (WIP) `6. Configuration-Setting and Utility`
- (Done) `6.1. Set vl and vtype`
Expand Down
14 changes: 12 additions & 2 deletions llvm/include/llvm/IR/IntrinsicsRISCVXTHeadV.td
Original file line number Diff line number Diff line change
Expand Up @@ -770,10 +770,10 @@ let TargetPrefix = "riscv" in {
defm th_vwmacc : XVTernaryWide;
defm th_vwmaccus : XVTernaryWide;
defm th_vwmaccsu : XVTernaryWide;
} // TargetPrefix = "riscv"

let TargetPrefix = "riscv" in {
// 12.14. Vector Integer Merge and Move Instructions
defm th_vmerge : RISCVBinaryWithV0;

// Output: (vector)
// Input: (passthru, vector_in, vl)
def int_riscv_th_vmv_v_v : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
Expand All @@ -783,4 +783,14 @@ let TargetPrefix = "riscv" in {
[IntrNoMem]>, RISCVVIntrinsic {
let VLOperand = 2;
}
// Output: (vector)
// Input: (passthru, scalar, vl)
def int_riscv_th_vmv_v_x : DefaultAttrsIntrinsic<[llvm_anyint_ty],
[LLVMMatchType<0>,
LLVMVectorElementType<0>,
llvm_anyint_ty],
[IntrNoMem]>, RISCVVIntrinsic {
let ScalarOperand = 1;
let VLOperand = 2;
}
} // TargetPrefix = "riscv"
32 changes: 21 additions & 11 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3554,7 +3554,7 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,

static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
SDValue Lo, SDValue Hi, SDValue VL,
SelectionDAG &DAG) {
SelectionDAG &DAG, bool HasVendorXTHeadV) {
if (!Passthru)
Passthru = DAG.getUNDEF(VT);
if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) {
Expand All @@ -3563,7 +3563,9 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
// If Hi constant is all the same sign bit as Lo, lower this as a custom
// node in order to try and match RVV vector/scalar instructions.
if ((LoC >> 31) == HiC)
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL);
return DAG.getNode(HasVendorXTHeadV ?
RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL,
DL, VT, Passthru, Lo, VL);

// If vl is equal to XLEN_MAX and Hi constant is equal to Lo, we could use
// vmv.v.x whose EEW = 32 to lower it.
Expand All @@ -3572,8 +3574,8 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
// TODO: if vl <= min(VLMAX), we can also do this. But we could not
// access the subtarget here now.
auto InterVec = DAG.getNode(
RISCVISD::VMV_V_X_VL, DL, InterVT, DAG.getUNDEF(InterVT), Lo,
DAG.getRegister(RISCV::X0, MVT::i32));
HasVendorXTHeadV ? RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL,
DL, InterVT, DAG.getUNDEF(InterVT), Lo, DAG.getRegister(RISCV::X0, MVT::i32));
return DAG.getNode(ISD::BITCAST, DL, VT, InterVec);
}
}
Expand All @@ -3588,11 +3590,11 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
// of the halves.
static SDValue splatSplitI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
SDValue Scalar, SDValue VL,
SelectionDAG &DAG) {
SelectionDAG &DAG, bool HasVendorXTHeadV) {
assert(Scalar.getValueType() == MVT::i64 && "Unexpected VT!");
SDValue Lo, Hi;
std::tie(Lo, Hi) = DAG.SplitScalar(Scalar, DL, MVT::i32, MVT::i32);
return splatPartsI64WithVL(DL, VT, Passthru, Lo, Hi, VL, DAG);
return splatPartsI64WithVL(DL, VT, Passthru, Lo, Hi, VL, DAG, HasVendorXTHeadV);
}

// This function lowers a splat of a scalar operand Splat with the vector
Expand Down Expand Up @@ -3628,7 +3630,9 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
if (isOneConstant(VL) &&
(!Const || isNullConstant(Scalar) || !isInt<5>(Const->getSExtValue())))
return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru, Scalar, VL);
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL);
return DAG.getNode(
Subtarget.hasVendorXTHeadV() ? RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL,
DL, VT, Passthru, Scalar, VL);
}

assert(XLenVT == MVT::i32 && Scalar.getValueType() == MVT::i64 &&
Expand All @@ -3639,7 +3643,8 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
DAG.getConstant(0, DL, XLenVT), VL);

// Otherwise use the more complicated splatting algorithm.
return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL, DAG);
return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL,
DAG, Subtarget.hasVendorXTHeadV());
}

static MVT getLMUL1VT(MVT VT) {
Expand Down Expand Up @@ -6637,7 +6642,8 @@ SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op,
auto VL = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).second;

SDValue Res =
splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL, DAG);
splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL,
DAG, Subtarget.hasVendorXTHeadV());
return convertFromScalableVector(VecVT, Res, DAG, Subtarget);
}

Expand Down Expand Up @@ -7369,7 +7375,8 @@ static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
// We need to convert the scalar to a splat vector.
SDValue VL = getVLOperand(Op);
assert(VL.getValueType() == XLenVT);
ScalarOp = splatSplitI64WithVL(DL, VT, SDValue(), ScalarOp, VL, DAG);
ScalarOp = splatSplitI64WithVL(DL, VT, SDValue(), ScalarOp, VL,
DAG, Subtarget.hasVendorXTHeadV());
return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands);
}

Expand Down Expand Up @@ -7483,6 +7490,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getValueType(),
Op.getOperand(1), DAG.getConstant(0, DL, XLenVT));
case Intrinsic::riscv_vmv_v_x:
case Intrinsic::riscv_th_vmv_v_x:
return lowerScalarSplat(Op.getOperand(1), Op.getOperand(2),
Op.getOperand(3), Op.getSimpleValueType(), DL, DAG,
Subtarget);
Expand Down Expand Up @@ -7519,7 +7527,8 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
SDValue Vec = Op.getOperand(1);
SDValue VL = getVLOperand(Op);

SDValue SplattedVal = splatSplitI64WithVL(DL, VT, SDValue(), Scalar, VL, DAG);
SDValue SplattedVal = splatSplitI64WithVL(DL, VT, SDValue(), Scalar, VL,
DAG, Subtarget.hasVendorXTHeadV());
if (Op.getOperand(1).isUndef())
return SplattedVal;
SDValue SplattedIdx =
Expand Down Expand Up @@ -16429,6 +16438,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(TH_SDD)
NODE_NAME_CASE(VMV_V_V_VL)
NODE_NAME_CASE(VMV_V_X_VL)
NODE_NAME_CASE(TH_VMV_V_X_VL)
NODE_NAME_CASE(VFMV_V_F_VL)
NODE_NAME_CASE(VMV_X_S)
NODE_NAME_CASE(VMV_S_X_VL)
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ enum NodeType : unsigned {
// for the VL value to be used for the operation. The first operand is
// passthru operand.
VMV_V_X_VL,
TH_VMV_V_X_VL,
// VFMV_V_F_VL matches the semantics of vfmv.v.f but includes an extra operand
// for the VL value to be used for the operation. The first operand is
// passthru operand.
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
const MCInstrDesc &Desc = DefMBBI->getDesc();
MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
MIB.addImm(0); // tu, mu
if (!XTHeadV)
MIB.addImm(0); // tu, mu
MIB.addReg(RISCV::VL, RegState::Implicit);
MIB.addReg(RISCV::VTYPE, RegState::Implicit);
}
Expand Down Expand Up @@ -522,7 +523,8 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
const MCInstrDesc &Desc = DefMBBI->getDesc();
MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
MIB.addImm(0); // tu, mu
if (!XTHeadV)
MIB.addImm(0); // tu, mu
MIB.addReg(RISCV::VL, RegState::Implicit);
MIB.addReg(RISCV::VTYPE, RegState::Implicit);
}
Expand Down
82 changes: 68 additions & 14 deletions llvm/lib/Target/RISCV/RISCVInstrInfoXTHeadVPseudos.td
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,19 @@ class XVPseudoBinaryMaskNoPolicy<VReg RetClass,
let HasSEWOp = 1;
}

class XVPseudoUnaryNoMask<DAGOperand RetClass, DAGOperand OpClass,
string Constraint = ""> :
Pseudo<(outs RetClass:$rd),
(ins RetClass:$merge, OpClass:$rs2, AVL:$vl, ixlenimm:$sew), []>,
RISCVVPseudo {
let mayLoad = 0;
let mayStore = 0;
let hasSideEffects = 0;
let Constraints = !interleave([Constraint, "$rd = $merge"], ",");
let HasVLOp = 1;
let HasSEWOp = 1;
}

multiclass XVPseudoBinary<VReg RetClass,
VReg Op1Class,
DAGOperand Op2Class,
Expand Down Expand Up @@ -2907,6 +2920,30 @@ let Predicates = [HasVendorXTHeadV] in {
//===----------------------------------------------------------------------===//
// 12.14. Vector Integer Merge and Move Instructions
//===----------------------------------------------------------------------===//
multiclass XVPseudoVMRG_VM_XM_IM {
foreach m = MxListXTHeadV in {
defvar mx = m.MX;
defvar WriteVIMergeV_MX = !cast<SchedWrite>("WriteVIMergeV_" # mx);
defvar WriteVIMergeX_MX = !cast<SchedWrite>("WriteVIMergeX_" # mx);
defvar WriteVIMergeI_MX = !cast<SchedWrite>("WriteVIMergeI_" # mx);
defvar ReadVIMergeV_MX = !cast<SchedRead>("ReadVIMergeV_" # mx);
defvar ReadVIMergeX_MX = !cast<SchedRead>("ReadVIMergeX_" # mx);

def "_VVM" # "_" # m.MX:
VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
m.vrclass, m.vrclass, m, 1, "">,
Sched<[WriteVIMergeV_MX, ReadVIMergeV_MX, ReadVIMergeV_MX, ReadVMask]>;
def "_VXM" # "_" # m.MX:
VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
m.vrclass, GPR, m, 1, "">,
Sched<[WriteVIMergeX_MX, ReadVIMergeV_MX, ReadVIMergeX_MX, ReadVMask]>;
def "_VIM" # "_" # m.MX:
VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
m.vrclass, simm5, m, 1, "">,
Sched<[WriteVIMergeI_MX, ReadVIMergeV_MX, ReadVMask]>;
}
}

multiclass XVPseudoUnaryVMV_V_X_I {
foreach m = MxListXTHeadV in {
let VLMul = m.value in {
Expand All @@ -2918,34 +2955,49 @@ multiclass XVPseudoUnaryVMV_V_X_I {
defvar ReadVIMovX_MX = !cast<SchedRead>("ReadVIMovX_" # mx);

let VLMul = m.value in {
def "_V_" # mx : VPseudoUnaryNoMask<m.vrclass, m.vrclass>,
def "_V_" # mx : XVPseudoUnaryNoMask<m.vrclass, m.vrclass>,
Sched<[WriteVIMovV_MX, ReadVIMovV_MX]>;
def "_X_" # mx : VPseudoUnaryNoMask<m.vrclass, GPR>,
def "_X_" # mx : XVPseudoUnaryNoMask<m.vrclass, GPR>,
Sched<[WriteVIMovX_MX, ReadVIMovX_MX]>;
def "_I_" # mx : VPseudoUnaryNoMask<m.vrclass, simm5>,
def "_I_" # mx : XVPseudoUnaryNoMask<m.vrclass, simm5>,
Sched<[WriteVIMovI_MX]>;
}
}
}
}

let Predicates = [HasVendorXTHeadV] in {
defm PseudoTH_VMERGE : XVPseudoVMRG_VM_XM_IM;
defm PseudoTH_VMV_V : XVPseudoUnaryVMV_V_X_I;
} // Predicates = [HasVendorXTHeadV]

// Patterns for `int_riscv_vmv_v_v` -> `PseudoTH_VMV_V_V_<LMUL>`
foreach vti = AllXVectors in {
let Predicates = GetXVTypePredicates<vti>.Predicates in {
// vmv.v.v
def : Pat<(vti.Vector (int_riscv_th_vmv_v_v (vti.Vector vti.RegClass:$passthru),
(vti.Vector vti.RegClass:$rs1),
VLOpFrag)),
(!cast<Instruction>("PseudoTH_VMV_V_V_"#vti.LMul.MX)
$passthru, $rs1, GPR:$vl, vti.Log2SEW, TU_MU)>;
let Predicates = [HasVendorXTHeadV] in {
defm : XVPatBinaryV_VM_XM_IM<"int_riscv_th_vmerge", "PseudoTH_VMERGE">;
// Define patterns for vmerge intrinsics with float-point arguments.
foreach vti = AllFloatXVectors in {
let Predicates = GetXVTypePredicates<vti>.Predicates in {
defm : VPatBinaryCarryInTAIL<"int_riscv_th_vmerge", "PseudoTH_VMERGE", "VVM",
vti.Vector,
vti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass,
vti.RegClass, vti.RegClass>;
}
}

// TODO: vmv.v.x, vmv.v.i
// Patterns for `int_riscv_vmv_v_v` -> `PseudoTH_VMV_V_V_<LMUL>`
foreach vti = AllXVectors in {
let Predicates = GetXVTypePredicates<vti>.Predicates in {
// vmv.v.v
def : Pat<(vti.Vector (int_riscv_th_vmv_v_v (vti.Vector vti.RegClass:$passthru),
(vti.Vector vti.RegClass:$rs1),
VLOpFrag)),
(!cast<Instruction>("PseudoTH_VMV_V_V_"#vti.LMul.MX)
$passthru, $rs1, GPR:$vl, vti.Log2SEW)>;
// Patterns for vmv.v.x and vmv.v.i are defined
// in RISCVInstrInfoXTHeadVVLPatterns.td
}
}
}
} // Predicates = [HasVendorXTHeadV]

//===----------------------------------------------------------------------===//
// 12.14. Vector Integer Merge and Move Instructions
Expand All @@ -2967,3 +3019,5 @@ let Predicates = [HasVendorXTHeadV] in {
def PseudoTH_VMV8R_V : XVPseudoWholeMove<TH_VMV_V_V, V_M8, VRM8>;
}
} // Predicates = [HasVendorXTHeadV]

include "RISCVInstrInfoXTHeadVVLPatterns.td"
32 changes: 32 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoXTHeadVVLPatterns.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===-- RISCVInstrInfoXTHeadVVLPatterns.td - RVV VL patterns -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------------===//
///
/// This file contains the required infrastructure and VL patterns to support
/// code generation for the standard 'V' (Vector) extension, version 0.7.1
///
/// This file is included from RISCVInstrInfoXTHeadVPseudos.td
//===---------------------------------------------------------------------------===//

def riscv_th_vmv_v_x_vl : SDNode<"RISCVISD::TH_VMV_V_X_VL",
SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisInt<0>,
SDTCisSameAs<0, 1>,
SDTCisVT<2, XLenVT>,
SDTCisVT<3, XLenVT>]>>;

foreach vti = AllXVectors in {
foreach vti = AllIntegerXVectors in {
def : Pat<(vti.Vector (riscv_th_vmv_v_x_vl vti.RegClass:$passthru, GPR:$rs2, VLOpFrag)),
(!cast<Instruction>("PseudoTH_VMV_V_X_"#vti.LMul.MX)
vti.RegClass:$passthru, GPR:$rs2, GPR:$vl, vti.Log2SEW)>;
defvar ImmPat = !cast<ComplexPattern>("sew"#vti.SEW#"simm5");
def : Pat<(vti.Vector (riscv_th_vmv_v_x_vl vti.RegClass:$passthru, (ImmPat simm5:$imm5),
VLOpFrag)),
(!cast<Instruction>("PseudoTH_VMV_V_I_"#vti.LMul.MX)
vti.RegClass:$passthru, simm5:$imm5, GPR:$vl, vti.Log2SEW)>;
}
}
Loading

0 comments on commit 99ef8b1

Please sign in to comment.