Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86][BF16] Improve vectorization of BF16 #88486

Merged
merged 2 commits into from
Apr 13, 2024
Merged

Conversation

phoebewang
Copy link
Contributor

  1. Move expansion to combineFP_EXTEND to help with small vectors;
  2. Combine FP_ROUND to reduce fptrunc then fpextend after promotion;

@llvmbot
Copy link
Member

llvmbot commented Apr 12, 2024

@llvm/pr-subscribers-backend-x86

Author: Phoebe Wang (phoebewang)

Changes
  1. Move expansion to combineFP_EXTEND to help with small vectors;
  2. Combine FP_ROUND to reduce fptrunc then fpextend after promotion;

Full diff: https://github.com/llvm/llvm-project/pull/88486.diff

3 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+28-23)
  • (modified) llvm/test/CodeGen/X86/bfloat.ll (+33-201)
  • (modified) llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll (+4-4)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b7cb4b7dafeb69..f66e6dcf9e8f63 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -21433,25 +21433,9 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
     return Res;
   }
 
-  if (!SVT.isVector())
+  if (!SVT.isVector() || SVT.getVectorElementType() == MVT::bf16)
     return Op;
 
-  if (SVT.getVectorElementType() == MVT::bf16) {
-    // FIXME: Do we need to support strict FP?
-    assert(!IsStrict && "Strict FP doesn't support BF16");
-    if (VT.getVectorElementType() == MVT::f64) {
-      MVT TmpVT = VT.changeVectorElementType(MVT::f32);
-      return DAG.getNode(ISD::FP_EXTEND, DL, VT,
-                         DAG.getNode(ISD::FP_EXTEND, DL, TmpVT, In));
-    }
-    assert(VT.getVectorElementType() == MVT::f32 && "Unexpected fpext");
-    MVT NVT = SVT.changeVectorElementType(MVT::i32);
-    In = DAG.getBitcast(SVT.changeTypeToInteger(), In);
-    In = DAG.getNode(ISD::ZERO_EXTEND, DL, NVT, In);
-    In = DAG.getNode(ISD::SHL, DL, NVT, In, DAG.getConstant(16, DL, NVT));
-    return DAG.getBitcast(VT, In);
-  }
-
   if (SVT.getVectorElementType() == MVT::f16) {
     if (Subtarget.hasFP16() && isTypeLegal(SVT))
       return Op;
@@ -56517,16 +56501,39 @@ static SDValue combineFP16_TO_FP(SDNode *N, SelectionDAG &DAG,
 
 static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
                                 const X86Subtarget &Subtarget) {
+  EVT VT = N->getValueType(0);
+  bool IsStrict = N->isStrictFPOpcode();
+  SDValue Src = N->getOperand(IsStrict ? 1 : 0);
+  EVT SrcVT = Src.getValueType();
+
+  SDLoc dl(N);
+  if (SrcVT.getScalarType() == MVT::bf16) {
+    if (Src.getOpcode() == ISD::FP_ROUND &&
+        Src.getOperand(0).getValueType() == VT)
+      return Src.getOperand(0);
+
+    if (!SrcVT.isVector())
+      return SDValue();
+
+    if (VT.getVectorElementType() == MVT::f64) {
+      MVT TmpVT = VT.getSimpleVT().changeVectorElementType(MVT::f32);
+      return DAG.getNode(ISD::FP_EXTEND, dl, VT,
+                         DAG.getNode(ISD::FP_EXTEND, dl, TmpVT, Src));
+    }
+
+    MVT NVT = SrcVT.getSimpleVT().changeVectorElementType(MVT::i32);
+    Src = DAG.getBitcast(SrcVT.changeTypeToInteger(), Src);
+    Src = DAG.getNode(ISD::ZERO_EXTEND, dl, NVT, Src);
+    Src = DAG.getNode(ISD::SHL, dl, NVT, Src, DAG.getConstant(16, dl, NVT));
+    return DAG.getBitcast(VT, Src);
+  }
+
   if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
     return SDValue();
 
   if (Subtarget.hasFP16())
     return SDValue();
 
-  bool IsStrict = N->isStrictFPOpcode();
-  EVT VT = N->getValueType(0);
-  SDValue Src = N->getOperand(IsStrict ? 1 : 0);
-  EVT SrcVT = Src.getValueType();
 
   if (!SrcVT.isVector() || SrcVT.getVectorElementType() != MVT::f16)
     return SDValue();
@@ -56539,8 +56546,6 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
   if (NumElts == 1 || !isPowerOf2_32(NumElts))
     return SDValue();
 
-  SDLoc dl(N);
-
   // Convert the input to vXi16.
   EVT IntVT = SrcVT.changeVectorElementTypeToInteger();
   Src = DAG.getBitcast(IntVT, Src);
diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll
index 8a2109a1c78df9..39d8e2d50c91ea 100644
--- a/llvm/test/CodeGen/X86/bfloat.ll
+++ b/llvm/test/CodeGen/X86/bfloat.ll
@@ -1629,22 +1629,8 @@ define <4 x float> @pr64460_1(<4 x bfloat> %a) {
 ;
 ; SSE2-LABEL: pr64460_1:
 ; SSE2:       # %bb.0:
-; SSE2-NEXT:    pextrw $1, %xmm0, %eax
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm2
-; SSE2-NEXT:    movd %xmm0, %eax
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm1
-; SSE2-NEXT:    pextrw $3, %xmm0, %eax
-; SSE2-NEXT:    shufps {{.*#+}} xmm0 = xmm0[1,1,1,1]
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1]
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm2
-; SSE2-NEXT:    movd %xmm0, %eax
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm0
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
+; SSE2-NEXT:    pxor %xmm1, %xmm1
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
 ; SSE2-NEXT:    movdqa %xmm1, %xmm0
 ; SSE2-NEXT:    retq
 ;
@@ -1666,41 +1652,11 @@ define <8 x float> @pr64460_2(<8 x bfloat> %a) {
 ;
 ; SSE2-LABEL: pr64460_2:
 ; SSE2:       # %bb.0:
-; SSE2-NEXT:    movq %xmm0, %rdx
-; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
-; SSE2-NEXT:    movq %xmm0, %rcx
-; SSE2-NEXT:    movq %rcx, %rax
-; SSE2-NEXT:    shrq $32, %rax
-; SSE2-NEXT:    movq %rdx, %rsi
-; SSE2-NEXT:    shrq $32, %rsi
-; SSE2-NEXT:    movl %edx, %edi
-; SSE2-NEXT:    andl $-65536, %edi # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %edi, %xmm1
-; SSE2-NEXT:    movl %edx, %edi
-; SSE2-NEXT:    shll $16, %edi
-; SSE2-NEXT:    movd %edi, %xmm0
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
-; SSE2-NEXT:    shrq $48, %rdx
-; SSE2-NEXT:    shll $16, %edx
-; SSE2-NEXT:    movd %edx, %xmm1
-; SSE2-NEXT:    shll $16, %esi
-; SSE2-NEXT:    movd %esi, %xmm2
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
-; SSE2-NEXT:    movl %ecx, %edx
-; SSE2-NEXT:    andl $-65536, %edx # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %edx, %xmm2
-; SSE2-NEXT:    movl %ecx, %edx
-; SSE2-NEXT:    shll $16, %edx
-; SSE2-NEXT:    movd %edx, %xmm1
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1]
-; SSE2-NEXT:    shrq $48, %rcx
-; SSE2-NEXT:    shll $16, %ecx
-; SSE2-NEXT:    movd %ecx, %xmm2
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm3
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm3 = xmm3[0],xmm2[0],xmm3[1],xmm2[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm3[0]
+; SSE2-NEXT:    pxor %xmm1, %xmm1
+; SSE2-NEXT:    pxor %xmm2, %xmm2
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3]
+; SSE2-NEXT:    punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7]
+; SSE2-NEXT:    movdqa %xmm2, %xmm0
 ; SSE2-NEXT:    retq
 ;
 ; AVX-LABEL: pr64460_2:
@@ -1721,76 +1677,16 @@ define <16 x float> @pr64460_3(<16 x bfloat> %a) {
 ;
 ; SSE2-LABEL: pr64460_3:
 ; SSE2:       # %bb.0:
-; SSE2-NEXT:    movq %xmm1, %rdi
-; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm1 = xmm1[1,1]
-; SSE2-NEXT:    movq %xmm1, %rcx
-; SSE2-NEXT:    movq %rcx, %rax
-; SSE2-NEXT:    shrq $32, %rax
-; SSE2-NEXT:    movq %xmm0, %r9
-; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
-; SSE2-NEXT:    movq %xmm0, %rsi
-; SSE2-NEXT:    movq %rsi, %rdx
-; SSE2-NEXT:    shrq $32, %rdx
-; SSE2-NEXT:    movq %rdi, %r8
-; SSE2-NEXT:    shrq $32, %r8
-; SSE2-NEXT:    movq %r9, %r10
-; SSE2-NEXT:    shrq $32, %r10
-; SSE2-NEXT:    movl %r9d, %r11d
-; SSE2-NEXT:    andl $-65536, %r11d # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %r11d, %xmm1
-; SSE2-NEXT:    movl %r9d, %r11d
-; SSE2-NEXT:    shll $16, %r11d
-; SSE2-NEXT:    movd %r11d, %xmm0
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
-; SSE2-NEXT:    shrq $48, %r9
-; SSE2-NEXT:    shll $16, %r9d
-; SSE2-NEXT:    movd %r9d, %xmm1
-; SSE2-NEXT:    shll $16, %r10d
-; SSE2-NEXT:    movd %r10d, %xmm2
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
-; SSE2-NEXT:    movl %edi, %r9d
-; SSE2-NEXT:    andl $-65536, %r9d # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %r9d, %xmm1
-; SSE2-NEXT:    movl %edi, %r9d
-; SSE2-NEXT:    shll $16, %r9d
-; SSE2-NEXT:    movd %r9d, %xmm2
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
-; SSE2-NEXT:    shrq $48, %rdi
-; SSE2-NEXT:    shll $16, %edi
-; SSE2-NEXT:    movd %edi, %xmm1
-; SSE2-NEXT:    shll $16, %r8d
-; SSE2-NEXT:    movd %r8d, %xmm3
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm3 = xmm3[0],xmm1[0],xmm3[1],xmm1[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm3[0]
-; SSE2-NEXT:    movl %esi, %edi
-; SSE2-NEXT:    andl $-65536, %edi # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %edi, %xmm3
-; SSE2-NEXT:    movl %esi, %edi
-; SSE2-NEXT:    shll $16, %edi
-; SSE2-NEXT:    movd %edi, %xmm1
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm1 = xmm1[0],xmm3[0],xmm1[1],xmm3[1]
-; SSE2-NEXT:    shrq $48, %rsi
-; SSE2-NEXT:    shll $16, %esi
-; SSE2-NEXT:    movd %esi, %xmm3
-; SSE2-NEXT:    shll $16, %edx
-; SSE2-NEXT:    movd %edx, %xmm4
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm4 = xmm4[0],xmm3[0],xmm4[1],xmm3[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm4[0]
-; SSE2-NEXT:    movl %ecx, %edx
-; SSE2-NEXT:    andl $-65536, %edx # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %edx, %xmm4
-; SSE2-NEXT:    movl %ecx, %edx
-; SSE2-NEXT:    shll $16, %edx
-; SSE2-NEXT:    movd %edx, %xmm3
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm3 = xmm3[0],xmm4[0],xmm3[1],xmm4[1]
-; SSE2-NEXT:    shrq $48, %rcx
-; SSE2-NEXT:    shll $16, %ecx
-; SSE2-NEXT:    movd %ecx, %xmm4
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm5
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm5 = xmm5[0],xmm4[0],xmm5[1],xmm4[1]
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm3 = xmm3[0],xmm5[0]
+; SSE2-NEXT:    pxor %xmm3, %xmm3
+; SSE2-NEXT:    pxor %xmm5, %xmm5
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm5 = xmm5[0],xmm0[0],xmm5[1],xmm0[1],xmm5[2],xmm0[2],xmm5[3],xmm0[3]
+; SSE2-NEXT:    pxor %xmm4, %xmm4
+; SSE2-NEXT:    punpckhwd {{.*#+}} xmm4 = xmm4[4],xmm0[4],xmm4[5],xmm0[5],xmm4[6],xmm0[6],xmm4[7],xmm0[7]
+; SSE2-NEXT:    pxor %xmm2, %xmm2
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3]
+; SSE2-NEXT:    punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm1[4],xmm3[5],xmm1[5],xmm3[6],xmm1[6],xmm3[7],xmm1[7]
+; SSE2-NEXT:    movdqa %xmm5, %xmm0
+; SSE2-NEXT:    movdqa %xmm4, %xmm1
 ; SSE2-NEXT:    retq
 ;
 ; F16-LABEL: pr64460_3:
@@ -1822,47 +1718,17 @@ define <8 x double> @pr64460_4(<8 x bfloat> %a) {
 ;
 ; SSE2-LABEL: pr64460_4:
 ; SSE2:       # %bb.0:
-; SSE2-NEXT:    movq %xmm0, %rsi
-; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
-; SSE2-NEXT:    movq %xmm0, %rdx
-; SSE2-NEXT:    movq %rdx, %rax
-; SSE2-NEXT:    shrq $32, %rax
-; SSE2-NEXT:    movq %rdx, %rcx
-; SSE2-NEXT:    shrq $48, %rcx
-; SSE2-NEXT:    movq %rsi, %rdi
-; SSE2-NEXT:    shrq $32, %rdi
-; SSE2-NEXT:    movq %rsi, %r8
-; SSE2-NEXT:    shrq $48, %r8
-; SSE2-NEXT:    movl %esi, %r9d
-; SSE2-NEXT:    andl $-65536, %r9d # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %r9d, %xmm0
-; SSE2-NEXT:    cvtss2sd %xmm0, %xmm1
-; SSE2-NEXT:    shll $16, %esi
-; SSE2-NEXT:    movd %esi, %xmm0
-; SSE2-NEXT:    cvtss2sd %xmm0, %xmm0
-; SSE2-NEXT:    movlhps {{.*#+}} xmm0 = xmm0[0],xmm1[0]
-; SSE2-NEXT:    shll $16, %r8d
-; SSE2-NEXT:    movd %r8d, %xmm1
-; SSE2-NEXT:    cvtss2sd %xmm1, %xmm2
-; SSE2-NEXT:    shll $16, %edi
-; SSE2-NEXT:    movd %edi, %xmm1
-; SSE2-NEXT:    cvtss2sd %xmm1, %xmm1
-; SSE2-NEXT:    movlhps {{.*#+}} xmm1 = xmm1[0],xmm2[0]
-; SSE2-NEXT:    movl %edx, %esi
-; SSE2-NEXT:    andl $-65536, %esi # imm = 0xFFFF0000
-; SSE2-NEXT:    movd %esi, %xmm2
-; SSE2-NEXT:    cvtss2sd %xmm2, %xmm3
-; SSE2-NEXT:    shll $16, %edx
-; SSE2-NEXT:    movd %edx, %xmm2
-; SSE2-NEXT:    cvtss2sd %xmm2, %xmm2
-; SSE2-NEXT:    movlhps {{.*#+}} xmm2 = xmm2[0],xmm3[0]
-; SSE2-NEXT:    shll $16, %ecx
-; SSE2-NEXT:    movd %ecx, %xmm3
-; SSE2-NEXT:    cvtss2sd %xmm3, %xmm4
-; SSE2-NEXT:    shll $16, %eax
-; SSE2-NEXT:    movd %eax, %xmm3
-; SSE2-NEXT:    cvtss2sd %xmm3, %xmm3
-; SSE2-NEXT:    movlhps {{.*#+}} xmm3 = xmm3[0],xmm4[0]
+; SSE2-NEXT:    pxor %xmm3, %xmm3
+; SSE2-NEXT:    pxor %xmm1, %xmm1
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
+; SSE2-NEXT:    cvtps2pd %xmm1, %xmm4
+; SSE2-NEXT:    punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm0[4],xmm3[5],xmm0[5],xmm3[6],xmm0[6],xmm3[7],xmm0[7]
+; SSE2-NEXT:    cvtps2pd %xmm3, %xmm2
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
+; SSE2-NEXT:    cvtps2pd %xmm0, %xmm1
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm3[2,3,2,3]
+; SSE2-NEXT:    cvtps2pd %xmm0, %xmm3
+; SSE2-NEXT:    movaps %xmm4, %xmm0
 ; SSE2-NEXT:    retq
 ;
 ; F16-LABEL: pr64460_4:
@@ -1874,45 +1740,11 @@ define <8 x double> @pr64460_4(<8 x bfloat> %a) {
 ;
 ; AVXNC-LABEL: pr64460_4:
 ; AVXNC:       # %bb.0:
-; AVXNC-NEXT:    vpextrw $3, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm1
-; AVXNC-NEXT:    vcvtss2sd %xmm1, %xmm1, %xmm1
-; AVXNC-NEXT:    vpextrw $2, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm2
-; AVXNC-NEXT:    vcvtss2sd %xmm2, %xmm2, %xmm2
-; AVXNC-NEXT:    vmovlhps {{.*#+}} xmm1 = xmm2[0],xmm1[0]
-; AVXNC-NEXT:    vpextrw $1, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm2
-; AVXNC-NEXT:    vcvtss2sd %xmm2, %xmm2, %xmm2
-; AVXNC-NEXT:    vmovd %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm3
-; AVXNC-NEXT:    vcvtss2sd %xmm3, %xmm3, %xmm3
-; AVXNC-NEXT:    vmovlhps {{.*#+}} xmm2 = xmm3[0],xmm2[0]
-; AVXNC-NEXT:    vinsertf128 $1, %xmm1, %ymm2, %ymm2
-; AVXNC-NEXT:    vpextrw $7, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm1
-; AVXNC-NEXT:    vcvtss2sd %xmm1, %xmm1, %xmm1
-; AVXNC-NEXT:    vpextrw $6, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm3
-; AVXNC-NEXT:    vcvtss2sd %xmm3, %xmm3, %xmm3
-; AVXNC-NEXT:    vmovlhps {{.*#+}} xmm1 = xmm3[0],xmm1[0]
-; AVXNC-NEXT:    vpextrw $5, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm3
-; AVXNC-NEXT:    vcvtss2sd %xmm3, %xmm3, %xmm3
-; AVXNC-NEXT:    vpextrw $4, %xmm0, %eax
-; AVXNC-NEXT:    shll $16, %eax
-; AVXNC-NEXT:    vmovd %eax, %xmm0
-; AVXNC-NEXT:    vcvtss2sd %xmm0, %xmm0, %xmm0
-; AVXNC-NEXT:    vmovlhps {{.*#+}} xmm0 = xmm0[0],xmm3[0]
-; AVXNC-NEXT:    vinsertf128 $1, %xmm1, %ymm0, %ymm1
-; AVXNC-NEXT:    vmovaps %ymm2, %ymm0
+; AVXNC-NEXT:    vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
+; AVXNC-NEXT:    vpslld $16, %ymm0, %ymm1
+; AVXNC-NEXT:    vcvtps2pd %xmm1, %ymm0
+; AVXNC-NEXT:    vextracti128 $1, %ymm1, %xmm1
+; AVXNC-NEXT:    vcvtps2pd %xmm1, %ymm1
 ; AVXNC-NEXT:    retq
   %b = fpext <8 x bfloat> %a to <8 x double>
   ret <8 x double> %b
diff --git a/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll b/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll
index eff1937b593436..c079a44bc5efd5 100644
--- a/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll
+++ b/llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll
@@ -10,11 +10,11 @@ define void @test(<2 x ptr> %ptr) {
 ; CHECK-NEXT:  # %bb.2: # %loop.127.preheader
 ; CHECK-NEXT:    retq
 ; CHECK-NEXT:  .LBB0_1: # %ifmerge.89
-; CHECK-NEXT:    movzwl (%rax), %eax
-; CHECK-NEXT:    shll $16, %eax
-; CHECK-NEXT:    vmovd %eax, %xmm0
-; CHECK-NEXT:    vmulss %xmm0, %xmm0, %xmm0
 ; CHECK-NEXT:    vbroadcastss %xmm0, %xmm0
+; CHECK-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; CHECK-NEXT:    vpbroadcastw (%rax), %xmm2
+; CHECK-NEXT:    vpunpcklwd {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1],xmm1[2],xmm2[2],xmm1[3],xmm2[3]
+; CHECK-NEXT:    vmulps %xmm1, %xmm0, %xmm0
 ; CHECK-NEXT:    vmovlps %xmm0, (%rax)
 entry:
   br label %then.13

Copy link

github-actions bot commented Apr 12, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

1. Move expansion to combineFP_EXTEND to help with small vectors;
2. Combine FP_ROUND to reduce fptrunc then fpextend after promotion;
MVT TmpVT = VT.getSimpleVT().changeVectorElementType(MVT::f32);
return DAG.getNode(ISD::FP_EXTEND, dl, VT,
DAG.getNode(ISD::FP_EXTEND, dl, TmpVT, Src));
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be OK to perform on scalars too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have handled it https://godbolt.org/z/r7cjPx9xY

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - although I think eventually it'd be nice to get bf16->f32 "zext+shift" out of X86 and handle it in generic legalization

@phoebewang
Copy link
Contributor Author

LGTM - although I think eventually it'd be nice to get bf16->f32 "zext+shift" out of X86 and handle it in generic legalization

Thanks @RKSimon, the main point here is to combine "FP_ROUND+FP_EXTEND". It's a bit hard if we lower to "zext+shift" and then combine "FP_ROUND+zext+shift".

@phoebewang phoebewang merged commit 3cf8535 into llvm:main Apr 13, 2024
3 of 4 checks passed
@phoebewang phoebewang deleted the bf16 branch April 13, 2024 06:51
bazuzi pushed a commit to bazuzi/llvm-project that referenced this pull request Apr 15, 2024
1. Move expansion to combineFP_EXTEND to help with small vectors;
2. Combine FP_ROUND to reduce fptrunc then fpextend after promotion;
SDLoc dl(N);
if (SrcVT.getScalarType() == MVT::bf16) {
if (!IsStrict && Src.getOpcode() == ISD::FP_ROUND &&
Src.getOperand(0).getValueType() == VT)
Copy link
Contributor

@krzysz00 krzysz00 May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, several weeks late, but I don't think this optimization is correct even the absence of some sort of strictness mode.

That is, extend(round(x)) isn't meaningfully equal to x to a reasonable level of precision - the code I set this off with is explicitly using

; This is probably vectorized by the time it hits you
%vRef = load float, ptr %p.iter ; in a loop, etc.
%vTrunc = fptrunc float %vRef to bfloat
%vReExt = fpext bfloat %vTrunc to float
store float %vReExt, ptr %p.iter

in order to "lose" extra precision from a f32 computation being used as a reference for a bfloat one.

(The high-level structure goes like this

bfloat[N] cTest;
gpu_run(kernel, ..., cTest);
float[N] cTestExt = bfloat_to_float(cTest);
float[N] cRef;
gpu_run(refKernel, ..., cRef);
cRef = float_to_bfloat(bfloat_to_float(cRef));
test_accuracy(cTestExt, cRef, N, ...);

I'll also note that this transformation isn't, from what I can tell, present for half

Please revert this conditional specifically or justify it, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some discussions in 3cf8535

bfloat is different from half in two ways:

  • bfloat has fewer fraction bits, so precision should not be a concern like other types (even half) by design;
  • half is an IEEE type, while bfloat is not. We don't necessarily follow it;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion noted. However, I as the programmer want to explicitly perform that truncate-extend behavior in one spot in my input (because I'm testing a bfloat function whose result has been fpexted to against a floating-point version that's had its lower bits masked off).

This rewrite has caused per-element result errors around 1e-2 (if I remember right 16.25 vs 16.3125 or the like)

I understand that this intermediate elimination improves performance and is numerically useful a lot of the time, so, given that ... what mechanism would you recommend for forcing this optimization to not fire for a particular pair of round and extend operations?

(I don't want to make things strict at the function level if possible - I want to protect a particular fptrunc / fpext pair from this optimization. Would canoicalize do it, or should I stick in s more interesting noop?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what mechanism would you recommend for forcing this optimization to not fire for a particular pair of round and extend operations?

After #90836 landed, you can use __arithmetic_fence to achieve this, e.g., https://godbolt.org/z/vYr1z3h71

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what mechanism would you recommend for forcing this optimization to not fire for a particular pair of round and extend operations?

After #90836 landed, you can use __arithmetic_fence to achieve this, e.g., https://godbolt.org/z/vYr1z3h71

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants