Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86][BF16] Improve vectorization of BF16 #88486

Merged
merged 2 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21433,25 +21433,9 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
return Res;
}

if (!SVT.isVector())
if (!SVT.isVector() || SVT.getVectorElementType() == MVT::bf16)
return Op;

if (SVT.getVectorElementType() == MVT::bf16) {
// FIXME: Do we need to support strict FP?
assert(!IsStrict && "Strict FP doesn't support BF16");
if (VT.getVectorElementType() == MVT::f64) {
MVT TmpVT = VT.changeVectorElementType(MVT::f32);
return DAG.getNode(ISD::FP_EXTEND, DL, VT,
DAG.getNode(ISD::FP_EXTEND, DL, TmpVT, In));
}
assert(VT.getVectorElementType() == MVT::f32 && "Unexpected fpext");
MVT NVT = SVT.changeVectorElementType(MVT::i32);
In = DAG.getBitcast(SVT.changeTypeToInteger(), In);
In = DAG.getNode(ISD::ZERO_EXTEND, DL, NVT, In);
In = DAG.getNode(ISD::SHL, DL, NVT, In, DAG.getConstant(16, DL, NVT));
return DAG.getBitcast(VT, In);
}

if (SVT.getVectorElementType() == MVT::f16) {
if (Subtarget.hasFP16() && isTypeLegal(SVT))
return Op;
Expand Down Expand Up @@ -56517,17 +56501,40 @@ static SDValue combineFP16_TO_FP(SDNode *N, SelectionDAG &DAG,

static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
bool IsStrict = N->isStrictFPOpcode();
SDValue Src = N->getOperand(IsStrict ? 1 : 0);
EVT SrcVT = Src.getValueType();

SDLoc dl(N);
if (SrcVT.getScalarType() == MVT::bf16) {
assert(!IsStrict && "Strict FP doesn't support BF16");
if (Src.getOpcode() == ISD::FP_ROUND &&
Copy link
Contributor

@krzysz00 krzysz00 May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, several weeks late, but I don't think this optimization is correct even the absence of some sort of strictness mode.

That is, extend(round(x)) isn't meaningfully equal to x to a reasonable level of precision - the code I set this off with is explicitly using

; This is probably vectorized by the time it hits you
%vRef = load float, ptr %p.iter ; in a loop, etc.
%vTrunc = fptrunc float %vRef to bfloat
%vReExt = fpext bfloat %vTrunc to float
store float %vReExt, ptr %p.iter

in order to "lose" extra precision from a f32 computation being used as a reference for a bfloat one.

(The high-level structure goes like this

bfloat[N] cTest;
gpu_run(kernel, ..., cTest);
float[N] cTestExt = bfloat_to_float(cTest);
float[N] cRef;
gpu_run(refKernel, ..., cRef);
cRef = float_to_bfloat(bfloat_to_float(cRef));
test_accuracy(cTestExt, cRef, N, ...);

I'll also note that this transformation isn't, from what I can tell, present for half

Please revert this conditional specifically or justify it, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some discussions in 3cf8535

bfloat is different from half in two ways:

  • bfloat has fewer fraction bits, so precision should not be a concern like other types (even half) by design;
  • half is an IEEE type, while bfloat is not. We don't necessarily follow it;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion noted. However, I as the programmer want to explicitly perform that truncate-extend behavior in one spot in my input (because I'm testing a bfloat function whose result has been fpexted to against a floating-point version that's had its lower bits masked off).

This rewrite has caused per-element result errors around 1e-2 (if I remember right 16.25 vs 16.3125 or the like)

I understand that this intermediate elimination improves performance and is numerically useful a lot of the time, so, given that ... what mechanism would you recommend for forcing this optimization to not fire for a particular pair of round and extend operations?

(I don't want to make things strict at the function level if possible - I want to protect a particular fptrunc / fpext pair from this optimization. Would canoicalize do it, or should I stick in s more interesting noop?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what mechanism would you recommend for forcing this optimization to not fire for a particular pair of round and extend operations?

After #90836 landed, you can use __arithmetic_fence to achieve this, e.g., https://godbolt.org/z/vYr1z3h71

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what mechanism would you recommend for forcing this optimization to not fire for a particular pair of round and extend operations?

After #90836 landed, you can use __arithmetic_fence to achieve this, e.g., https://godbolt.org/z/vYr1z3h71

Src.getOperand(0).getValueType() == VT)
return Src.getOperand(0);

if (!SrcVT.isVector())
return SDValue();

if (VT.getVectorElementType() == MVT::f64) {
MVT TmpVT = VT.getSimpleVT().changeVectorElementType(MVT::f32);
return DAG.getNode(ISD::FP_EXTEND, dl, VT,
DAG.getNode(ISD::FP_EXTEND, dl, TmpVT, Src));
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be OK to perform on scalars too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have handled it https://godbolt.org/z/r7cjPx9xY

assert(VT.getVectorElementType() == MVT::f32 && "Unexpected fpext");
MVT NVT = SrcVT.getSimpleVT().changeVectorElementType(MVT::i32);
Src = DAG.getBitcast(SrcVT.changeTypeToInteger(), Src);
Src = DAG.getNode(ISD::ZERO_EXTEND, dl, NVT, Src);
Src = DAG.getNode(ISD::SHL, dl, NVT, Src, DAG.getConstant(16, dl, NVT));
return DAG.getBitcast(VT, Src);
}

if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
return SDValue();

if (Subtarget.hasFP16())
return SDValue();

bool IsStrict = N->isStrictFPOpcode();
EVT VT = N->getValueType(0);
SDValue Src = N->getOperand(IsStrict ? 1 : 0);
EVT SrcVT = Src.getValueType();

if (!SrcVT.isVector() || SrcVT.getVectorElementType() != MVT::f16)
return SDValue();

Expand All @@ -56539,8 +56546,6 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
if (NumElts == 1 || !isPowerOf2_32(NumElts))
return SDValue();

SDLoc dl(N);

// Convert the input to vXi16.
EVT IntVT = SrcVT.changeVectorElementTypeToInteger();
Src = DAG.getBitcast(IntVT, Src);
Expand Down
234 changes: 33 additions & 201 deletions llvm/test/CodeGen/X86/bfloat.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1629,22 +1629,8 @@ define <4 x float> @pr64460_1(<4 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_1:
; SSE2: # %bb.0:
; SSE2-NEXT: pextrw $1, %xmm0, %eax
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm2
; SSE2-NEXT: movd %xmm0, %eax
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm1
; SSE2-NEXT: pextrw $3, %xmm0, %eax
; SSE2-NEXT: shufps {{.*#+}} xmm0 = xmm0[1,1,1,1]
; SSE2-NEXT: punpckldq {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1]
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm2
; SSE2-NEXT: movd %xmm0, %eax
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm0
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
; SSE2-NEXT: pxor %xmm1, %xmm1
; SSE2-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
; SSE2-NEXT: movdqa %xmm1, %xmm0
; SSE2-NEXT: retq
;
Expand All @@ -1666,41 +1652,11 @@ define <8 x float> @pr64460_2(<8 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_2:
; SSE2: # %bb.0:
; SSE2-NEXT: movq %xmm0, %rdx
; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
; SSE2-NEXT: movq %xmm0, %rcx
; SSE2-NEXT: movq %rcx, %rax
; SSE2-NEXT: shrq $32, %rax
; SSE2-NEXT: movq %rdx, %rsi
; SSE2-NEXT: shrq $32, %rsi
; SSE2-NEXT: movl %edx, %edi
; SSE2-NEXT: andl $-65536, %edi # imm = 0xFFFF0000
; SSE2-NEXT: movd %edi, %xmm1
; SSE2-NEXT: movl %edx, %edi
; SSE2-NEXT: shll $16, %edi
; SSE2-NEXT: movd %edi, %xmm0
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
; SSE2-NEXT: shrq $48, %rdx
; SSE2-NEXT: shll $16, %edx
; SSE2-NEXT: movd %edx, %xmm1
; SSE2-NEXT: shll $16, %esi
; SSE2-NEXT: movd %esi, %xmm2
; SSE2-NEXT: punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
; SSE2-NEXT: movl %ecx, %edx
; SSE2-NEXT: andl $-65536, %edx # imm = 0xFFFF0000
; SSE2-NEXT: movd %edx, %xmm2
; SSE2-NEXT: movl %ecx, %edx
; SSE2-NEXT: shll $16, %edx
; SSE2-NEXT: movd %edx, %xmm1
; SSE2-NEXT: punpckldq {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1]
; SSE2-NEXT: shrq $48, %rcx
; SSE2-NEXT: shll $16, %ecx
; SSE2-NEXT: movd %ecx, %xmm2
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm3
; SSE2-NEXT: punpckldq {{.*#+}} xmm3 = xmm3[0],xmm2[0],xmm3[1],xmm2[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm3[0]
; SSE2-NEXT: pxor %xmm1, %xmm1
; SSE2-NEXT: pxor %xmm2, %xmm2
; SSE2-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3]
; SSE2-NEXT: punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7]
; SSE2-NEXT: movdqa %xmm2, %xmm0
; SSE2-NEXT: retq
;
; AVX-LABEL: pr64460_2:
Expand All @@ -1721,76 +1677,16 @@ define <16 x float> @pr64460_3(<16 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_3:
; SSE2: # %bb.0:
; SSE2-NEXT: movq %xmm1, %rdi
; SSE2-NEXT: punpckhqdq {{.*#+}} xmm1 = xmm1[1,1]
; SSE2-NEXT: movq %xmm1, %rcx
; SSE2-NEXT: movq %rcx, %rax
; SSE2-NEXT: shrq $32, %rax
; SSE2-NEXT: movq %xmm0, %r9
; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
; SSE2-NEXT: movq %xmm0, %rsi
; SSE2-NEXT: movq %rsi, %rdx
; SSE2-NEXT: shrq $32, %rdx
; SSE2-NEXT: movq %rdi, %r8
; SSE2-NEXT: shrq $32, %r8
; SSE2-NEXT: movq %r9, %r10
; SSE2-NEXT: shrq $32, %r10
; SSE2-NEXT: movl %r9d, %r11d
; SSE2-NEXT: andl $-65536, %r11d # imm = 0xFFFF0000
; SSE2-NEXT: movd %r11d, %xmm1
; SSE2-NEXT: movl %r9d, %r11d
; SSE2-NEXT: shll $16, %r11d
; SSE2-NEXT: movd %r11d, %xmm0
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
; SSE2-NEXT: shrq $48, %r9
; SSE2-NEXT: shll $16, %r9d
; SSE2-NEXT: movd %r9d, %xmm1
; SSE2-NEXT: shll $16, %r10d
; SSE2-NEXT: movd %r10d, %xmm2
; SSE2-NEXT: punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
; SSE2-NEXT: movl %edi, %r9d
; SSE2-NEXT: andl $-65536, %r9d # imm = 0xFFFF0000
; SSE2-NEXT: movd %r9d, %xmm1
; SSE2-NEXT: movl %edi, %r9d
; SSE2-NEXT: shll $16, %r9d
; SSE2-NEXT: movd %r9d, %xmm2
; SSE2-NEXT: punpckldq {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1]
; SSE2-NEXT: shrq $48, %rdi
; SSE2-NEXT: shll $16, %edi
; SSE2-NEXT: movd %edi, %xmm1
; SSE2-NEXT: shll $16, %r8d
; SSE2-NEXT: movd %r8d, %xmm3
; SSE2-NEXT: punpckldq {{.*#+}} xmm3 = xmm3[0],xmm1[0],xmm3[1],xmm1[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm3[0]
; SSE2-NEXT: movl %esi, %edi
; SSE2-NEXT: andl $-65536, %edi # imm = 0xFFFF0000
; SSE2-NEXT: movd %edi, %xmm3
; SSE2-NEXT: movl %esi, %edi
; SSE2-NEXT: shll $16, %edi
; SSE2-NEXT: movd %edi, %xmm1
; SSE2-NEXT: punpckldq {{.*#+}} xmm1 = xmm1[0],xmm3[0],xmm1[1],xmm3[1]
; SSE2-NEXT: shrq $48, %rsi
; SSE2-NEXT: shll $16, %esi
; SSE2-NEXT: movd %esi, %xmm3
; SSE2-NEXT: shll $16, %edx
; SSE2-NEXT: movd %edx, %xmm4
; SSE2-NEXT: punpckldq {{.*#+}} xmm4 = xmm4[0],xmm3[0],xmm4[1],xmm3[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm4[0]
; SSE2-NEXT: movl %ecx, %edx
; SSE2-NEXT: andl $-65536, %edx # imm = 0xFFFF0000
; SSE2-NEXT: movd %edx, %xmm4
; SSE2-NEXT: movl %ecx, %edx
; SSE2-NEXT: shll $16, %edx
; SSE2-NEXT: movd %edx, %xmm3
; SSE2-NEXT: punpckldq {{.*#+}} xmm3 = xmm3[0],xmm4[0],xmm3[1],xmm4[1]
; SSE2-NEXT: shrq $48, %rcx
; SSE2-NEXT: shll $16, %ecx
; SSE2-NEXT: movd %ecx, %xmm4
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm5
; SSE2-NEXT: punpckldq {{.*#+}} xmm5 = xmm5[0],xmm4[0],xmm5[1],xmm4[1]
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm3 = xmm3[0],xmm5[0]
; SSE2-NEXT: pxor %xmm3, %xmm3
; SSE2-NEXT: pxor %xmm5, %xmm5
; SSE2-NEXT: punpcklwd {{.*#+}} xmm5 = xmm5[0],xmm0[0],xmm5[1],xmm0[1],xmm5[2],xmm0[2],xmm5[3],xmm0[3]
; SSE2-NEXT: pxor %xmm4, %xmm4
; SSE2-NEXT: punpckhwd {{.*#+}} xmm4 = xmm4[4],xmm0[4],xmm4[5],xmm0[5],xmm4[6],xmm0[6],xmm4[7],xmm0[7]
; SSE2-NEXT: pxor %xmm2, %xmm2
; SSE2-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3]
; SSE2-NEXT: punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm1[4],xmm3[5],xmm1[5],xmm3[6],xmm1[6],xmm3[7],xmm1[7]
; SSE2-NEXT: movdqa %xmm5, %xmm0
; SSE2-NEXT: movdqa %xmm4, %xmm1
; SSE2-NEXT: retq
;
; F16-LABEL: pr64460_3:
Expand Down Expand Up @@ -1822,47 +1718,17 @@ define <8 x double> @pr64460_4(<8 x bfloat> %a) {
;
; SSE2-LABEL: pr64460_4:
; SSE2: # %bb.0:
; SSE2-NEXT: movq %xmm0, %rsi
; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
; SSE2-NEXT: movq %xmm0, %rdx
; SSE2-NEXT: movq %rdx, %rax
; SSE2-NEXT: shrq $32, %rax
; SSE2-NEXT: movq %rdx, %rcx
; SSE2-NEXT: shrq $48, %rcx
; SSE2-NEXT: movq %rsi, %rdi
; SSE2-NEXT: shrq $32, %rdi
; SSE2-NEXT: movq %rsi, %r8
; SSE2-NEXT: shrq $48, %r8
; SSE2-NEXT: movl %esi, %r9d
; SSE2-NEXT: andl $-65536, %r9d # imm = 0xFFFF0000
; SSE2-NEXT: movd %r9d, %xmm0
; SSE2-NEXT: cvtss2sd %xmm0, %xmm1
; SSE2-NEXT: shll $16, %esi
; SSE2-NEXT: movd %esi, %xmm0
; SSE2-NEXT: cvtss2sd %xmm0, %xmm0
; SSE2-NEXT: movlhps {{.*#+}} xmm0 = xmm0[0],xmm1[0]
; SSE2-NEXT: shll $16, %r8d
; SSE2-NEXT: movd %r8d, %xmm1
; SSE2-NEXT: cvtss2sd %xmm1, %xmm2
; SSE2-NEXT: shll $16, %edi
; SSE2-NEXT: movd %edi, %xmm1
; SSE2-NEXT: cvtss2sd %xmm1, %xmm1
; SSE2-NEXT: movlhps {{.*#+}} xmm1 = xmm1[0],xmm2[0]
; SSE2-NEXT: movl %edx, %esi
; SSE2-NEXT: andl $-65536, %esi # imm = 0xFFFF0000
; SSE2-NEXT: movd %esi, %xmm2
; SSE2-NEXT: cvtss2sd %xmm2, %xmm3
; SSE2-NEXT: shll $16, %edx
; SSE2-NEXT: movd %edx, %xmm2
; SSE2-NEXT: cvtss2sd %xmm2, %xmm2
; SSE2-NEXT: movlhps {{.*#+}} xmm2 = xmm2[0],xmm3[0]
; SSE2-NEXT: shll $16, %ecx
; SSE2-NEXT: movd %ecx, %xmm3
; SSE2-NEXT: cvtss2sd %xmm3, %xmm4
; SSE2-NEXT: shll $16, %eax
; SSE2-NEXT: movd %eax, %xmm3
; SSE2-NEXT: cvtss2sd %xmm3, %xmm3
; SSE2-NEXT: movlhps {{.*#+}} xmm3 = xmm3[0],xmm4[0]
; SSE2-NEXT: pxor %xmm3, %xmm3
; SSE2-NEXT: pxor %xmm1, %xmm1
; SSE2-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3]
; SSE2-NEXT: cvtps2pd %xmm1, %xmm4
; SSE2-NEXT: punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm0[4],xmm3[5],xmm0[5],xmm3[6],xmm0[6],xmm3[7],xmm0[7]
; SSE2-NEXT: cvtps2pd %xmm3, %xmm2
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
; SSE2-NEXT: cvtps2pd %xmm0, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm3[2,3,2,3]
; SSE2-NEXT: cvtps2pd %xmm0, %xmm3
; SSE2-NEXT: movaps %xmm4, %xmm0
; SSE2-NEXT: retq
;
; F16-LABEL: pr64460_4:
Expand All @@ -1874,45 +1740,11 @@ define <8 x double> @pr64460_4(<8 x bfloat> %a) {
;
; AVXNC-LABEL: pr64460_4:
; AVXNC: # %bb.0:
; AVXNC-NEXT: vpextrw $3, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm1
; AVXNC-NEXT: vcvtss2sd %xmm1, %xmm1, %xmm1
; AVXNC-NEXT: vpextrw $2, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm2
; AVXNC-NEXT: vcvtss2sd %xmm2, %xmm2, %xmm2
; AVXNC-NEXT: vmovlhps {{.*#+}} xmm1 = xmm2[0],xmm1[0]
; AVXNC-NEXT: vpextrw $1, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm2
; AVXNC-NEXT: vcvtss2sd %xmm2, %xmm2, %xmm2
; AVXNC-NEXT: vmovd %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm3
; AVXNC-NEXT: vcvtss2sd %xmm3, %xmm3, %xmm3
; AVXNC-NEXT: vmovlhps {{.*#+}} xmm2 = xmm3[0],xmm2[0]
; AVXNC-NEXT: vinsertf128 $1, %xmm1, %ymm2, %ymm2
; AVXNC-NEXT: vpextrw $7, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm1
; AVXNC-NEXT: vcvtss2sd %xmm1, %xmm1, %xmm1
; AVXNC-NEXT: vpextrw $6, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm3
; AVXNC-NEXT: vcvtss2sd %xmm3, %xmm3, %xmm3
; AVXNC-NEXT: vmovlhps {{.*#+}} xmm1 = xmm3[0],xmm1[0]
; AVXNC-NEXT: vpextrw $5, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm3
; AVXNC-NEXT: vcvtss2sd %xmm3, %xmm3, %xmm3
; AVXNC-NEXT: vpextrw $4, %xmm0, %eax
; AVXNC-NEXT: shll $16, %eax
; AVXNC-NEXT: vmovd %eax, %xmm0
; AVXNC-NEXT: vcvtss2sd %xmm0, %xmm0, %xmm0
; AVXNC-NEXT: vmovlhps {{.*#+}} xmm0 = xmm0[0],xmm3[0]
; AVXNC-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm1
; AVXNC-NEXT: vmovaps %ymm2, %ymm0
; AVXNC-NEXT: vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
; AVXNC-NEXT: vpslld $16, %ymm0, %ymm1
; AVXNC-NEXT: vcvtps2pd %xmm1, %ymm0
; AVXNC-NEXT: vextracti128 $1, %ymm1, %xmm1
; AVXNC-NEXT: vcvtps2pd %xmm1, %ymm1
; AVXNC-NEXT: retq
%b = fpext <8 x bfloat> %a to <8 x double>
ret <8 x double> %b
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/CodeGen/X86/concat-fpext-v2bf16.ll
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ define void @test(<2 x ptr> %ptr) {
; CHECK-NEXT: # %bb.2: # %loop.127.preheader
; CHECK-NEXT: retq
; CHECK-NEXT: .LBB0_1: # %ifmerge.89
; CHECK-NEXT: movzwl (%rax), %eax
; CHECK-NEXT: shll $16, %eax
; CHECK-NEXT: vmovd %eax, %xmm0
; CHECK-NEXT: vmulss %xmm0, %xmm0, %xmm0
; CHECK-NEXT: vbroadcastss %xmm0, %xmm0
; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1
; CHECK-NEXT: vpbroadcastw (%rax), %xmm2
; CHECK-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1],xmm1[2],xmm2[2],xmm1[3],xmm2[3]
; CHECK-NEXT: vmulps %xmm1, %xmm0, %xmm0
; CHECK-NEXT: vmovlps %xmm0, (%rax)
entry:
br label %then.13
Expand Down
Loading