-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[X86][BF16] Improve vectorization of BF16 #88486
Conversation
phoebewang
commented
Apr 12, 2024
- Move expansion to combineFP_EXTEND to help with small vectors;
- Combine FP_ROUND to reduce fptrunc then fpextend after promotion;
@llvm/pr-subscribers-backend-x86 Author: Phoebe Wang (phoebewang) Changes
Full diff: https://github.com/llvm/llvm-project/pull/88486.diff 3 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b7cb4b7dafeb69..f66e6dcf9e8f63 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -21433,25 +21433,9 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
return Res;
}
- if (!SVT.isVector())
+ if (!SVT.isVector() || SVT.getVectorElementType() == MVT::bf16)
return Op;
- if (SVT.getVectorElementType() == MVT::bf16) {
- // FIXME: Do we need to support strict FP?
- assert(!IsStrict && "Strict FP doesn't support BF16");
- if (VT.getVectorElementType() == MVT::f64) {
- MVT TmpVT = VT.changeVectorElementType(MVT::f32);
- return DAG.getNode(ISD::FP_EXTEND, DL, VT,
- DAG.getNode(ISD::FP_EXTEND, DL, TmpVT, In));
- }
- assert(VT.getVectorElementType() == MVT::f32 && "Unexpected fpext");
- MVT NVT = SVT.changeVectorElementType(MVT::i32);
- In = DAG.getBitcast(SVT.changeTypeToInteger(), In);
- In = DAG.getNode(ISD::ZERO_EXTEND, DL, NVT, In);
- In = DAG.getNode(ISD::SHL, DL, NVT, In, DAG.getConstant(16, DL, NVT));
- return DAG.getBitcast(VT, In);
- }
-
if (SVT.getVectorElementType() == MVT::f16) {
if (Subtarget.hasFP16() && isTypeLegal(SVT))
return Op;
@@ -56517,16 +56501,39 @@ static SDValue combineFP16_TO_FP(SDNode *N, SelectionDAG &DAG,
static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
+ EVT VT = N->getValueType(0);
+ bool IsStrict = N->isStrictFPOpcode();
+ SDValue Src = N->getOperand(IsStrict ? 1 : 0);
+ EVT SrcVT = Src.getValueType();
+
+ SDLoc dl(N);
+ if (SrcVT.getScalarType() == MVT::bf16) {
+ if (Src.getOpcode() == ISD::FP_ROUND &&
+ Src.getOperand(0).getValueType() == VT)
+ return Src.getOperand(0);
+
+ if (!SrcVT.isVector())
+ return SDValue();
+
+ if (VT.getVectorElementType() == MVT::f64) {
+ MVT TmpVT = VT.getSimpleVT().changeVectorElementType(MVT::f32);
+ return DAG.getNode(ISD::FP_EXTEND, dl, VT,
+ DAG.getNode(ISD::FP_EXTEND, dl, TmpVT, Src));
+ }
+
+ MVT NVT = SrcVT.getSimpleVT().changeVectorElementType(MVT::i32);
+ Src = DAG.getBitcast(SrcVT.changeTypeToInteger(), Src);
+ Src = DAG.getNode(ISD::ZERO_EXTEND, dl, NVT, Src);
+ Src = DAG.getNode(ISD::SHL, dl, NVT, Src, DAG.getConstant(16, dl, NVT));
+ return DAG.getBitcast(VT, Src);
+ }
+
if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
return SDValue();
if (Subtarget.hasFP16())
return SDValue();
- bool IsStrict = N->isStrictFPOpcode();
- EVT VT = N->getValueType(0);
- SDValue Src = N->getOperand(IsStrict ? 1 : 0);
- EVT SrcVT = Src.getValueType();
if (!SrcVT.isVector() || SrcVT.getVectorElementType() != MVT::f16)
return SDValue();
@@ -56539,8 +56546,6 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
if (NumElts == 1 || !isPowerOf2_32(NumElts))
return SDValue();
- SDLoc dl(N);
-
// Convert the input to vXi16.
EVT IntVT = SrcVT.changeVectorElementTypeToInteger();
Src = DAG.getBitcast(IntVT, Src);
diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll
index 8a2109a1c78df9..39d8e2d50c91ea 100644
--- a/llvm/test/CodeGen/X86/bfloat.ll
+++ b/llvm/test/CodeGen/X86/bfloat.ll
@@ -1629,22 +1629,8 @@ define <4 x float> @pr64460_1(<4 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_1:
; SSE2: # %bb.0:
-; SSE2-NEXT: pextrw $1, %xmm0, %eax
-; SSE2-NEXT: shll $16, %eax
-; SSE2-NEXT: movd %eax, %xmm2
-; SSE2-NEXT: movd %xmm0, %eax
-; SSE2-NEXT: shll $16, %eax
-; SSE2-NEXT: movd %eax, %xmm1
-; SSE2-NEXT: pextrw $3, %xmm0, %eax
-; SSE2-NEXT: shufps {{.*#+}} xmm0 = xmm0[1,1,1,1]
-; SSE2-NEXT: punpckldq {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1]
-; SSE2-NEXT: shll $16, %eax
-; SSE2-NEXT: movd %eax, %xmm2
-; SSE2-NEXT: movd %xmm0, %eax
-; SSE2-NEXT: shll $16, %eax
-; SSE2-NEXT: movd %eax, %xmm0
-; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
-; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
+; SSE2-NEXT: pxor %xmm1, %xmm1
+; SSE2-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
; SSE2-NEXT: movdqa %xmm1, %xmm0
; SSE2-NEXT: retq
;
@@ -1666,41 +1652,11 @@ define <8 x float> @pr64460_2(<8 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_2:
; SSE2: # %bb.0:
-; SSE2-NEXT: movq %xmm0, %rdx
-; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
-; SSE2-NEXT: movq %xmm0, %rcx
-; SSE2-NEXT: movq %rcx, %rax
-; SSE2-NEXT: shrq $32, %rax
-; SSE2-NEXT: movq %rdx, %rsi
-; SSE2-NEXT: shrq $32, %rsi
-; SSE2-NEXT: movl %edx, %edi
-; SSE2-NEXT: andl $-65536, %edi # imm = 0xFFFF0000
-; SSE2-NEXT: movd %edi, %xmm1
-; SSE2-NEXT: movl %edx, %edi
-; SSE2-NEXT: shll $16, %edi
-; SSE2-NEXT: movd %edi, %xmm0
-; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
-; SSE2-NEXT: shrq $48, %rdx
-; SSE2-NEXT: shll $16, %edx
-; SSE2-NEXT: movd %edx, %xmm1
-; SSE2-NEXT: shll $16, %esi
-; SSE2-NEXT: movd %esi, %xmm2
-; SSE2-NEXT: punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
-; SSE2-NEXT: punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
-; SSE2-NEXT: movl %ecx, %edx
-; SSE2-NEXT: andl $-65536, %edx # imm = 0xFFFF0000
-; SSE2-NEXT: movd %edx, %xmm2
-; SSE2-NEXT: movl %ecx, %edx
-; SSE2-NEXT: shll $16, %edx
-; SSE2-NEXT: movd %edx, %xmm1
-; SSE2-NEXT: punpckldq {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1]
-; SSE2-NEXT: shrq $48, %rcx
-; SSE2-NEXT: shll $16, %ecx
-; SSE2-NEXT: movd %ecx, %xmm2
-; SSE2-NEXT: shll $16, %eax
-; SSE2-NEXT: movd %eax, %xmm3
-; SSE2-NEXT: punpckldq {{.*#+}} xmm3 = xmm3[0],xmm2[0],xmm3[1],xmm2[1]
-; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm3[0]
+; SSE2-NEXT: pxor %xmm1, %xmm1
+; SSE2-NEXT: pxor %xmm2, %xmm2
+; SSE2-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3]
+; SSE2-NEXT: punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7]
+; SSE2-NEXT: movdqa %xmm2, %xmm0
; SSE2-NEXT: retq
;
; AVX-LABEL: pr64460_2:
@@ -1721,76 +1677,16 @@ define <16 x float> @pr64460_3(<16 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_3:
; SSE2: # %bb.0:
-; SSE2-NEXT: movq %xmm1, %rdi
-; SSE2-NEXT: punpckhqdq {{.*#+}} xmm1 = xmm1[1,1]
-; SSE2-NEXT: movq %xmm1, %rcx
-; SSE2-NEXT: movq %rcx, %rax
-; SSE2-NEXT: shrq $32, %rax
-; SSE2-NEXT: movq %xmm0, %r9
-; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
-; SSE2-NEXT: movq %xmm0, %rsi
-; SSE2-NEXT: movq %rsi, %rdx
-; SSE2-NEXT: shrq $32, %rdx
-; SSE2-NEXT: movq %rdi, %r8
-; SSE2-NEXT: shrq $32, %r8
-; SSE2-NEXT: movq %r9, %r10
-; SSE2-NEXT: shrq $32, %r10
-; SSE2-NEXT: movl %r9d, %r11d
-; SSE2-NEXT: andl $-65536, %r11d # imm = 0xFFFF0000
-; SSE2-NEXT: movd %r11d, %xmm1
-; SSE2-NEXT: movl %r9d, %r11d
-; SSE2-NEXT: shll $16, %r11d
-; SSE2-NEXT: movd %r11d, %xmm0
-; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
-; SSE2-NEXT: shrq $48, %r9
-; SSE2-NEXT: shll $16, %r9d
-; SSE2-NEXT: movd %r9d, %xmm1
-; SSE2-NEXT: shll $16, %r10d
-; SSE2-NEXT: movd %r10d, %xmm2
-; SSE2-NEXT: punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
-; SSE2-NEXT: punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
-; SSE2-NEXT: movl %edi, %r9d
-; SSE2-NEXT: andl $-65536, %r9d # imm = 0xFFFF0000
-; SSE2-NEXT: movd %r9d, %xmm1
-; SSE2-NEXT: movl %edi, %r9d
-; SSE2-NEXT: shll $16, %r9d
-; SSE2-NEXT: movd %r9d, %xmm2
-; SSE2-NEXT: punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
-; SSE2-NEXT: shrq $48, %rdi
-; SSE2-NEXT: shll $16, %edi
-; SSE2-NEXT: movd %edi, %xmm1
-; SSE2-NEXT: shll $16, %r8d
-; SSE2-NEXT: movd %r8d, %xmm3
-; SSE2-NEXT: punpckldq {{.*#+}} xmm3 = xmm3[0],xmm1[0],xmm3[1],xmm1[1]
-; SSE2-NEXT: punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm3[0]
-; SSE2-NEXT: movl %esi, %edi
-; SSE2-NEXT: andl $-65536, %edi # imm = 0xFFFF0000
-; SSE2-NEXT: movd %edi, %xmm3
-; SSE2-NEXT: movl %esi, %edi
-; SSE2-NEXT: shll $16, %edi
-; SSE2-NEXT: movd %edi, %xmm1
-; SSE2-NEXT: punpckldq {{.*#+}} xmm1 = xmm1[0],xmm3[0],xmm1[1],xmm3[1]
-; SSE2-NEXT: shrq $48, %rsi
-; SSE2-NEXT: shll $16, %esi
-; SSE2-NEXT: movd %esi, %xmm3
-; SSE2-NEXT: shll $16, %edx
-; SSE2-NEXT: movd %edx, %xmm4
-; SSE2-NEXT: punpckldq {{.*#+}} xmm4 = xmm4[0],xmm3[0],xmm4[1],xmm3[1]
-; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm4[0]
-; SSE2-NEXT: movl %ecx, %edx
-; SSE2-NEXT: andl $-65536, %edx # imm = 0xFFFF0000
-; SSE2-NEXT: movd %edx, %xmm4
-; SSE2-NEXT: movl %ecx, %edx
-; SSE2-NEXT: shll $16, %edx
-; SSE2-NEXT: movd %edx, %xmm3
-; SSE2-NEXT: punpckldq {{.*#+}} xmm3 = xmm3[0],xmm4[0],xmm3[1],xmm4[1]
-; SSE2-NEXT: shrq $48, %rcx
-; SSE2-NEXT: shll $16, %ecx
-; SSE2-NEXT: movd %ecx, %xmm4
-; SSE2-NEXT: shll $16, %eax
-; SSE2-NEXT: movd %eax, %xmm5
-; SSE2-NEXT: punpckldq {{.*#+}} xmm5 = xmm5[0],xmm4[0],xmm5[1],xmm4[1]
-; SSE2-NEXT: punpcklqdq {{.*#+}} xmm3 = xmm3[0],xmm5[0]
+; SSE2-NEXT: pxor %xmm3, %xmm3
+; SSE2-NEXT: pxor %xmm5, %xmm5
+; SSE2-NEXT: punpcklwd {{.*#+}} xmm5 = xmm5[0],xmm0[0],xmm5[1],xmm0[1],xmm5[2],xmm0[2],xmm5[3],xmm0[3]
+; SSE2-NEXT: pxor %xmm4, %xmm4
+; SSE2-NEXT: punpckhwd {{.*#+}} xmm4 = xmm4[4],xmm0[4],xmm4[5],xmm0[5],xmm4[6],xmm0[6],xmm4[7],xmm0[7]
+; SSE2-NEXT: pxor %xmm2, %xmm2
+; SSE2-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3]
+; SSE2-NEXT: punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm1[4],xmm3[5],xmm1[5],xmm3[6],xmm1[6],xmm3[7],xmm1[7]
+; SSE2-NEXT: movdqa %xmm5, %xmm0
+; SSE2-NEXT: movdqa %xmm4, %xmm1
; SSE2-NEXT: retq
;
; F16-LABEL: pr64460_3:
@@ -1822,47 +1718,17 @@ define <8 x double> @pr64460_4(<8 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_4:
; SSE2: # %bb.0:
-; SSE2-NEXT: movq %xmm0, %rsi
-; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
-; SSE2-NEXT: movq %xmm0, %rdx
-; SSE2-NEXT: movq %rdx, %rax
-; SSE2-NEXT: shrq $32, %rax
-; SSE2-NEXT: movq %rdx, %rcx
-; SSE2-NEXT: shrq $48, %rcx
-; SSE2-NEXT: movq %rsi, %rdi
-; SSE2-NEXT: shrq $32, %rdi
-; SSE2-NEXT: movq %rsi, %r8
-; SSE2-NEXT: shrq $48, %r8
-; SSE2-NEXT: movl %esi, %r9d
-; SSE2-NEXT: andl $-65536, %r9d # imm = 0xFFFF0000
-; SSE2-NEXT: movd %r9d, %xmm0
-; SSE2-NEXT: cvtss2sd %xmm0, %xmm1
-; SSE2-NEXT: shll $16, %esi
-; SSE2-NEXT: movd %esi, %xmm0
-; SSE2-NEXT: cvtss2sd %xmm0, %xmm0
-; SSE2-NEXT: movlhps {{.*#+}} xmm0 = xmm0[0],xmm1[0]
-; SSE2-NEXT: shll $16, %r8d
-; SSE2-NEXT: movd %r8d, %xmm1
-; SSE2-NEXT: cvtss2sd %xmm1, %xmm2
-; SSE2-NEXT: shll $16, %edi
-; SSE2-NEXT: movd %edi, %xmm1
-; SSE2-NEXT: cvtss2sd %xmm1, %xmm1
-; SSE2-NEXT: movlhps {{.*#+}} xmm1 = xmm1[0],xmm2[0]
-; SSE2-NEXT: movl %edx, %esi
-; SSE2-NEXT: andl $-65536, %esi # imm = 0xFFFF0000
-; SSE2-NEXT: movd %esi, %xmm2
-; SSE2-NEXT: cvtss2sd %xmm2, %xmm3
-; SSE2-NEXT: shll $16, %edx
-; SSE2-NEXT: movd %edx, %xmm2
-; SSE2-NEXT: cvtss2sd %xmm2, %xmm2
-; SSE2-NEXT: movlhps {{.*#+}} xmm2 = xmm2[0],xmm3[0]
-; SSE2-NEXT: shll $16, %ecx
-; SSE2-NEXT: movd %ecx, %xmm3
-; SSE2-NEXT: cvtss2sd %xmm3, %xmm4
-; SSE2-NEXT: shll $16, %eax
-; SSE2-NEXT: movd %eax, %xmm3
-; SSE2-NEXT: cvtss2sd %xmm3, %xmm3
-; SSE2-NEXT: movlhps {{.*#+}} xmm3 = xmm3[0],xmm4[0]
+; SSE2-NEXT: pxor %xmm3, %xmm3
+; SSE2-NEXT: pxor %xmm1, %xmm1
+; SSE2-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
+; SSE2-NEXT: cvtps2pd %xmm1, %xmm4
+; SSE2-NEXT: punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm0[4],xmm3[5],xmm0[5],xmm3[6],xmm0[6],xmm3[7],xmm0[7]
+; SSE2-NEXT: cvtps2pd %xmm3, %xmm2
+; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
+; SSE2-NEXT: cvtps2pd %xmm0, %xmm1
+; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm3[2,3,2,3]
+; SSE2-NEXT: cvtps2pd %xmm0, %xmm3
+; SSE2-NEXT: movaps %xmm4, %xmm0
; SSE2-NEXT: retq
;
; F16-LABEL: pr64460_4:
@@ -1874,45 +1740,11 @@ define <8 x double> @pr64460_4(<8 x bfloat> %a) {
;
; AVXNC-LABEL: pr64460_4:
; AVXNC: # %bb.0:
-; AVXNC-NEXT: vpextrw $3, %xmm0, %eax
-; AVXNC-NEXT: shll $16, %eax
-; AVXNC-NEXT: vmovd %eax, %xmm1
-; AVXNC-NEXT: vcvtss2sd %xmm1, %xmm1, %xmm1
-; AVXNC-NEXT: vpextrw $2, %xmm0, %eax
-; AVXNC-NEXT: shll $16, %eax
-; AVXNC-NEXT: vmovd %eax, %xmm2
-; AVXNC-NEXT: vcvtss2sd %xmm2, %xmm2, %xmm2
-; AVXNC-NEXT: vmovlhps {{.*#+}} xmm1 = xmm2[0],xmm1[0]
-; AVXNC-NEXT: vpextrw $1, %xmm0, %eax
-; AVXNC-NEXT: shll $16, %eax
-; AVXNC-NEXT: vmovd %eax, %xmm2
-; AVXNC-NEXT: vcvtss2sd %xmm2, %xmm2, %xmm2
-; AVXNC-NEXT: vmovd %xmm0, %eax
-; AVXNC-NEXT: shll $16, %eax
-; AVXNC-NEXT: vmovd %eax, %xmm3
-; AVXNC-NEXT: vcvtss2sd %xmm3, %xmm3, %xmm3
-; AVXNC-NEXT: vmovlhps {{.*#+}} xmm2 = xmm3[0],xmm2[0]
-; AVXNC-NEXT: vinsertf128 $1, %xmm1, %ymm2, %ymm2
-; AVXNC-NEXT: vpextrw $7, %xmm0, %eax
-; AVXNC-NEXT: shll $16, %eax
-; AVXNC-NEXT: vmovd %eax, %xmm1
-; AVXNC-NEXT: vcvtss2sd %xmm1, %xmm1, %xmm1
-; AVXNC-NEXT: vpextrw $6, %xmm0, %eax
-; AVXNC-NEXT: shll $16, %eax
-; AVXNC-NEXT: vmovd %eax, %xmm3
-; AVXNC-NEXT: vcvtss2sd %xmm3, %xmm3, %xmm3
-; AVXNC-NEXT: vmovlhps {{.*#+}} xmm1 = xmm3[0],xmm1[0]
-; AVXNC-NEXT: vpextrw $5, %xmm0, %eax
-; AVXNC-NEXT: shll $16, %eax
-; AVXNC-NEXT: vmovd %eax, %xmm3
-; AVXNC-NEXT: vcvtss2sd %xmm3, %xmm3, %xmm3
-; AVXNC-NEXT: vpextrw $4, %xmm0, %eax
-; AVXNC-NEXT: shll $16, %eax
-; AVXNC-NEXT: vmovd %eax, %xmm0
-; AVXNC-NEXT: vcvtss2sd %xmm0, %xmm0, %xmm0
-; AVXNC-NEXT: vmovlhps {{.*#+}} xmm0 = xmm0[0],xmm3[0]
-; AVXNC-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm1
-; AVXNC-NEXT: vmovaps %ymm2, %ymm0
+; AVXNC-NEXT: vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
+; AVXNC-NEXT: vpslld $16, %ymm0, %ymm1
+; AVXNC-NEXT: vcvtps2pd %xmm1, %ymm0
+; AVXNC-NEXT: vextracti128 $1, %ymm1, %xmm1
+; AVXNC-NEXT: vcvtps2pd %xmm1, %ymm1
; AVXNC-NEXT: retq
%b = fpext <8 x bfloat> %a to <8 x double>
ret <8 x double> %b
diff --git a/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll b/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll
index eff1937b593436..c079a44bc5efd5 100644
--- a/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll
+++ b/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll
@@ -10,11 +10,11 @@ define void @test(<2 x ptr> %ptr) {
; CHECK-NEXT: # %bb.2: # %loop.127.preheader
; CHECK-NEXT: retq
; CHECK-NEXT: .LBB0_1: # %ifmerge.89
-; CHECK-NEXT: movzwl (%rax), %eax
-; CHECK-NEXT: shll $16, %eax
-; CHECK-NEXT: vmovd %eax, %xmm0
-; CHECK-NEXT: vmulss %xmm0, %xmm0, %xmm0
; CHECK-NEXT: vbroadcastss %xmm0, %xmm0
+; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1
+; CHECK-NEXT: vpbroadcastw (%rax), %xmm2
+; CHECK-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1],xmm1[2],xmm2[2],xmm1[3],xmm2[3]
+; CHECK-NEXT: vmulps %xmm1, %xmm0, %xmm0
; CHECK-NEXT: vmovlps %xmm0, (%rax)
entry:
br label %then.13
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
1. Move expansion to combineFP_EXTEND to help with small vectors; 2. Combine FP_ROUND to reduce fptrunc then fpextend after promotion;
MVT TmpVT = VT.getSimpleVT().changeVectorElementType(MVT::f32); | ||
return DAG.getNode(ISD::FP_EXTEND, dl, VT, | ||
DAG.getNode(ISD::FP_EXTEND, dl, TmpVT, Src)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this be OK to perform on scalars too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have handled it https://godbolt.org/z/r7cjPx9xY
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - although I think eventually it'd be nice to get bf16->f32 "zext+shift" out of X86 and handle it in generic legalization
Thanks @RKSimon, the main point here is to combine "FP_ROUND+FP_EXTEND". It's a bit hard if we lower to "zext+shift" and then combine "FP_ROUND+zext+shift". |
1. Move expansion to combineFP_EXTEND to help with small vectors; 2. Combine FP_ROUND to reduce fptrunc then fpextend after promotion;
SDLoc dl(N); | ||
if (SrcVT.getScalarType() == MVT::bf16) { | ||
if (!IsStrict && Src.getOpcode() == ISD::FP_ROUND && | ||
Src.getOperand(0).getValueType() == VT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, several weeks late, but I don't think this optimization is correct even the absence of some sort of strictness mode.
That is, extend(round(x))
isn't meaningfully equal to x
to a reasonable level of precision - the code I set this off with is explicitly using
; This is probably vectorized by the time it hits you
%vRef = load float, ptr %p.iter ; in a loop, etc.
%vTrunc = fptrunc float %vRef to bfloat
%vReExt = fpext bfloat %vTrunc to float
store float %vReExt, ptr %p.iter
in order to "lose" extra precision from a f32 computation being used as a reference for a bfloat one.
(The high-level structure goes like this
bfloat[N] cTest;
gpu_run(kernel, ..., cTest);
float[N] cTestExt = bfloat_to_float(cTest);
float[N] cRef;
gpu_run(refKernel, ..., cRef);
cRef = float_to_bfloat(bfloat_to_float(cRef));
test_accuracy(cTestExt, cRef, N, ...);
I'll also note that this transformation isn't, from what I can tell, present for half
Please revert this conditional specifically or justify it, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some discussions in 3cf8535
bfloat
is different from half
in two ways:
bfloat
has fewer fraction bits, so precision should not be a concern like other types (evenhalf
) by design;half
is an IEEE type, whilebfloat
is not. We don't necessarily follow it;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussion noted. However, I as the programmer want to explicitly perform that truncate-extend behavior in one spot in my input (because I'm testing a bfloat function whose result has been fpext
ed to against a floating-point version that's had its lower bits masked off).
This rewrite has caused per-element result errors around 1e-2 (if I remember right 16.25 vs 16.3125 or the like)
I understand that this intermediate elimination improves performance and is numerically useful a lot of the time, so, given that ... what mechanism would you recommend for forcing this optimization to not fire for a particular pair of round and extend operations?
(I don't want to make things strict at the function level if possible - I want to protect a particular fptrunc / fpext pair from this optimization. Would canoicalize
do it, or should I stick in s more interesting noop?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what mechanism would you recommend for forcing this optimization to not fire for a particular pair of round and extend operations?
After #90836 landed, you can use __arithmetic_fence
to achieve this, e.g., https://godbolt.org/z/vYr1z3h71
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what mechanism would you recommend for forcing this optimization to not fire for a particular pair of round and extend operations?
After #90836 landed, you can use __arithmetic_fence
to achieve this, e.g., https://godbolt.org/z/vYr1z3h71