Skip to content

Commit af0b992

Browse files
committed
[X86] Mark the FMA nodes as commutable so tablegen will auto generate the patterns.
This uses the capability introduced in r312464 to make SDNode patterns commutable on the first two operands. This allows us to remove some of the extra FMA patterns that have to put loads and mask operands in different places to cover all cases. This even includes patterns that were missing to support match a load in the first operand with FMA4. Non-broadcast loads with masking for AVX512. I believe this is causing us to generate some duplicate patterns because tablegen's isomorphism checks don't catch isomorphism between the patterns as written in the td. It only detects isomorphism in the commuted variants it tries to create. The the unmasked 231 and 132 memory forms are isomorphic as written in the td file so we end up keeping both. I think we precommute the 132 pattern to fix this. We also need a follow up patch to go back to the legacy FMA3 instructions and add patterns to the 231 and 132 forms which we currently don't have. llvm-svn: 312469
1 parent 561f0de commit af0b992

File tree

3 files changed

+28
-62
lines changed

3 files changed

+28
-62
lines changed

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6680,14 +6680,6 @@ multiclass avx512_fma3p_213_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
66806680
_.RC:$src1,(_.VT (X86VBroadcast (_.ScalarLdFrag addr:$src3)))), 1, 0>,
66816681
AVX512FMA3Base, EVEX_B;
66826682
}
6683-
6684-
// Additional pattern for folding broadcast nodes in other orders.
6685-
def : Pat<(_.VT (vselect _.KRCWM:$mask,
6686-
(OpNode _.RC:$src1, _.RC:$src2,
6687-
(X86VBroadcast (_.ScalarLdFrag addr:$src3))),
6688-
_.RC:$src1)),
6689-
(!cast<Instruction>(NAME#Suff#_.ZSuffix#mbk) _.RC:$src1,
6690-
_.KRCWM:$mask, _.RC:$src2, addr:$src3)>;
66916683
}
66926684

66936685
multiclass avx512_fma3_213_round<bits<8> opc, string OpcodeStr, SDNode OpNode,
@@ -6724,7 +6716,7 @@ multiclass avx512_fma3p_213_f<bits<8> opc, string OpcodeStr, SDNode OpNode,
67246716
avx512vl_f64_info, "PD">, VEX_W;
67256717
}
67266718

6727-
defm VFMADD213 : avx512_fma3p_213_f<0xA8, "vfmadd213", fma, X86FmaddRnd>;
6719+
defm VFMADD213 : avx512_fma3p_213_f<0xA8, "vfmadd213", X86Fmadd, X86FmaddRnd>;
67286720
defm VFMSUB213 : avx512_fma3p_213_f<0xAA, "vfmsub213", X86Fmsub, X86FmsubRnd>;
67296721
defm VFMADDSUB213 : avx512_fma3p_213_f<0xA6, "vfmaddsub213", X86Fmaddsub, X86FmaddsubRnd>;
67306722
defm VFMSUBADD213 : avx512_fma3p_213_f<0xA7, "vfmsubadd213", X86Fmsubadd, X86FmsubaddRnd>;
@@ -6755,24 +6747,6 @@ multiclass avx512_fma3p_231_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
67556747
(_.VT (X86VBroadcast(_.ScalarLdFrag addr:$src3))),
67566748
_.RC:$src1)), 1, 0>, AVX512FMA3Base, EVEX_B;
67576749
}
6758-
6759-
// Additional patterns for folding broadcast nodes in other orders.
6760-
def : Pat<(_.VT (OpNode (X86VBroadcast (_.ScalarLdFrag addr:$src3)),
6761-
_.RC:$src2, _.RC:$src1)),
6762-
(!cast<Instruction>(NAME#Suff#_.ZSuffix#mb) _.RC:$src1,
6763-
_.RC:$src2, addr:$src3)>;
6764-
def : Pat<(_.VT (vselect _.KRCWM:$mask,
6765-
(OpNode (X86VBroadcast (_.ScalarLdFrag addr:$src3)),
6766-
_.RC:$src2, _.RC:$src1),
6767-
_.RC:$src1)),
6768-
(!cast<Instruction>(NAME#Suff#_.ZSuffix#mbk) _.RC:$src1,
6769-
_.KRCWM:$mask, _.RC:$src2, addr:$src3)>;
6770-
def : Pat<(_.VT (vselect _.KRCWM:$mask,
6771-
(OpNode (X86VBroadcast (_.ScalarLdFrag addr:$src3)),
6772-
_.RC:$src2, _.RC:$src1),
6773-
_.ImmAllZerosV)),
6774-
(!cast<Instruction>(NAME#Suff#_.ZSuffix#mbkz) _.RC:$src1,
6775-
_.KRCWM:$mask, _.RC:$src2, addr:$src3)>;
67766750
}
67776751

67786752
multiclass avx512_fma3_231_round<bits<8> opc, string OpcodeStr, SDNode OpNode,
@@ -6810,7 +6784,7 @@ multiclass avx512_fma3p_231_f<bits<8> opc, string OpcodeStr, SDNode OpNode,
68106784
avx512vl_f64_info, "PD">, VEX_W;
68116785
}
68126786

6813-
defm VFMADD231 : avx512_fma3p_231_f<0xB8, "vfmadd231", fma, X86FmaddRnd>;
6787+
defm VFMADD231 : avx512_fma3p_231_f<0xB8, "vfmadd231", X86Fmadd, X86FmaddRnd>;
68146788
defm VFMSUB231 : avx512_fma3p_231_f<0xBA, "vfmsub231", X86Fmsub, X86FmsubRnd>;
68156789
defm VFMADDSUB231 : avx512_fma3p_231_f<0xB6, "vfmaddsub231", X86Fmaddsub, X86FmaddsubRnd>;
68166790
defm VFMSUBADD231 : avx512_fma3p_231_f<0xB7, "vfmsubadd231", X86Fmsubadd, X86FmsubaddRnd>;
@@ -6840,14 +6814,6 @@ multiclass avx512_fma3p_132_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
68406814
(_.VT (X86VBroadcast(_.ScalarLdFrag addr:$src3))),
68416815
_.RC:$src2)), 1, 0>, AVX512FMA3Base, EVEX_B;
68426816
}
6843-
6844-
// Additional patterns for folding broadcast nodes in other orders.
6845-
def : Pat<(_.VT (vselect _.KRCWM:$mask,
6846-
(OpNode (X86VBroadcast (_.ScalarLdFrag addr:$src3)),
6847-
_.RC:$src1, _.RC:$src2),
6848-
_.RC:$src1)),
6849-
(!cast<Instruction>(NAME#Suff#_.ZSuffix#mbk) _.RC:$src1,
6850-
_.KRCWM:$mask, _.RC:$src2, addr:$src3)>;
68516817
}
68526818

68536819
multiclass avx512_fma3_132_round<bits<8> opc, string OpcodeStr, SDNode OpNode,
@@ -6885,7 +6851,7 @@ multiclass avx512_fma3p_132_f<bits<8> opc, string OpcodeStr, SDNode OpNode,
68856851
avx512vl_f64_info, "PD">, VEX_W;
68866852
}
68876853

6888-
defm VFMADD132 : avx512_fma3p_132_f<0x98, "vfmadd132", fma, X86FmaddRnd>;
6854+
defm VFMADD132 : avx512_fma3p_132_f<0x98, "vfmadd132", X86Fmadd, X86FmaddRnd>;
68896855
defm VFMSUB132 : avx512_fma3p_132_f<0x9A, "vfmsub132", X86Fmsub, X86FmsubRnd>;
68906856
defm VFMADDSUB132 : avx512_fma3p_132_f<0x96, "vfmaddsub132", X86Fmaddsub, X86FmaddsubRnd>;
68916857
defm VFMSUBADD132 : avx512_fma3p_132_f<0x97, "vfmsubadd132", X86Fmsubadd, X86FmsubaddRnd>;
@@ -6984,7 +6950,7 @@ multiclass avx512_fma3s<bits<8> opc213, bits<8> opc231, bits<8> opc132,
69846950
}
69856951
}
69866952

6987-
defm VFMADD : avx512_fma3s<0xA9, 0xB9, 0x99, "vfmadd", fma, X86FmaddRnds1,
6953+
defm VFMADD : avx512_fma3s<0xA9, 0xB9, 0x99, "vfmadd", X86Fmadd, X86FmaddRnds1,
69886954
X86FmaddRnds3>;
69896955
defm VFMSUB : avx512_fma3s<0xAB, 0xBB, 0x9B, "vfmsub", X86Fmsub, X86FmsubRnds1,
69906956
X86FmsubRnds3>;

llvm/lib/Target/X86/X86InstrFMA.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ multiclass fma3p_forms<bits<8> opc132, bits<8> opc213, bits<8> opc231,
8989
// Fused Multiply-Add
9090
let ExeDomain = SSEPackedSingle in {
9191
defm VFMADD : fma3p_forms<0x98, 0xA8, 0xB8, "vfmadd", "ps", "PS",
92-
loadv4f32, loadv8f32, fma, v4f32, v8f32>;
92+
loadv4f32, loadv8f32, X86Fmadd, v4f32, v8f32>;
9393
defm VFMSUB : fma3p_forms<0x9A, 0xAA, 0xBA, "vfmsub", "ps", "PS",
9494
loadv4f32, loadv8f32, X86Fmsub, v4f32, v8f32>;
9595
defm VFMADDSUB : fma3p_forms<0x96, 0xA6, 0xB6, "vfmaddsub", "ps", "PS",
@@ -102,7 +102,7 @@ let ExeDomain = SSEPackedSingle in {
102102

103103
let ExeDomain = SSEPackedDouble in {
104104
defm VFMADD : fma3p_forms<0x98, 0xA8, 0xB8, "vfmadd", "pd", "PD",
105-
loadv2f64, loadv4f64, fma, v2f64,
105+
loadv2f64, loadv4f64, X86Fmadd, v2f64,
106106
v4f64>, VEX_W;
107107
defm VFMSUB : fma3p_forms<0x9A, 0xAA, 0xBA, "vfmsub", "pd", "PD",
108108
loadv2f64, loadv4f64, X86Fmsub, v2f64,
@@ -271,7 +271,7 @@ multiclass fma3s<bits<8> opc132, bits<8> opc213, bits<8> opc231,
271271
}
272272

273273
defm VFMADD : fma3s<0x99, 0xA9, 0xB9, "vfmadd", int_x86_fma_vfmadd_ss,
274-
int_x86_fma_vfmadd_sd, fma>, VEX_LIG;
274+
int_x86_fma_vfmadd_sd, X86Fmadd>, VEX_LIG;
275275
defm VFMSUB : fma3s<0x9B, 0xAB, 0xBB, "vfmsub", int_x86_fma_vfmsub_ss,
276276
int_x86_fma_vfmsub_sd, X86Fmsub>, VEX_LIG;
277277

@@ -407,7 +407,7 @@ let isCodeGenOnly = 1, ForceDisassemble = 1, hasSideEffects = 0 in {
407407

408408
let ExeDomain = SSEPackedSingle in {
409409
// Scalar Instructions
410-
defm VFMADDSS4 : fma4s<0x6A, "vfmaddss", FR32, f32mem, f32, fma, loadf32>,
410+
defm VFMADDSS4 : fma4s<0x6A, "vfmaddss", FR32, f32mem, f32, X86Fmadd, loadf32>,
411411
fma4s_int<0x6A, "vfmaddss", ssmem, sse_load_f32,
412412
int_x86_fma_vfmadd_ss>;
413413
defm VFMSUBSS4 : fma4s<0x6E, "vfmsubss", FR32, f32mem, f32, X86Fmsub, loadf32>,
@@ -422,7 +422,7 @@ let ExeDomain = SSEPackedSingle in {
422422
fma4s_int<0x7E, "vfnmsubss", ssmem, sse_load_f32,
423423
int_x86_fma_vfnmsub_ss>;
424424
// Packed Instructions
425-
defm VFMADDPS4 : fma4p<0x68, "vfmaddps", fma, v4f32, v8f32,
425+
defm VFMADDPS4 : fma4p<0x68, "vfmaddps", X86Fmadd, v4f32, v8f32,
426426
loadv4f32, loadv8f32>;
427427
defm VFMSUBPS4 : fma4p<0x6C, "vfmsubps", X86Fmsub, v4f32, v8f32,
428428
loadv4f32, loadv8f32>;
@@ -438,7 +438,7 @@ let ExeDomain = SSEPackedSingle in {
438438

439439
let ExeDomain = SSEPackedDouble in {
440440
// Scalar Instructions
441-
defm VFMADDSD4 : fma4s<0x6B, "vfmaddsd", FR64, f64mem, f64, fma, loadf64>,
441+
defm VFMADDSD4 : fma4s<0x6B, "vfmaddsd", FR64, f64mem, f64, X86Fmadd, loadf64>,
442442
fma4s_int<0x6B, "vfmaddsd", sdmem, sse_load_f64,
443443
int_x86_fma_vfmadd_sd>;
444444
defm VFMSUBSD4 : fma4s<0x6F, "vfmsubsd", FR64, f64mem, f64, X86Fmsub, loadf64>,
@@ -453,7 +453,7 @@ let ExeDomain = SSEPackedDouble in {
453453
fma4s_int<0x7F, "vfnmsubsd", sdmem, sse_load_f64,
454454
int_x86_fma_vfnmsub_sd>;
455455
// Packed Instructions
456-
defm VFMADDPD4 : fma4p<0x69, "vfmaddpd", fma, v2f64, v4f64,
456+
defm VFMADDPD4 : fma4p<0x69, "vfmaddpd", X86Fmadd, v2f64, v4f64,
457457
loadv2f64, loadv4f64>;
458458
defm VFMSUBPD4 : fma4p<0x6D, "vfmsubpd", X86Fmsub, v2f64, v4f64,
459459
loadv2f64, loadv4f64>;

llvm/lib/Target/X86/X86InstrFragmentsSIMD.td

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -481,19 +481,19 @@ def X86fsqrtRnds : SDNode<"X86ISD::FSQRTS_RND", SDTFPBinOpRound>;
481481
def X86fgetexpRnd : SDNode<"X86ISD::FGETEXP_RND", SDTFPUnaryOpRound>;
482482
def X86fgetexpRnds : SDNode<"X86ISD::FGETEXPS_RND", SDTFPBinOpRound>;
483483

484-
// No need for FMADD because we use ISD::FMA.
485-
def X86Fnmadd : SDNode<"X86ISD::FNMADD", SDTFPTernaryOp>;
486-
def X86Fmsub : SDNode<"X86ISD::FMSUB", SDTFPTernaryOp>;
487-
def X86Fnmsub : SDNode<"X86ISD::FNMSUB", SDTFPTernaryOp>;
488-
def X86Fmaddsub : SDNode<"X86ISD::FMADDSUB", SDTFPTernaryOp>;
489-
def X86Fmsubadd : SDNode<"X86ISD::FMSUBADD", SDTFPTernaryOp>;
490-
491-
def X86FmaddRnd : SDNode<"X86ISD::FMADD_RND", SDTFmaRound>;
492-
def X86FnmaddRnd : SDNode<"X86ISD::FNMADD_RND", SDTFmaRound>;
493-
def X86FmsubRnd : SDNode<"X86ISD::FMSUB_RND", SDTFmaRound>;
494-
def X86FnmsubRnd : SDNode<"X86ISD::FNMSUB_RND", SDTFmaRound>;
495-
def X86FmaddsubRnd : SDNode<"X86ISD::FMADDSUB_RND", SDTFmaRound>;
496-
def X86FmsubaddRnd : SDNode<"X86ISD::FMSUBADD_RND", SDTFmaRound>;
484+
def X86Fmadd : SDNode<"ISD::FMA", SDTFPTernaryOp, [SDNPCommutative]>;
485+
def X86Fnmadd : SDNode<"X86ISD::FNMADD", SDTFPTernaryOp, [SDNPCommutative]>;
486+
def X86Fmsub : SDNode<"X86ISD::FMSUB", SDTFPTernaryOp, [SDNPCommutative]>;
487+
def X86Fnmsub : SDNode<"X86ISD::FNMSUB", SDTFPTernaryOp, [SDNPCommutative]>;
488+
def X86Fmaddsub : SDNode<"X86ISD::FMADDSUB", SDTFPTernaryOp, [SDNPCommutative]>;
489+
def X86Fmsubadd : SDNode<"X86ISD::FMSUBADD", SDTFPTernaryOp, [SDNPCommutative]>;
490+
491+
def X86FmaddRnd : SDNode<"X86ISD::FMADD_RND", SDTFmaRound, [SDNPCommutative]>;
492+
def X86FnmaddRnd : SDNode<"X86ISD::FNMADD_RND", SDTFmaRound, [SDNPCommutative]>;
493+
def X86FmsubRnd : SDNode<"X86ISD::FMSUB_RND", SDTFmaRound, [SDNPCommutative]>;
494+
def X86FnmsubRnd : SDNode<"X86ISD::FNMSUB_RND", SDTFmaRound, [SDNPCommutative]>;
495+
def X86FmaddsubRnd : SDNode<"X86ISD::FMADDSUB_RND", SDTFmaRound, [SDNPCommutative]>;
496+
def X86FmsubaddRnd : SDNode<"X86ISD::FMSUBADD_RND", SDTFmaRound, [SDNPCommutative]>;
497497

498498
// Scalar FMA intrinsics with passthru bits in operand 1.
499499
def X86FmaddRnds1 : SDNode<"X86ISD::FMADDS1_RND", SDTFmaRound>;
@@ -502,10 +502,10 @@ def X86FmsubRnds1 : SDNode<"X86ISD::FMSUBS1_RND", SDTFmaRound>;
502502
def X86FnmsubRnds1 : SDNode<"X86ISD::FNMSUBS1_RND", SDTFmaRound>;
503503

504504
// Scalar FMA intrinsics with passthru bits in operand 3.
505-
def X86FmaddRnds3 : SDNode<"X86ISD::FMADDS3_RND", SDTFmaRound>;
506-
def X86FnmaddRnds3 : SDNode<"X86ISD::FNMADDS3_RND", SDTFmaRound>;
507-
def X86FmsubRnds3 : SDNode<"X86ISD::FMSUBS3_RND", SDTFmaRound>;
508-
def X86FnmsubRnds3 : SDNode<"X86ISD::FNMSUBS3_RND", SDTFmaRound>;
505+
def X86FmaddRnds3 : SDNode<"X86ISD::FMADDS3_RND", SDTFmaRound, [SDNPCommutative]>;
506+
def X86FnmaddRnds3 : SDNode<"X86ISD::FNMADDS3_RND", SDTFmaRound, [SDNPCommutative]>;
507+
def X86FmsubRnds3 : SDNode<"X86ISD::FMSUBS3_RND", SDTFmaRound, [SDNPCommutative]>;
508+
def X86FnmsubRnds3 : SDNode<"X86ISD::FNMSUBS3_RND", SDTFmaRound, [SDNPCommutative]>;
509509

510510
def SDTIFma : SDTypeProfile<1, 3, [SDTCisInt<0>, SDTCisSameAs<0,1>,
511511
SDTCisSameAs<1,2>, SDTCisSameAs<1,3>]>;

0 commit comments

Comments
 (0)