Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVFeatures.td
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,7 @@ def HasVInstructionsF16Minimal : Predicate<"Subtarget->hasVInstructionsF16Minima

def HasVInstructionsBF16Minimal : Predicate<"Subtarget->hasVInstructionsBF16Minimal()">;
def HasVInstructionsF16 : Predicate<"Subtarget->hasVInstructionsF16()">;
def HasVInstructionsBF16 : Predicate<"Subtarget->hasVInstructionsBF16()">;
def HasVInstructionsF64 : Predicate<"Subtarget->hasVInstructionsF64()">;

def HasVInstructionsFullMultiply : Predicate<"Subtarget->hasVInstructionsFullMultiply()">;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoV.td
Original file line number Diff line number Diff line change
Expand Up @@ -1840,3 +1840,6 @@ let Predicates = [HasVInstructionsI64, IsRV64] in {

include "RISCVInstrInfoVPseudos.td"
include "RISCVInstrInfoZvfbf.td"
// Include the non-intrinsic ISel patterns
include "RISCVInstrInfoVVLPatterns.td"
include "RISCVInstrInfoVSDPatterns.td"
44 changes: 21 additions & 23 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
Original file line number Diff line number Diff line change
Expand Up @@ -473,17 +473,27 @@ defset list<VTypeInfoToWide> AllWidenableIntVectors = {
def : VTypeInfoToWide<VI32M4, VI64M8>;
}

defset list<VTypeInfoToWide> AllWidenableFloatVectors = {
def : VTypeInfoToWide<VF16MF4, VF32MF2>;
def : VTypeInfoToWide<VF16MF2, VF32M1>;
def : VTypeInfoToWide<VF16M1, VF32M2>;
def : VTypeInfoToWide<VF16M2, VF32M4>;
def : VTypeInfoToWide<VF16M4, VF32M8>;
defset list<VTypeInfoToWide> AllWidenableFloatAndBF16Vectors = {
defset list<VTypeInfoToWide> AllWidenableFloatVectors = {
def : VTypeInfoToWide<VF16MF4, VF32MF2>;
def : VTypeInfoToWide<VF16MF2, VF32M1>;
def : VTypeInfoToWide<VF16M1, VF32M2>;
def : VTypeInfoToWide<VF16M2, VF32M4>;
def : VTypeInfoToWide<VF16M4, VF32M8>;

def : VTypeInfoToWide<VF32MF2, VF64M1>;
def : VTypeInfoToWide<VF32M1, VF64M2>;
def : VTypeInfoToWide<VF32M2, VF64M4>;
def : VTypeInfoToWide<VF32M4, VF64M8>;
def : VTypeInfoToWide<VF32MF2, VF64M1>;
def : VTypeInfoToWide<VF32M1, VF64M2>;
def : VTypeInfoToWide<VF32M2, VF64M4>;
def : VTypeInfoToWide<VF32M4, VF64M8>;
}

defset list<VTypeInfoToWide> AllWidenableBF16ToFloatVectors = {
def : VTypeInfoToWide<VBF16MF4, VF32MF2>;
def : VTypeInfoToWide<VBF16MF2, VF32M1>;
def : VTypeInfoToWide<VBF16M1, VF32M2>;
def : VTypeInfoToWide<VBF16M2, VF32M4>;
def : VTypeInfoToWide<VBF16M4, VF32M8>;
}
}

defset list<VTypeInfoToFraction> AllFractionableVF2IntVectors = {
Expand Down Expand Up @@ -543,14 +553,6 @@ defset list<VTypeInfoToWide> AllWidenableIntToFloatVectors = {
def : VTypeInfoToWide<VI32M4, VF64M8>;
}

defset list<VTypeInfoToWide> AllWidenableBF16ToFloatVectors = {
def : VTypeInfoToWide<VBF16MF4, VF32MF2>;
def : VTypeInfoToWide<VBF16MF2, VF32M1>;
def : VTypeInfoToWide<VBF16M1, VF32M2>;
def : VTypeInfoToWide<VBF16M2, VF32M4>;
def : VTypeInfoToWide<VBF16M4, VF32M8>;
}

// This class holds the record of the RISCVVPseudoTable below.
// This represents the information we need in codegen for each pseudo.
// The definition should be consistent with `struct PseudoInfo` in
Expand Down Expand Up @@ -780,7 +782,7 @@ class GetVRegNoV0<VReg VRegClass> {

class GetVTypePredicates<VTypeInfo vti> {
list<Predicate> Predicates = !cond(!eq(vti.Scalar, f16) : [HasVInstructionsF16],
!eq(vti.Scalar, bf16) : [HasVInstructionsBF16Minimal],
!eq(vti.Scalar, bf16) : [HasVInstructionsBF16],
!eq(vti.Scalar, f32) : [HasVInstructionsAnyF],
!eq(vti.Scalar, f64) : [HasVInstructionsF64],
!eq(vti.SEW, 64) : [HasVInstructionsI64],
Expand Down Expand Up @@ -7326,7 +7328,3 @@ defm : VPatBinaryV_VV_INT_EEW<"int_riscv_vrgatherei16_vv", "PseudoVRGATHEREI16",
// 16.5. Vector Compress Instruction
//===----------------------------------------------------------------------===//
defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllVectors>;

// Include the non-intrinsic ISel patterns
include "RISCVInstrInfoVVLPatterns.td"
include "RISCVInstrInfoVSDPatterns.td"
106 changes: 74 additions & 32 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,17 @@ multiclass VPatBinaryFPSDNode_VV_VF<SDPatternOperator vop, string instruction_na
}

multiclass VPatBinaryFPSDNode_VV_VF_RM<SDPatternOperator vop, string instruction_name,
bit isSEWAware = 0, bit isBF16 = 0> {
foreach vti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in {
list<VTypeInfo> vtilist = AllFloatVectors,
bit isSEWAware = 0> {
foreach vti = vtilist in {
let Predicates = GetVTypePredicates<vti>.Predicates in {
def : VPatBinarySDNode_VV_RM<vop, instruction_name,
def : VPatBinarySDNode_VV_RM<vop, instruction_name #
!if(!eq(vti.Scalar, bf16), "_ALT", ""),
vti.Vector, vti.Vector, vti.Log2SEW,
vti.LMul, vti.AVL, vti.RegClass, isSEWAware>;
def : VPatBinarySDNode_VF_RM<vop, instruction_name#"_V"#vti.ScalarSuffix,
def : VPatBinarySDNode_VF_RM<vop, instruction_name#
!if(!eq(vti.Scalar, bf16), "_ALT", "")#
"_V"#vti.ScalarSuffix,
vti.Vector, vti.Vector, vti.Scalar,
vti.Log2SEW, vti.LMul, vti.AVL, vti.RegClass,
vti.ScalarRegClass, isSEWAware>;
Expand All @@ -246,14 +250,17 @@ multiclass VPatBinaryFPSDNode_R_VF<SDPatternOperator vop, string instruction_nam
}

multiclass VPatBinaryFPSDNode_R_VF_RM<SDPatternOperator vop, string instruction_name,
bit isSEWAware = 0, bit isBF16 = 0> {
foreach fvti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in
list<VTypeInfo> vtilist = AllFloatVectors,
bit isSEWAware = 0> {
foreach fvti = vtilist in
let Predicates = GetVTypePredicates<fvti>.Predicates in
def : Pat<(fvti.Vector (vop (fvti.Vector (SplatFPOp fvti.Scalar:$rs2)),
(fvti.Vector fvti.RegClass:$rs1))),
(!cast<Instruction>(
!if(isSEWAware,
instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW,
instruction_name#
!if(!eq(fvti.Scalar, bf16), "_ALT", "")#
"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW,
instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX))
(fvti.Vector (IMPLICIT_DEF)),
fvti.RegClass:$rs1,
Expand Down Expand Up @@ -664,19 +671,20 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
defvar vti = vtiToWti.Vti;
defvar wti = vtiToWti.Wti;
defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates,
let Predicates = !listconcat(GetVTypePredicates<wti>.Predicates,
!if(!eq(vti.Scalar, bf16),
[HasStdExtZvfbfwma],
[])) in {
GetVTypePredicates<vti>.Predicates)) in {
def : Pat<(fma (wti.Vector (riscv_fpextend_vl_sameuser
(vti.Vector vti.RegClass:$rs1),
(vti.Mask true_mask), (XLenVT srcvalue))),
(wti.Vector (riscv_fpextend_vl_sameuser
(vti.Vector vti.RegClass:$rs2),
(vti.Mask true_mask), (XLenVT srcvalue))),
(wti.Vector wti.RegClass:$rd)),
(!cast<Instruction>(instruction_name#"_VV_"#suffix)
(!cast<Instruction>(instruction_name#
!if(!eq(vti.Scalar, bf16), "BF16", "")#
"_VV_"#suffix)
wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
// Value to indicate no rounding mode change in
// RISCVInsertReadWriteCSR
Expand All @@ -688,7 +696,9 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
(vti.Vector vti.RegClass:$rs2),
(vti.Mask true_mask), (XLenVT srcvalue))),
(wti.Vector wti.RegClass:$rd)),
(!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#suffix)
(!cast<Instruction>(instruction_name#
!if(!eq(vti.Scalar, bf16), "BF16", "")#
"_V"#vti.ScalarSuffix#"_"#suffix)
wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
// Value to indicate no rounding mode change in
// RISCVInsertReadWriteCSR
Expand Down Expand Up @@ -1201,16 +1211,20 @@ foreach mti = AllMasks in {
// 13. Vector Floating-Point Instructions

// 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
defm : VPatBinaryFPSDNode_VV_VF_RM<any_fadd, "PseudoVFADD", isSEWAware=1>;
defm : VPatBinaryFPSDNode_VV_VF_RM<any_fsub, "PseudoVFSUB", isSEWAware=1>;
defm : VPatBinaryFPSDNode_R_VF_RM<any_fsub, "PseudoVFRSUB", isSEWAware=1>;
defm : VPatBinaryFPSDNode_VV_VF_RM<any_fadd, "PseudoVFADD", AllFloatAndBF16Vectors,
isSEWAware=1>;
defm : VPatBinaryFPSDNode_VV_VF_RM<any_fsub, "PseudoVFSUB", AllFloatAndBF16Vectors,
isSEWAware=1>;
defm : VPatBinaryFPSDNode_R_VF_RM<any_fsub, "PseudoVFRSUB", AllFloatAndBF16Vectors,
isSEWAware=1>;

// 13.3. Vector Widening Floating-Point Add/Subtract Instructions
defm : VPatWidenBinaryFPSDNode_VV_VF_WV_WF_RM<fadd, "PseudoVFWADD">;
defm : VPatWidenBinaryFPSDNode_VV_VF_WV_WF_RM<fsub, "PseudoVFWSUB">;

// 13.4. Vector Single-Width Floating-Point Multiply/Divide Instructions
defm : VPatBinaryFPSDNode_VV_VF_RM<any_fmul, "PseudoVFMUL", isSEWAware=1>;
defm : VPatBinaryFPSDNode_VV_VF_RM<any_fmul, "PseudoVFMUL", AllFloatAndBF16Vectors,
isSEWAware=1>;
defm : VPatBinaryFPSDNode_VV_VF_RM<any_fdiv, "PseudoVFDIV", isSEWAware=1>;
defm : VPatBinaryFPSDNode_R_VF_RM<any_fdiv, "PseudoVFRDIV", isSEWAware=1>;

Expand Down Expand Up @@ -1314,14 +1328,15 @@ foreach fvti = AllFloatVectors in {

// 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACC",
AllWidenableFloatVectors>;
AllWidenableFloatAndBF16Vectors>;
defm : VPatWidenFPNegMulAccSDNode_VV_VF_RM<"PseudoVFWNMACC">;
defm : VPatWidenFPMulSacSDNode_VV_VF_RM<"PseudoVFWMSAC">;
defm : VPatWidenFPNegMulSacSDNode_VV_VF_RM<"PseudoVFWNMSAC">;

foreach vti = AllFloatVectors in {
foreach vti = AllFloatAndBF16Vectors in {
let Predicates = GetVTypePredicates<vti>.Predicates in {
// 13.8. Vector Floating-Point Square-Root Instruction
if !ne(vti.Scalar, bf16) then
def : Pat<(any_fsqrt (vti.Vector vti.RegClass:$rs2)),
(!cast<Instruction>("PseudoVFSQRT_V_"# vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
Expand All @@ -1333,34 +1348,46 @@ foreach vti = AllFloatVectors in {

// 13.12. Vector Floating-Point Sign-Injection Instructions
def : Pat<(fabs (vti.Vector vti.RegClass:$rs)),
(!cast<Instruction>("PseudoVFSGNJX_VV_"# vti.LMul.MX#"_E"#vti.SEW)
(!cast<Instruction>("PseudoVFSGNJX"#
!if(!eq(vti.Scalar, bf16), "_ALT", "")#
"_VV_"# vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>;
// Handle fneg with VFSGNJN using the same input for both operands.
def : Pat<(fneg (vti.Vector vti.RegClass:$rs)),
(!cast<Instruction>("PseudoVFSGNJN_VV_"# vti.LMul.MX#"_E"#vti.SEW)
(!cast<Instruction>("PseudoVFSGNJN"#
!if(!eq(vti.Scalar, bf16), "_ALT", "")#
"_VV_"# vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>;

def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
(vti.Vector vti.RegClass:$rs2))),
(!cast<Instruction>("PseudoVFSGNJ_VV_"# vti.LMul.MX#"_E"#vti.SEW)
(!cast<Instruction>("PseudoVFSGNJ"#
!if(!eq(vti.Scalar, bf16), "_ALT", "")#
"_VV_"# vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
(vti.Vector (SplatFPOp vti.ScalarRegClass:$rs2)))),
(!cast<Instruction>("PseudoVFSGNJ_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
(!cast<Instruction>("PseudoVFSGNJ"#
!if(!eq(vti.Scalar, bf16), "_ALT", "")#
"_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;

def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
(vti.Vector (fneg vti.RegClass:$rs2)))),
(!cast<Instruction>("PseudoVFSGNJN_VV_"# vti.LMul.MX#"_E"#vti.SEW)
(!cast<Instruction>("PseudoVFSGNJN"#
!if(!eq(vti.Scalar, bf16), "_ALT", "")#
"_VV_"# vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
(vti.Vector (fneg (SplatFPOp vti.ScalarRegClass:$rs2))))),
(!cast<Instruction>("PseudoVFSGNJN_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
(!cast<Instruction>("PseudoVFSGNJN"#
!if(!eq(vti.Scalar, bf16), "_ALT", "")#
"_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
}
Expand Down Expand Up @@ -1446,13 +1473,28 @@ defm : VPatNConvertFP2ISDNode_W<any_fp_to_sint, "PseudoVFNCVT_RTZ_X_F_W">;
defm : VPatNConvertFP2ISDNode_W<any_fp_to_uint, "PseudoVFNCVT_RTZ_XU_F_W">;
defm : VPatNConvertI2FPSDNode_W_RM<any_sint_to_fp, "PseudoVFNCVT_F_X_W">;
defm : VPatNConvertI2FPSDNode_W_RM<any_uint_to_fp, "PseudoVFNCVT_F_XU_W">;
foreach fvtiToFWti = AllWidenableFloatVectors in {
foreach fvtiToFWti = AllWidenableFloatAndBF16Vectors in {
defvar fvti = fvtiToFWti.Vti;
defvar fwti = fvtiToFWti.Wti;
let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates,
GetVTypeMinimalPredicates<fwti>.Predicates) in
let Predicates = !listconcat(GetVTypeMinimalPredicates<fwti>.Predicates,
!if(!eq(fvti.Scalar, bf16),
[HasStdExtZvfbfmin],
GetVTypeMinimalPredicates<fvti>.Predicates)) in
def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))),
(!cast<Instruction>("PseudoVFNCVT"#
!if(!eq(fvti.Scalar, bf16), "BF16", "")#
"_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW)
(fvti.Vector (IMPLICIT_DEF)),
fwti.RegClass:$rs1,
// Value to indicate no rounding mode change in
// RISCVInsertReadWriteCSR
FRM_DYN,
fvti.AVL, fvti.Log2SEW, TA_MA)>;
// Define vfncvt.f.f.w for bf16 when Zvfbfa is enabled.
if !eq(fvti.Scalar, bf16) then
let Predicates = [HasVInstructionsBF16] in
def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))),
(!cast<Instruction>("PseudoVFNCVT_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW)
(!cast<Instruction>("PseudoVFNCVT_F_F_ALT_W_"#fvti.LMul.MX#"_E"#fvti.SEW)
(fvti.Vector (IMPLICIT_DEF)),
fwti.RegClass:$rs1,
// Value to indicate no rounding mode change in
Expand All @@ -1464,10 +1506,10 @@ foreach fvtiToFWti = AllWidenableFloatVectors in {
//===----------------------------------------------------------------------===//
// Vector Element Extracts
//===----------------------------------------------------------------------===//
foreach vti = NoGroupFloatVectors in {
defvar vfmv_f_s_inst = !cast<Instruction>(!strconcat("PseudoVFMV_",
vti.ScalarSuffix,
"_S"));
foreach vti = !listconcat(NoGroupFloatVectors, NoGroupBF16Vectors) in {
defvar vfmv_f_s_inst =
!cast<Instruction>(!strconcat("PseudoVFMV_", vti.ScalarSuffix,
"_S", !if(!eq(vti.Scalar, bf16), "_ALT", "")));
// Only pattern-match extract-element operations where the index is 0. Any
// other index will have been custom-lowered to slide the vector correctly
// into place.
Expand Down
Loading