@@ -215,13 +215,17 @@ multiclass VPatBinaryFPSDNode_VV_VF<SDPatternOperator vop, string instruction_na
215215}
216216
217217multiclass VPatBinaryFPSDNode_VV_VF_RM<SDPatternOperator vop, string instruction_name,
218- bit isSEWAware = 0, bit isBF16 = 0> {
219- foreach vti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in {
218+ list<VTypeInfo> vtilist = AllFloatVectors,
219+ bit isSEWAware = 0> {
220+ foreach vti = vtilist in {
220221 let Predicates = GetVTypePredicates<vti>.Predicates in {
221- def : VPatBinarySDNode_VV_RM<vop, instruction_name,
222+ def : VPatBinarySDNode_VV_RM<vop, instruction_name #
223+ !if(!eq(vti.Scalar, bf16), "_ALT", ""),
222224 vti.Vector, vti.Vector, vti.Log2SEW,
223225 vti.LMul, vti.AVL, vti.RegClass, isSEWAware>;
224- def : VPatBinarySDNode_VF_RM<vop, instruction_name#"_V"#vti.ScalarSuffix,
226+ def : VPatBinarySDNode_VF_RM<vop, instruction_name#
227+ !if(!eq(vti.Scalar, bf16), "_ALT", "")#
228+ "_V"#vti.ScalarSuffix,
225229 vti.Vector, vti.Vector, vti.Scalar,
226230 vti.Log2SEW, vti.LMul, vti.AVL, vti.RegClass,
227231 vti.ScalarRegClass, isSEWAware>;
@@ -246,14 +250,17 @@ multiclass VPatBinaryFPSDNode_R_VF<SDPatternOperator vop, string instruction_nam
246250}
247251
248252multiclass VPatBinaryFPSDNode_R_VF_RM<SDPatternOperator vop, string instruction_name,
249- bit isSEWAware = 0, bit isBF16 = 0> {
250- foreach fvti = !if(isBF16, AllBF16Vectors, AllFloatVectors) in
253+ list<VTypeInfo> vtilist = AllFloatVectors,
254+ bit isSEWAware = 0> {
255+ foreach fvti = vtilist in
251256 let Predicates = GetVTypePredicates<fvti>.Predicates in
252257 def : Pat<(fvti.Vector (vop (fvti.Vector (SplatFPOp fvti.Scalar:$rs2)),
253258 (fvti.Vector fvti.RegClass:$rs1))),
254259 (!cast<Instruction>(
255260 !if(isSEWAware,
256- instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW,
261+ instruction_name#
262+ !if(!eq(fvti.Scalar, bf16), "_ALT", "")#
263+ "_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_E"#fvti.SEW,
257264 instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX))
258265 (fvti.Vector (IMPLICIT_DEF)),
259266 fvti.RegClass:$rs1,
@@ -664,19 +671,20 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
664671 defvar vti = vtiToWti.Vti;
665672 defvar wti = vtiToWti.Wti;
666673 defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
667- let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
668- GetVTypePredicates<wti>.Predicates,
674+ let Predicates = !listconcat(GetVTypePredicates<wti>.Predicates,
669675 !if(!eq(vti.Scalar, bf16),
670676 [HasStdExtZvfbfwma],
671- [] )) in {
677+ GetVTypePredicates<vti>.Predicates )) in {
672678 def : Pat<(fma (wti.Vector (riscv_fpextend_vl_sameuser
673679 (vti.Vector vti.RegClass:$rs1),
674680 (vti.Mask true_mask), (XLenVT srcvalue))),
675681 (wti.Vector (riscv_fpextend_vl_sameuser
676682 (vti.Vector vti.RegClass:$rs2),
677683 (vti.Mask true_mask), (XLenVT srcvalue))),
678684 (wti.Vector wti.RegClass:$rd)),
679- (!cast<Instruction>(instruction_name#"_VV_"#suffix)
685+ (!cast<Instruction>(instruction_name#
686+ !if(!eq(vti.Scalar, bf16), "BF16", "")#
687+ "_VV_"#suffix)
680688 wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
681689 // Value to indicate no rounding mode change in
682690 // RISCVInsertReadWriteCSR
@@ -688,7 +696,9 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
688696 (vti.Vector vti.RegClass:$rs2),
689697 (vti.Mask true_mask), (XLenVT srcvalue))),
690698 (wti.Vector wti.RegClass:$rd)),
691- (!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#suffix)
699+ (!cast<Instruction>(instruction_name#
700+ !if(!eq(vti.Scalar, bf16), "BF16", "")#
701+ "_V"#vti.ScalarSuffix#"_"#suffix)
692702 wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
693703 // Value to indicate no rounding mode change in
694704 // RISCVInsertReadWriteCSR
@@ -1201,16 +1211,20 @@ foreach mti = AllMasks in {
12011211// 13. Vector Floating-Point Instructions
12021212
12031213// 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
1204- defm : VPatBinaryFPSDNode_VV_VF_RM<any_fadd, "PseudoVFADD", isSEWAware=1>;
1205- defm : VPatBinaryFPSDNode_VV_VF_RM<any_fsub, "PseudoVFSUB", isSEWAware=1>;
1206- defm : VPatBinaryFPSDNode_R_VF_RM<any_fsub, "PseudoVFRSUB", isSEWAware=1>;
1214+ defm : VPatBinaryFPSDNode_VV_VF_RM<any_fadd, "PseudoVFADD", AllFloatAndBF16Vectors,
1215+ isSEWAware=1>;
1216+ defm : VPatBinaryFPSDNode_VV_VF_RM<any_fsub, "PseudoVFSUB", AllFloatAndBF16Vectors,
1217+ isSEWAware=1>;
1218+ defm : VPatBinaryFPSDNode_R_VF_RM<any_fsub, "PseudoVFRSUB", AllFloatAndBF16Vectors,
1219+ isSEWAware=1>;
12071220
12081221// 13.3. Vector Widening Floating-Point Add/Subtract Instructions
12091222defm : VPatWidenBinaryFPSDNode_VV_VF_WV_WF_RM<fadd, "PseudoVFWADD">;
12101223defm : VPatWidenBinaryFPSDNode_VV_VF_WV_WF_RM<fsub, "PseudoVFWSUB">;
12111224
12121225// 13.4. Vector Single-Width Floating-Point Multiply/Divide Instructions
1213- defm : VPatBinaryFPSDNode_VV_VF_RM<any_fmul, "PseudoVFMUL", isSEWAware=1>;
1226+ defm : VPatBinaryFPSDNode_VV_VF_RM<any_fmul, "PseudoVFMUL", AllFloatAndBF16Vectors,
1227+ isSEWAware=1>;
12141228defm : VPatBinaryFPSDNode_VV_VF_RM<any_fdiv, "PseudoVFDIV", isSEWAware=1>;
12151229defm : VPatBinaryFPSDNode_R_VF_RM<any_fdiv, "PseudoVFRDIV", isSEWAware=1>;
12161230
@@ -1314,14 +1328,15 @@ foreach fvti = AllFloatVectors in {
13141328
13151329// 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
13161330defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACC",
1317- AllWidenableFloatVectors >;
1331+ AllWidenableFloatAndBF16Vectors >;
13181332defm : VPatWidenFPNegMulAccSDNode_VV_VF_RM<"PseudoVFWNMACC">;
13191333defm : VPatWidenFPMulSacSDNode_VV_VF_RM<"PseudoVFWMSAC">;
13201334defm : VPatWidenFPNegMulSacSDNode_VV_VF_RM<"PseudoVFWNMSAC">;
13211335
1322- foreach vti = AllFloatVectors in {
1336+ foreach vti = AllFloatAndBF16Vectors in {
13231337 let Predicates = GetVTypePredicates<vti>.Predicates in {
13241338 // 13.8. Vector Floating-Point Square-Root Instruction
1339+ if !ne(vti.Scalar, bf16) then
13251340 def : Pat<(any_fsqrt (vti.Vector vti.RegClass:$rs2)),
13261341 (!cast<Instruction>("PseudoVFSQRT_V_"# vti.LMul.MX#"_E"#vti.SEW)
13271342 (vti.Vector (IMPLICIT_DEF)),
@@ -1333,34 +1348,46 @@ foreach vti = AllFloatVectors in {
13331348
13341349 // 13.12. Vector Floating-Point Sign-Injection Instructions
13351350 def : Pat<(fabs (vti.Vector vti.RegClass:$rs)),
1336- (!cast<Instruction>("PseudoVFSGNJX_VV_"# vti.LMul.MX#"_E"#vti.SEW)
1351+ (!cast<Instruction>("PseudoVFSGNJX"#
1352+ !if(!eq(vti.Scalar, bf16), "_ALT", "")#
1353+ "_VV_"# vti.LMul.MX#"_E"#vti.SEW)
13371354 (vti.Vector (IMPLICIT_DEF)),
13381355 vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>;
13391356 // Handle fneg with VFSGNJN using the same input for both operands.
13401357 def : Pat<(fneg (vti.Vector vti.RegClass:$rs)),
1341- (!cast<Instruction>("PseudoVFSGNJN_VV_"# vti.LMul.MX#"_E"#vti.SEW)
1358+ (!cast<Instruction>("PseudoVFSGNJN"#
1359+ !if(!eq(vti.Scalar, bf16), "_ALT", "")#
1360+ "_VV_"# vti.LMul.MX#"_E"#vti.SEW)
13421361 (vti.Vector (IMPLICIT_DEF)),
13431362 vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>;
13441363
13451364 def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
13461365 (vti.Vector vti.RegClass:$rs2))),
1347- (!cast<Instruction>("PseudoVFSGNJ_VV_"# vti.LMul.MX#"_E"#vti.SEW)
1366+ (!cast<Instruction>("PseudoVFSGNJ"#
1367+ !if(!eq(vti.Scalar, bf16), "_ALT", "")#
1368+ "_VV_"# vti.LMul.MX#"_E"#vti.SEW)
13481369 (vti.Vector (IMPLICIT_DEF)),
13491370 vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
13501371 def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
13511372 (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs2)))),
1352- (!cast<Instruction>("PseudoVFSGNJ_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
1373+ (!cast<Instruction>("PseudoVFSGNJ"#
1374+ !if(!eq(vti.Scalar, bf16), "_ALT", "")#
1375+ "_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
13531376 (vti.Vector (IMPLICIT_DEF)),
13541377 vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
13551378
13561379 def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
13571380 (vti.Vector (fneg vti.RegClass:$rs2)))),
1358- (!cast<Instruction>("PseudoVFSGNJN_VV_"# vti.LMul.MX#"_E"#vti.SEW)
1381+ (!cast<Instruction>("PseudoVFSGNJN"#
1382+ !if(!eq(vti.Scalar, bf16), "_ALT", "")#
1383+ "_VV_"# vti.LMul.MX#"_E"#vti.SEW)
13591384 (vti.Vector (IMPLICIT_DEF)),
13601385 vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
13611386 def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
13621387 (vti.Vector (fneg (SplatFPOp vti.ScalarRegClass:$rs2))))),
1363- (!cast<Instruction>("PseudoVFSGNJN_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
1388+ (!cast<Instruction>("PseudoVFSGNJN"#
1389+ !if(!eq(vti.Scalar, bf16), "_ALT", "")#
1390+ "_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
13641391 (vti.Vector (IMPLICIT_DEF)),
13651392 vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
13661393 }
@@ -1446,13 +1473,28 @@ defm : VPatNConvertFP2ISDNode_W<any_fp_to_sint, "PseudoVFNCVT_RTZ_X_F_W">;
14461473defm : VPatNConvertFP2ISDNode_W<any_fp_to_uint, "PseudoVFNCVT_RTZ_XU_F_W">;
14471474defm : VPatNConvertI2FPSDNode_W_RM<any_sint_to_fp, "PseudoVFNCVT_F_X_W">;
14481475defm : VPatNConvertI2FPSDNode_W_RM<any_uint_to_fp, "PseudoVFNCVT_F_XU_W">;
1449- foreach fvtiToFWti = AllWidenableFloatVectors in {
1476+ foreach fvtiToFWti = AllWidenableFloatAndBF16Vectors in {
14501477 defvar fvti = fvtiToFWti.Vti;
14511478 defvar fwti = fvtiToFWti.Wti;
1452- let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates,
1453- GetVTypeMinimalPredicates<fwti>.Predicates) in
1479+ let Predicates = !listconcat(GetVTypeMinimalPredicates<fwti>.Predicates,
1480+ !if(!eq(fvti.Scalar, bf16),
1481+ [HasStdExtZvfbfmin],
1482+ GetVTypeMinimalPredicates<fvti>.Predicates)) in
1483+ def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))),
1484+ (!cast<Instruction>("PseudoVFNCVT"#
1485+ !if(!eq(fvti.Scalar, bf16), "BF16", "")#
1486+ "_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW)
1487+ (fvti.Vector (IMPLICIT_DEF)),
1488+ fwti.RegClass:$rs1,
1489+ // Value to indicate no rounding mode change in
1490+ // RISCVInsertReadWriteCSR
1491+ FRM_DYN,
1492+ fvti.AVL, fvti.Log2SEW, TA_MA)>;
1493+ // Define vfncvt.f.f.w for bf16 when Zvfbfa is enabled.
1494+ if !eq(fvti.Scalar, bf16) then
1495+ let Predicates = [HasVInstructionsBF16] in
14541496 def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))),
1455- (!cast<Instruction>("PseudoVFNCVT_F_F_W_ "#fvti.LMul.MX#"_E"#fvti.SEW)
1497+ (!cast<Instruction>("PseudoVFNCVT_F_F_ALT_W_ "#fvti.LMul.MX#"_E"#fvti.SEW)
14561498 (fvti.Vector (IMPLICIT_DEF)),
14571499 fwti.RegClass:$rs1,
14581500 // Value to indicate no rounding mode change in
@@ -1464,10 +1506,10 @@ foreach fvtiToFWti = AllWidenableFloatVectors in {
14641506//===----------------------------------------------------------------------===//
14651507// Vector Element Extracts
14661508//===----------------------------------------------------------------------===//
1467- foreach vti = NoGroupFloatVectors in {
1468- defvar vfmv_f_s_inst = !cast<Instruction>(!strconcat("PseudoVFMV_",
1469- vti.ScalarSuffix,
1470- "_S" ));
1509+ foreach vti = !listconcat( NoGroupFloatVectors, NoGroupBF16Vectors) in {
1510+ defvar vfmv_f_s_inst =
1511+ !cast<Instruction>(!strconcat("PseudoVFMV_", vti.ScalarSuffix,
1512+ "_S", !if(!eq(vti.Scalar, bf16), "_ALT", "") ));
14711513 // Only pattern-match extract-element operations where the index is 0. Any
14721514 // other index will have been custom-lowered to slide the vector correctly
14731515 // into place.
0 commit comments