diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index 0c75312847c87..05f50cba6e9be 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -908,6 +908,7 @@ def HasVInstructionsF16Minimal : Predicate<"Subtarget->hasVInstructionsF16Minima def HasVInstructionsBF16Minimal : Predicate<"Subtarget->hasVInstructionsBF16Minimal()">; def HasVInstructionsF16 : Predicate<"Subtarget->hasVInstructionsF16()">; +def HasVInstructionsBF16 : Predicate<"Subtarget->hasVInstructionsBF16()">; def HasVInstructionsF64 : Predicate<"Subtarget->hasVInstructionsF64()">; def HasVInstructionsFullMultiply : Predicate<"Subtarget->hasVInstructionsFullMultiply()">; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoV.td b/llvm/lib/Target/RISCV/RISCVInstrInfoV.td index 594a75a4746d4..9354b63bced53 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoV.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoV.td @@ -1840,3 +1840,6 @@ let Predicates = [HasVInstructionsI64, IsRV64] in { include "RISCVInstrInfoVPseudos.td" include "RISCVInstrInfoZvfbf.td" +// Include the non-intrinsic ISel patterns +include "RISCVInstrInfoVVLPatterns.td" +include "RISCVInstrInfoVSDPatterns.td" diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td index e36204c536c0d..cdbeb0c1046d2 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -473,17 +473,27 @@ defset list AllWidenableIntVectors = { def : VTypeInfoToWide; } -defset list AllWidenableFloatVectors = { - def : VTypeInfoToWide; - def : VTypeInfoToWide; - def : VTypeInfoToWide; - def : VTypeInfoToWide; - def : VTypeInfoToWide; +defset list AllWidenableFloatAndBF16Vectors = { + defset list AllWidenableFloatVectors = { + def : VTypeInfoToWide; + def : VTypeInfoToWide; + def : VTypeInfoToWide; + def : VTypeInfoToWide; + def : VTypeInfoToWide; - def : VTypeInfoToWide; - def : VTypeInfoToWide; - def : VTypeInfoToWide; - def : VTypeInfoToWide; + def : VTypeInfoToWide; + def : VTypeInfoToWide; + def : VTypeInfoToWide; + def : VTypeInfoToWide; + } + + defset list AllWidenableBF16ToFloatVectors = { + def : VTypeInfoToWide; + def : VTypeInfoToWide; + def : VTypeInfoToWide; + def : VTypeInfoToWide; + def : VTypeInfoToWide; + } } defset list AllFractionableVF2IntVectors = { @@ -543,14 +553,6 @@ defset list AllWidenableIntToFloatVectors = { def : VTypeInfoToWide; } -defset list AllWidenableBF16ToFloatVectors = { - def : VTypeInfoToWide; - def : VTypeInfoToWide; - def : VTypeInfoToWide; - def : VTypeInfoToWide; - def : VTypeInfoToWide; -} - // This class holds the record of the RISCVVPseudoTable below. // This represents the information we need in codegen for each pseudo. // The definition should be consistent with `struct PseudoInfo` in @@ -780,7 +782,7 @@ class GetVRegNoV0 { class GetVTypePredicates { list Predicates = !cond(!eq(vti.Scalar, f16) : [HasVInstructionsF16], - !eq(vti.Scalar, bf16) : [HasVInstructionsBF16Minimal], + !eq(vti.Scalar, bf16) : [HasVInstructionsBF16], !eq(vti.Scalar, f32) : [HasVInstructionsAnyF], !eq(vti.Scalar, f64) : [HasVInstructionsF64], !eq(vti.SEW, 64) : [HasVInstructionsI64], @@ -7326,7 +7328,3 @@ defm : VPatBinaryV_VV_INT_EEW<"int_riscv_vrgatherei16_vv", "PseudoVRGATHEREI16", // 16.5. Vector Compress Instruction //===----------------------------------------------------------------------===// defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllVectors>; - -// Include the non-intrinsic ISel patterns -include "RISCVInstrInfoVVLPatterns.td" -include "RISCVInstrInfoVSDPatterns.td" diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td index a67112b9981b8..14ad7ca0eb35a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -215,13 +215,17 @@ multiclass VPatBinaryFPSDNode_VV_VF { - foreach vti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in { + list vtilist = AllFloatVectors, + bit isSEWAware = 0> { + foreach vti = vtilist in { let Predicates = GetVTypePredicates.Predicates in { - def : VPatBinarySDNode_VV_RM; - def : VPatBinarySDNode_VF_RM; @@ -246,14 +250,17 @@ multiclass VPatBinaryFPSDNode_R_VF { - foreach fvti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in + list vtilist = AllFloatVectors, + bit isSEWAware = 0> { + foreach fvti = vtilist in let Predicates = GetVTypePredicates.Predicates in def : Pat<(fvti.Vector (vop (fvti.Vector (SplatFPOp fvti.Scalar:$rs2)), (fvti.Vector fvti.RegClass:$rs1))), (!cast( !if(isSEWAware, - instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW, + instruction_name# + !if(!eq(fvti.Scalar, bf16), "_ALT", "")# + "_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW, instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)) (fvti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs1, @@ -664,11 +671,10 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM.Predicates, - GetVTypePredicates.Predicates, + let Predicates = !listconcat(GetVTypePredicates.Predicates, !if(!eq(vti.Scalar, bf16), [HasStdExtZvfbfwma], - [])) in { + GetVTypePredicates.Predicates)) in { def : Pat<(fma (wti.Vector (riscv_fpextend_vl_sameuser (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), (XLenVT srcvalue))), @@ -676,7 +682,9 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM(instruction_name#"_VV_"#suffix) + (!cast(instruction_name# + !if(!eq(vti.Scalar, bf16), "BF16", "")# + "_VV_"#suffix) wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, // Value to indicate no rounding mode change in // RISCVInsertReadWriteCSR @@ -688,7 +696,9 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM(instruction_name#"_V"#vti.ScalarSuffix#"_"#suffix) + (!cast(instruction_name# + !if(!eq(vti.Scalar, bf16), "BF16", "")# + "_V"#vti.ScalarSuffix#"_"#suffix) wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, // Value to indicate no rounding mode change in // RISCVInsertReadWriteCSR @@ -1201,16 +1211,20 @@ foreach mti = AllMasks in { // 13. Vector Floating-Point Instructions // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions -defm : VPatBinaryFPSDNode_VV_VF_RM; -defm : VPatBinaryFPSDNode_VV_VF_RM; -defm : VPatBinaryFPSDNode_R_VF_RM; +defm : VPatBinaryFPSDNode_VV_VF_RM; +defm : VPatBinaryFPSDNode_VV_VF_RM; +defm : VPatBinaryFPSDNode_R_VF_RM; // 13.3. Vector Widening Floating-Point Add/Subtract Instructions defm : VPatWidenBinaryFPSDNode_VV_VF_WV_WF_RM; defm : VPatWidenBinaryFPSDNode_VV_VF_WV_WF_RM; // 13.4. Vector Single-Width Floating-Point Multiply/Divide Instructions -defm : VPatBinaryFPSDNode_VV_VF_RM; +defm : VPatBinaryFPSDNode_VV_VF_RM; defm : VPatBinaryFPSDNode_VV_VF_RM; defm : VPatBinaryFPSDNode_R_VF_RM; @@ -1314,14 +1328,15 @@ foreach fvti = AllFloatVectors in { // 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACC", - AllWidenableFloatVectors>; + AllWidenableFloatAndBF16Vectors>; defm : VPatWidenFPNegMulAccSDNode_VV_VF_RM<"PseudoVFWNMACC">; defm : VPatWidenFPMulSacSDNode_VV_VF_RM<"PseudoVFWMSAC">; defm : VPatWidenFPNegMulSacSDNode_VV_VF_RM<"PseudoVFWNMSAC">; -foreach vti = AllFloatVectors in { +foreach vti = AllFloatAndBF16Vectors in { let Predicates = GetVTypePredicates.Predicates in { // 13.8. Vector Floating-Point Square-Root Instruction + if !ne(vti.Scalar, bf16) then def : Pat<(any_fsqrt (vti.Vector vti.RegClass:$rs2)), (!cast("PseudoVFSQRT_V_"# vti.LMul.MX#"_E"#vti.SEW) (vti.Vector (IMPLICIT_DEF)), @@ -1333,34 +1348,46 @@ foreach vti = AllFloatVectors in { // 13.12. Vector Floating-Point Sign-Injection Instructions def : Pat<(fabs (vti.Vector vti.RegClass:$rs)), - (!cast("PseudoVFSGNJX_VV_"# vti.LMul.MX#"_E"#vti.SEW) + (!cast("PseudoVFSGNJX"# + !if(!eq(vti.Scalar, bf16), "_ALT", "")# + "_VV_"# vti.LMul.MX#"_E"#vti.SEW) (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>; // Handle fneg with VFSGNJN using the same input for both operands. def : Pat<(fneg (vti.Vector vti.RegClass:$rs)), - (!cast("PseudoVFSGNJN_VV_"# vti.LMul.MX#"_E"#vti.SEW) + (!cast("PseudoVFSGNJN"# + !if(!eq(vti.Scalar, bf16), "_ALT", "")# + "_VV_"# vti.LMul.MX#"_E"#vti.SEW) (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>; def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), (vti.Vector vti.RegClass:$rs2))), - (!cast("PseudoVFSGNJ_VV_"# vti.LMul.MX#"_E"#vti.SEW) + (!cast("PseudoVFSGNJ"# + !if(!eq(vti.Scalar, bf16), "_ALT", "")# + "_VV_"# vti.LMul.MX#"_E"#vti.SEW) (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs2)))), - (!cast("PseudoVFSGNJ_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW) + (!cast("PseudoVFSGNJ"# + !if(!eq(vti.Scalar, bf16), "_ALT", "")# + "_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW) (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), (vti.Vector (fneg vti.RegClass:$rs2)))), - (!cast("PseudoVFSGNJN_VV_"# vti.LMul.MX#"_E"#vti.SEW) + (!cast("PseudoVFSGNJN"# + !if(!eq(vti.Scalar, bf16), "_ALT", "")# + "_VV_"# vti.LMul.MX#"_E"#vti.SEW) (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), (vti.Vector (fneg (SplatFPOp vti.ScalarRegClass:$rs2))))), - (!cast("PseudoVFSGNJN_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW) + (!cast("PseudoVFSGNJN"# + !if(!eq(vti.Scalar, bf16), "_ALT", "")# + "_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW) (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; } @@ -1446,13 +1473,28 @@ defm : VPatNConvertFP2ISDNode_W; defm : VPatNConvertFP2ISDNode_W; defm : VPatNConvertI2FPSDNode_W_RM; defm : VPatNConvertI2FPSDNode_W_RM; -foreach fvtiToFWti = AllWidenableFloatVectors in { +foreach fvtiToFWti = AllWidenableFloatAndBF16Vectors in { defvar fvti = fvtiToFWti.Vti; defvar fwti = fvtiToFWti.Wti; - let Predicates = !listconcat(GetVTypeMinimalPredicates.Predicates, - GetVTypeMinimalPredicates.Predicates) in + let Predicates = !listconcat(GetVTypeMinimalPredicates.Predicates, + !if(!eq(fvti.Scalar, bf16), + [HasStdExtZvfbfmin], + GetVTypeMinimalPredicates.Predicates)) in + def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))), + (!cast("PseudoVFNCVT"# + !if(!eq(fvti.Scalar, bf16), "BF16", "")# + "_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW) + (fvti.Vector (IMPLICIT_DEF)), + fwti.RegClass:$rs1, + // Value to indicate no rounding mode change in + // RISCVInsertReadWriteCSR + FRM_DYN, + fvti.AVL, fvti.Log2SEW, TA_MA)>; + // Define vfncvt.f.f.w for bf16 when Zvfbfa is enabled. + if !eq(fvti.Scalar, bf16) then + let Predicates = [HasVInstructionsBF16] in def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))), - (!cast("PseudoVFNCVT_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW) + (!cast("PseudoVFNCVT_F_F_ALT_W_"#fvti.LMul.MX#"_E"#fvti.SEW) (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1, // Value to indicate no rounding mode change in @@ -1464,10 +1506,10 @@ foreach fvtiToFWti = AllWidenableFloatVectors in { //===----------------------------------------------------------------------===// // Vector Element Extracts //===----------------------------------------------------------------------===// -foreach vti = NoGroupFloatVectors in { - defvar vfmv_f_s_inst = !cast(!strconcat("PseudoVFMV_", - vti.ScalarSuffix, - "_S")); +foreach vti = !listconcat(NoGroupFloatVectors, NoGroupBF16Vectors) in { + defvar vfmv_f_s_inst = + !cast(!strconcat("PseudoVFMV_", vti.ScalarSuffix, + "_S", !if(!eq(vti.Scalar, bf16), "_ALT", ""))); // Only pattern-match extract-element operations where the index is 0. Any // other index will have been custom-lowered to slide the vector correctly // into place. diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index 38edab5400291..9273ce094eb0a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -1058,14 +1058,18 @@ multiclass VPatBinaryFPVL_VV_VF { - foreach vti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in { + list vtilist = AllFloatVectors, + bit isSEWAware = 0> { + foreach vti = vtilist in { let Predicates = GetVTypePredicates.Predicates in { - def : VPatBinaryVL_V_RM; - def : VPatBinaryVL_VF_RM; @@ -1093,8 +1097,9 @@ multiclass VPatBinaryFPVL_R_VF { - foreach fvti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in { + list vtilist = AllFloatVectors, + bit isSEWAware = 0> { + foreach fvti = vtilist in { let Predicates = GetVTypePredicates.Predicates in def : Pat<(fvti.Vector (vop (SplatFPOp fvti.ScalarRegClass:$rs2), fvti.RegClass:$rs1, @@ -1103,7 +1108,9 @@ multiclass VPatBinaryFPVL_R_VF_RM( !if(isSEWAware, - instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK", + instruction_name# + !if(!eq(fvti.Scalar, bf16), "_ALT", "")# + "_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK", instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_MASK")) fvti.RegClass:$passthru, fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2, @@ -1832,16 +1839,17 @@ multiclass VPatWidenFPMulAccVL_VV_VF_RM.Predicates, - GetVTypePredicates.Predicates, + let Predicates = !listconcat(GetVTypePredicates.Predicates, !if(!eq(vti.Scalar, bf16), [HasStdExtZvfbfwma], - [])) in { + GetVTypePredicates.Predicates)) in { def : Pat<(vop (vti.Vector vti.RegClass:$rs1), (vti.Vector vti.RegClass:$rs2), (wti.Vector wti.RegClass:$rd), (vti.Mask VMV0:$vm), VLOpFrag), - (!cast(instruction_name#"_VV_"#suffix#"_MASK") + (!cast(instruction_name# + !if(!eq(vti.Scalar, bf16), "BF16", "")# + "_VV_"#suffix#"_MASK") wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, (vti.Mask VMV0:$vm), // Value to indicate no rounding mode change in @@ -1852,7 +1860,9 @@ multiclass VPatWidenFPMulAccVL_VV_VF_RM(instruction_name#"_V"#vti.ScalarSuffix#"_"#suffix#"_MASK") + (!cast(instruction_name# + !if(!eq(vti.Scalar, bf16), "BF16", "")# + "_V"#vti.ScalarSuffix#"_"#suffix#"_MASK") wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, (vti.Mask VMV0:$vm), // Value to indicate no rounding mode change in @@ -2296,9 +2306,12 @@ foreach vtiTowti = AllWidenableIntVectors in { // 13. Vector Floating-Point Instructions // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions -defm : VPatBinaryFPVL_VV_VF_RM; -defm : VPatBinaryFPVL_VV_VF_RM; -defm : VPatBinaryFPVL_R_VF_RM; +defm : VPatBinaryFPVL_VV_VF_RM; +defm : VPatBinaryFPVL_VV_VF_RM; +defm : VPatBinaryFPVL_R_VF_RM; // 13.3. Vector Widening Floating-Point Add/Subtract Instructions defm : VPatBinaryFPWVL_VV_VF_WV_WF_RM; // 13.4. Vector Single-Width Floating-Point Multiply/Divide Instructions -defm : VPatBinaryFPVL_VV_VF_RM; +defm : VPatBinaryFPVL_VV_VF_RM; defm : VPatBinaryFPVL_VV_VF_RM; defm : VPatBinaryFPVL_R_VF_RM; @@ -2321,7 +2335,8 @@ defm : VPatFPMulAddVL_VV_VF_RM; defm : VPatFPMulAddVL_VV_VF_RM; // 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions -defm : VPatWidenFPMulAccVL_VV_VF_RM; +defm : VPatWidenFPMulAccVL_VV_VF_RM; defm : VPatWidenFPMulAccVL_VV_VF_RM; defm : VPatWidenFPMulAccVL_VV_VF_RM; defm : VPatWidenFPMulAccVL_VV_VF_RM; @@ -2423,6 +2438,66 @@ foreach vti = AllFloatVectors in { } } +foreach vti = AllBF16Vectors in { + let Predicates = GetVTypePredicates.Predicates in { + // 13.12. Vector Floating-Point Sign-Injection Instructions + def : Pat<(riscv_fabs_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm), + VLOpFrag), + (!cast("PseudoVFSGNJX"# + !if(!eq(vti.Scalar, bf16), "_ALT", "")# + "_VV_"# vti.LMul.MX #"_E"#vti.SEW#"_MASK") + (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs, + vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, + TA_MA)>; + // Handle fneg with VFSGNJN using the same input for both operands. + def : Pat<(riscv_fneg_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm), + VLOpFrag), + (!cast("PseudoVFSGNJN"# + !if(!eq(vti.Scalar, bf16), "_ALT", "")# + "_VV_"# vti.LMul.MX#"_E"#vti.SEW #"_MASK") + (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs, + vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, + TA_MA)>; + + def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1), + (vti.Vector vti.RegClass:$rs2), + vti.RegClass:$passthru, + (vti.Mask VMV0:$vm), + VLOpFrag), + (!cast("PseudoVFSGNJ"# + !if(!eq(vti.Scalar, bf16), "_ALT", "")# + "_VV_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK") + vti.RegClass:$passthru, vti.RegClass:$rs1, + vti.RegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, + TAIL_AGNOSTIC)>; + + def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1), + (riscv_fneg_vl vti.RegClass:$rs2, + (vti.Mask true_mask), + VLOpFrag), + srcvalue, + (vti.Mask true_mask), + VLOpFrag), + (!cast("PseudoVFSGNJN"# + !if(!eq(vti.Scalar, bf16), "_ALT", "")# + "_VV_"# vti.LMul.MX#"_E"#vti.SEW) + (vti.Vector (IMPLICIT_DEF)), + vti.RegClass:$rs1, vti.RegClass:$rs2, GPR:$vl, vti.Log2SEW, TA_MA)>; + + def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1), + (SplatFPOp vti.ScalarRegClass:$rs2), + vti.RegClass:$passthru, + (vti.Mask VMV0:$vm), + VLOpFrag), + (!cast("PseudoVFSGNJ"# + !if(!eq(vti.Scalar, bf16), "_ALT", "")# + "_V"#vti.ScalarSuffix#"_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK") + vti.RegClass:$passthru, vti.RegClass:$rs1, + vti.ScalarRegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, + TAIL_AGNOSTIC)>; + } +} + // Floating-point vselects: // 11.15. Vector Integer Merge Instructions // 13.15. Vector Floating-Point Merge Instruction @@ -2476,7 +2551,7 @@ foreach fvti = AllFloatVectors in { } } -foreach fvti = AllFloatVectors in { +foreach fvti = AllFloatAndBF16Vectors in { defvar ivti = GetIntVTypeInfo.Vti; let Predicates = GetVTypePredicates.Predicates in { // 13.16. Vector Floating-Point Move Instruction @@ -2492,11 +2567,13 @@ foreach fvti = AllFloatVectors in { } } -foreach fvti = AllFloatVectors in { +foreach fvti = AllFloatAndBF16Vectors in { let Predicates = GetVTypePredicates.Predicates in { def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl fvti.Vector:$passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2), VLOpFrag)), - (!cast("PseudoVFMV_V_" # fvti.ScalarSuffix # "_" # + (!cast("PseudoVFMV_V" # + !if(!eq(fvti.Scalar, bf16), "_ALT_", "_") # + fvti.ScalarSuffix # "_" # fvti.LMul.MX) $passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2), GPR:$vl, fvti.Log2SEW, TU_MU)>; @@ -2526,20 +2603,37 @@ defm : VPatWConvertFP2IVL_V; defm : VPatWConvertI2FPVL_V; -foreach fvtiToFWti = AllWidenableFloatVectors in { +foreach fvtiToFWti = AllWidenableFloatAndBF16Vectors in { defvar fvti = fvtiToFWti.Vti; defvar fwti = fvtiToFWti.Wti; - // Define vfwcvt.f.f.v for f16 when Zvfhmin is enable. - let Predicates = !listconcat(GetVTypeMinimalPredicates.Predicates, - GetVTypeMinimalPredicates.Predicates) in + // Define vfwcvt.f.f.v for f16 when Zvfhmin is enabled. + // Define vfwcvtbf16.f.f.v for bf16 when Zvfbfmin is enabled. + let Predicates = !listconcat(GetVTypeMinimalPredicates.Predicates, + !if(!eq(fvti.Scalar, bf16), + [HasStdExtZvfbfmin], + GetVTypeMinimalPredicates.Predicates)) in { def : Pat<(fwti.Vector (any_riscv_fpextend_vl (fvti.Vector fvti.RegClass:$rs1), (fvti.Mask VMV0:$vm), VLOpFrag)), - (!cast("PseudoVFWCVT_F_F_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") + (!cast("PseudoVFWCVT"# + !if(!eq(fvti.Scalar, bf16), "BF16", "")# + "_F_F_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") (fwti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs1, (fvti.Mask VMV0:$vm), GPR:$vl, fvti.Log2SEW, TA_MA)>; + + // Define vfwcvt.f.f.v for bf16 when Zvfbfa is enabled. + if !eq(fvti.Scalar, bf16) then + let Predicates = [HasVInstructionsBF16] in + def : Pat<(fwti.Vector (any_riscv_fpextend_vl + (fvti.Vector fvti.RegClass:$rs1), + (fvti.Mask VMV0:$vm), + VLOpFrag)), + (!cast("PseudoVFWCVT_F_F_ALT_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") + (fwti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs1, + (fvti.Mask VMV0:$vm), + GPR:$vl, fvti.Log2SEW, TA_MA)>; } // 13.19 Narrowing Floating-Point/Integer Type-Convert Instructions @@ -2555,16 +2649,21 @@ defm : VPatNConvertI2FPVL_W_RM; defm : VPatNConvertI2FP_RM_VL_W; defm : VPatNConvertI2FP_RM_VL_W; -foreach fvtiToFWti = AllWidenableFloatVectors in { +foreach fvtiToFWti = AllWidenableFloatAndBF16Vectors in { defvar fvti = fvtiToFWti.Vti; defvar fwti = fvtiToFWti.Wti; - // Define vfncvt.f.f.w for f16 when Zvfhmin is enable. - let Predicates = !listconcat(GetVTypeMinimalPredicates.Predicates, - GetVTypeMinimalPredicates.Predicates) in { + // Define vfncvt.f.f.w for f16 when Zvfhmin is enabled. + // Define vfncvtbf16.f.f.w for bf16 when Zvfbfmin is enabled. + let Predicates = !listconcat(GetVTypeMinimalPredicates.Predicates, + !if(!eq(fvti.Scalar, bf16), + [HasStdExtZvfbfmin], + GetVTypeMinimalPredicates.Predicates)) in def : Pat<(fvti.Vector (any_riscv_fpround_vl (fwti.Vector fwti.RegClass:$rs1), (fwti.Mask VMV0:$vm), VLOpFrag)), - (!cast("PseudoVFNCVT_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") + (!cast("PseudoVFNCVT"# + !if(!eq(fvti.Scalar, bf16), "BF16", "")# + "_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1, (fwti.Mask VMV0:$vm), // Value to indicate no rounding mode change in @@ -2581,6 +2680,20 @@ foreach fvtiToFWti = AllWidenableFloatVectors in { (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1, (fwti.Mask VMV0:$vm), GPR:$vl, fvti.Log2SEW, TA_MA)>; } + + // Define vfncvt.f.f.w for bf16 when Zvfbfa is enabled. + if !eq(fvti.Scalar, bf16) then + let Predicates = [HasVInstructionsBF16] in + def : Pat<(fvti.Vector (any_riscv_fpround_vl + (fwti.Vector fwti.RegClass:$rs1), + (fwti.Mask VMV0:$vm), VLOpFrag)), + (!cast("PseudoVFNCVT_F_F_ALT_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") + (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1, + (fwti.Mask VMV0:$vm), + // Value to indicate no rounding mode change in + // RISCVInsertReadWriteCSR + FRM_DYN, + GPR:$vl, fvti.Log2SEW, TA_MA)>; } // 14. Vector Reduction Operations @@ -2751,7 +2864,7 @@ foreach vti = AllIntegerVectors in { } // 16.2. Floating-Point Scalar Move Instructions -foreach vti = NoGroupFloatVectors in { +foreach vti = !listconcat(NoGroupFloatVectors, NoGroupBF16Vectors) in { let Predicates = GetVTypePredicates.Predicates in { def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru), (vti.Scalar (fpimm0)), @@ -2764,7 +2877,8 @@ foreach vti = NoGroupFloatVectors in { def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru), vti.ScalarRegClass:$rs1, VLOpFrag)), - (!cast("PseudoVFMV_S_"#vti.ScalarSuffix) + (!cast("PseudoVFMV_S_"#vti.ScalarSuffix# + !if(!eq(vti.Scalar, bf16), "_ALT", "")) vti.RegClass:$passthru, (vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>; } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td index e24e4a33288f7..866e831fdcd94 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td @@ -406,47 +406,11 @@ let Predicates = [HasStdExtZvfbfmin] in { "PseudoVFWCVTBF16_F_F", isSEWAware=1>; defm : VPatConversionVF_WF_BF_RM<"int_riscv_vfncvtbf16_f_f_w", "PseudoVFNCVTBF16_F_F", isSEWAware=1>; - - foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { - defvar fvti = fvtiToFWti.Vti; - defvar fwti = fvtiToFWti.Wti; - def : Pat<(fwti.Vector (any_riscv_fpextend_vl - (fvti.Vector fvti.RegClass:$rs1), - (fvti.Mask VMV0:$vm), - VLOpFrag)), - (!cast("PseudoVFWCVTBF16_F_F_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") - (fwti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs1, - (fvti.Mask VMV0:$vm), - GPR:$vl, fvti.Log2SEW, TA_MA)>; - - def : Pat<(fvti.Vector (any_riscv_fpround_vl - (fwti.Vector fwti.RegClass:$rs1), - (fwti.Mask VMV0:$vm), VLOpFrag)), - (!cast("PseudoVFNCVTBF16_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") - (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1, - (fwti.Mask VMV0:$vm), - // Value to indicate no rounding mode change in - // RISCVInsertReadWriteCSR - FRM_DYN, - GPR:$vl, fvti.Log2SEW, TA_MA)>; - def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))), - (!cast("PseudoVFNCVTBF16_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW) - (fvti.Vector (IMPLICIT_DEF)), - fwti.RegClass:$rs1, - // Value to indicate no rounding mode change in - // RISCVInsertReadWriteCSR - FRM_DYN, - fvti.AVL, fvti.Log2SEW, TA_MA)>; - } } let Predicates = [HasStdExtZvfbfwma] in { defm : VPatTernaryW_VV_VX_RM<"int_riscv_vfwmaccbf16", "PseudoVFWMACCBF16", AllWidenableBF16ToFloatVectors, isSEWAware=1>; - defm : VPatWidenFPMulAccVL_VV_VF_RM; - defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACCBF16", - AllWidenableBF16ToFloatVectors>; } multiclass VPatConversionVI_VF_BF16 { @@ -614,191 +578,4 @@ defm : VPatConversionVF_WF_BF16<"int_riscv_vfncvt_rod_f_f_w", "PseudoVFNCVT_ROD_ isSEWAware=1>; defm : VPatBinaryV_VX<"int_riscv_vfslide1up", "PseudoVFSLIDE1UP_ALT", AllBF16Vectors>; defm : VPatBinaryV_VX<"int_riscv_vfslide1down", "PseudoVFSLIDE1DOWN_ALT", AllBF16Vectors>; - -foreach fvti = AllBF16Vectors in { - defvar ivti = GetIntVTypeInfo.Vti; - let Predicates = GetVTypePredicates.Predicates in { - // 13.16. Vector Floating-Point Move Instruction - // If we're splatting fpimm0, use vmv.v.x vd, x0. - def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl - fvti.Vector:$passthru, (fvti.Scalar (fpimm0)), VLOpFrag)), - (!cast("PseudoVMV_V_I_"#fvti.LMul.MX) - $passthru, 0, GPR:$vl, fvti.Log2SEW, TU_MU)>; - def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl - fvti.Vector:$passthru, (fvti.Scalar (SelectScalarFPAsInt (XLenVT GPR:$imm))), VLOpFrag)), - (!cast("PseudoVMV_V_X_"#fvti.LMul.MX) - $passthru, GPR:$imm, GPR:$vl, fvti.Log2SEW, TU_MU)>; - } - - let Predicates = GetVTypePredicates.Predicates in { - def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl - fvti.Vector:$passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2), VLOpFrag)), - (!cast("PseudoVFMV_V_ALT_" # fvti.ScalarSuffix # "_" # - fvti.LMul.MX) - $passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2), - GPR:$vl, fvti.Log2SEW, TU_MU)>; - } -} - -foreach vti = NoGroupBF16Vectors in { - let Predicates = GetVTypePredicates.Predicates in { - def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru), - (vti.Scalar (fpimm0)), - VLOpFrag)), - (PseudoVMV_S_X $passthru, (XLenVT X0), GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru), - (vti.Scalar (SelectScalarFPAsInt (XLenVT GPR:$imm))), - VLOpFrag)), - (PseudoVMV_S_X $passthru, GPR:$imm, GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru), - vti.ScalarRegClass:$rs1, - VLOpFrag)), - (!cast("PseudoVFMV_S_"#vti.ScalarSuffix#"_ALT") - vti.RegClass:$passthru, - (vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>; - } - - defvar vfmv_f_s_inst = !cast(!strconcat("PseudoVFMV_", - vti.ScalarSuffix, - "_S_ALT")); - // Only pattern-match extract-element operations where the index is 0. Any - // other index will have been custom-lowered to slide the vector correctly - // into place. - let Predicates = GetVTypePredicates.Predicates in - def : Pat<(vti.Scalar (extractelt (vti.Vector vti.RegClass:$rs2), 0)), - (vfmv_f_s_inst vti.RegClass:$rs2, vti.Log2SEW)>; -} - -let Predicates = [HasStdExtZvfbfa] in { - foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { - defvar fvti = fvtiToFWti.Vti; - defvar fwti = fvtiToFWti.Wti; - def : Pat<(fwti.Vector (any_riscv_fpextend_vl - (fvti.Vector fvti.RegClass:$rs1), - (fvti.Mask VMV0:$vm), - VLOpFrag)), - (!cast("PseudoVFWCVT_F_F_ALT_V_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") - (fwti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs1, - (fvti.Mask VMV0:$vm), - GPR:$vl, fvti.Log2SEW, TA_MA)>; - - def : Pat<(fvti.Vector (any_riscv_fpround_vl - (fwti.Vector fwti.RegClass:$rs1), - (fwti.Mask VMV0:$vm), VLOpFrag)), - (!cast("PseudoVFNCVT_F_F_ALT_W_"#fvti.LMul.MX#"_E"#fvti.SEW#"_MASK") - (fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1, - (fwti.Mask VMV0:$vm), - // Value to indicate no rounding mode change in - // RISCVInsertReadWriteCSR - FRM_DYN, - GPR:$vl, fvti.Log2SEW, TA_MA)>; - def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))), - (!cast("PseudoVFNCVT_F_F_ALT_W_"#fvti.LMul.MX#"_E"#fvti.SEW) - (fvti.Vector (IMPLICIT_DEF)), - fwti.RegClass:$rs1, - // Value to indicate no rounding mode change in - // RISCVInsertReadWriteCSR - FRM_DYN, - fvti.AVL, fvti.Log2SEW, TA_MA)>; - } - - foreach vti = AllBF16Vectors in { - // 13.12. Vector Floating-Point Sign-Injection Instructions - def : Pat<(fabs (vti.Vector vti.RegClass:$rs)), - (!cast("PseudoVFSGNJX_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW) - (vti.Vector (IMPLICIT_DEF)), - vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>; - // Handle fneg with VFSGNJN using the same input for both operands. - def : Pat<(fneg (vti.Vector vti.RegClass:$rs)), - (!cast("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW) - (vti.Vector (IMPLICIT_DEF)), - vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>; - - def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), - (vti.Vector vti.RegClass:$rs2))), - (!cast("PseudoVFSGNJ_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW) - (vti.Vector (IMPLICIT_DEF)), - vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; - def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), - (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs2)))), - (!cast("PseudoVFSGNJ_ALT_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW) - (vti.Vector (IMPLICIT_DEF)), - vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; - - def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), - (vti.Vector (fneg vti.RegClass:$rs2)))), - (!cast("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW) - (vti.Vector (IMPLICIT_DEF)), - vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; - def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), - (vti.Vector (fneg (SplatFPOp vti.ScalarRegClass:$rs2))))), - (!cast("PseudoVFSGNJN_ALT_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW) - (vti.Vector (IMPLICIT_DEF)), - vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; - - // 13.12. Vector Floating-Point Sign-Injection Instructions - def : Pat<(riscv_fabs_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm), - VLOpFrag), - (!cast("PseudoVFSGNJX_ALT_VV_"# vti.LMul.MX #"_E"#vti.SEW#"_MASK") - (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs, - vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, - TA_MA)>; - // Handle fneg with VFSGNJN using the same input for both operands. - def : Pat<(riscv_fneg_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm), - VLOpFrag), - (!cast("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW #"_MASK") - (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs, - vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, - TA_MA)>; - - def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1), - (vti.Vector vti.RegClass:$rs2), - vti.RegClass:$passthru, - (vti.Mask VMV0:$vm), - VLOpFrag), - (!cast("PseudoVFSGNJ_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK") - vti.RegClass:$passthru, vti.RegClass:$rs1, - vti.RegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, - TAIL_AGNOSTIC)>; - - def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1), - (riscv_fneg_vl vti.RegClass:$rs2, - (vti.Mask true_mask), - VLOpFrag), - srcvalue, - (vti.Mask true_mask), - VLOpFrag), - (!cast("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW) - (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, - vti.RegClass:$rs2, GPR:$vl, vti.Log2SEW, TA_MA)>; - - def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1), - (SplatFPOp vti.ScalarRegClass:$rs2), - vti.RegClass:$passthru, - (vti.Mask VMV0:$vm), - VLOpFrag), - (!cast("PseudoVFSGNJ_ALT_V"#vti.ScalarSuffix#"_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK") - vti.RegClass:$passthru, vti.RegClass:$rs1, - vti.ScalarRegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, - TAIL_AGNOSTIC)>; - } - } - - defm : VPatBinaryFPSDNode_VV_VF_RM; - defm : VPatBinaryFPSDNode_VV_VF_RM; - defm : VPatBinaryFPSDNode_VV_VF_RM; - defm : VPatBinaryFPSDNode_R_VF_RM; - - defm : VPatBinaryFPVL_VV_VF_RM; - defm : VPatBinaryFPVL_VV_VF_RM; - defm : VPatBinaryFPVL_VV_VF_RM; - defm : VPatBinaryFPVL_R_VF_RM; } // Predicates = [HasStdExtZvfbfa]