From 0443c0860e98aef51125d5767d321f8a7b7c2106 Mon Sep 17 00:00:00 2001 From: Phoebe Wang Date: Sat, 7 Oct 2023 14:29:23 +0800 Subject: [PATCH 1/3] [X86] Enable bfloat type support in inline assembly constraints Similar to FP16 but we don't have native scalar instruction support, so limit it to vector types only. Fixes #68149 --- llvm/lib/Target/X86/X86ISelLowering.cpp | 24 +++++++++++++++++++ .../X86/inline-asm-avx512f-x-constraint.ll | 13 +++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c4cd2a672fe7b..c0e93da877a8a 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -56904,6 +56904,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, if (!Subtarget.hasFP16()) break; [[fallthrough]]; + case MVT::v8bf16: + if (!Subtarget.hasBF16()) + break; + [[fallthrough]]; case MVT::f128: case MVT::v16i8: case MVT::v8i16: @@ -56919,6 +56923,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, if (!Subtarget.hasFP16()) break; [[fallthrough]]; + case MVT::v16bf16: + if (!Subtarget.hasBF16()) + break; + [[fallthrough]]; case MVT::v32i8: case MVT::v16i16: case MVT::v8i32: @@ -56934,6 +56942,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, if (!Subtarget.hasFP16()) break; [[fallthrough]]; + case MVT::v32bf16: + if (!Subtarget.hasBF16()) + break; + [[fallthrough]]; case MVT::v64i8: case MVT::v32i16: case MVT::v8f64: @@ -56977,6 +56989,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, if (!Subtarget.hasFP16()) break; [[fallthrough]]; + case MVT::v8bf16: + if (!Subtarget.hasBF16()) + break; + [[fallthrough]]; case MVT::f128: case MVT::v16i8: case MVT::v8i16: @@ -56990,6 +57006,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, if (!Subtarget.hasFP16()) break; [[fallthrough]]; + case MVT::v16bf16: + if (!Subtarget.hasBF16()) + break; + [[fallthrough]]; case MVT::v32i8: case MVT::v16i16: case MVT::v8i32: @@ -57003,6 +57023,10 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, if (!Subtarget.hasFP16()) break; [[fallthrough]]; + case MVT::v32bf16: + if (!Subtarget.hasBF16()) + break; + [[fallthrough]]; case MVT::v64i8: case MVT::v32i16: case MVT::v8f64: diff --git a/llvm/test/CodeGen/X86/inline-asm-avx512f-x-constraint.ll b/llvm/test/CodeGen/X86/inline-asm-avx512f-x-constraint.ll index fcea55c47cd3e..e153387d16e72 100644 --- a/llvm/test/CodeGen/X86/inline-asm-avx512f-x-constraint.ll +++ b/llvm/test/CodeGen/X86/inline-asm-avx512f-x-constraint.ll @@ -1,7 +1,7 @@ ; RUN: not llc < %s -mtriple=x86_64-unknown-unknown -mattr=avx512f -stop-after=finalize-isel > %t 2> %t.err ; RUN: FileCheck < %t %s ; RUN: FileCheck --check-prefix=CHECK-STDERR < %t.err %s -; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=avx512fp16 -stop-after=finalize-isel | FileCheck --check-prefixes=CHECK,FP16 %s +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=avx512bf16,avx512fp16 -stop-after=finalize-isel | FileCheck --check-prefixes=CHECK,FP16 %s ; CHECK-LABEL: name: mask_Yk_i8 ; CHECK: %[[REG1:.*]]:vr512_0_15 = COPY %1 @@ -24,3 +24,14 @@ entry: %0 = tail call <32 x half> asm "vaddph\09$3, $2, $0 {$1}", "=x,^Yk,x,x,~{dirflag},~{fpsr},~{flags}"(i8 %msk, <32 x half> %x, <32 x half> %y) ret <32 x half> %0 } + +; FP16-LABEL: name: mask_Yk_bf16 +; FP16: %[[REG1:.*]]:vr512_0_15 = COPY %1 +; FP16: %[[REG2:.*]]:vr512_0_15 = COPY %2 +; FP16: INLINEASM &"vaddph\09$3, $2, $0 {$1}", 0 /* attdialect */, {{.*}}, def %{{.*}}, {{.*}}, %{{.*}}, {{.*}}, %[[REG1]], {{.*}}, %[[REG2]], 12 /* clobber */, implicit-def early-clobber $df, 12 /* clobber */, implicit-def early-clobber $fpsw, 12 /* clobber */, implicit-def early-clobber $eflags +; CHECK-STDERR: couldn't allocate output register for constraint 'x' +define <32 x bfloat> @mask_Yk_bf16(i8 signext %msk, <32 x bfloat> %x, <32 x bfloat> %y) { +entry: + %0 = tail call <32 x bfloat> asm "vaddph\09$3, $2, $0 {$1}", "=x,^Yk,x,x,~{dirflag},~{fpsr},~{flags}"(i8 %msk, <32 x bfloat> %x, <32 x bfloat> %y) + ret <32 x bfloat> %0 +} From ee98b7dd9df4ab6d1afe187bc900cdb7f8ca5460 Mon Sep 17 00:00:00 2001 From: Phoebe Wang Date: Mon, 9 Oct 2023 19:50:50 +0800 Subject: [PATCH 2/3] Do not use [[fallthrough]] --- llvm/lib/Target/X86/X86ISelLowering.cpp | 36 ++++++++++++++++--------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c0e93da877a8a..6a9f39ada651c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -56903,11 +56903,15 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, case MVT::v8f16: if (!Subtarget.hasFP16()) break; - [[fallthrough]]; + if (VConstraint) + return std::make_pair(0U, &X86::VR128XRegClass); + return std::make_pair(0U, &X86::VR128RegClass); case MVT::v8bf16: if (!Subtarget.hasBF16()) break; - [[fallthrough]]; + if (VConstraint) + return std::make_pair(0U, &X86::VR128XRegClass); + return std::make_pair(0U, &X86::VR128RegClass); case MVT::f128: case MVT::v16i8: case MVT::v8i16: @@ -56922,11 +56926,15 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, case MVT::v16f16: if (!Subtarget.hasFP16()) break; - [[fallthrough]]; + if (VConstraint) + return std::make_pair(0U, &X86::VR256XRegClass); + return std::make_pair(0U, &X86::VR256RegClass); case MVT::v16bf16: if (!Subtarget.hasBF16()) break; - [[fallthrough]]; + if (VConstraint) + return std::make_pair(0U, &X86::VR256XRegClass); + return std::make_pair(0U, &X86::VR256RegClass); case MVT::v32i8: case MVT::v16i16: case MVT::v8i32: @@ -56941,11 +56949,15 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, case MVT::v32f16: if (!Subtarget.hasFP16()) break; - [[fallthrough]]; + if (VConstraint) + return std::make_pair(0U, &X86::VR512RegClass); + return std::make_pair(0U, &X86::VR512_0_15RegClass); case MVT::v32bf16: if (!Subtarget.hasBF16()) break; - [[fallthrough]]; + if (VConstraint) + return std::make_pair(0U, &X86::VR512RegClass); + return std::make_pair(0U, &X86::VR512_0_15RegClass); case MVT::v64i8: case MVT::v32i16: case MVT::v8f64: @@ -56988,11 +57000,11 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, case MVT::v8f16: if (!Subtarget.hasFP16()) break; - [[fallthrough]]; + return std::make_pair(X86::XMM0, &X86::VR128RegClass); case MVT::v8bf16: if (!Subtarget.hasBF16()) break; - [[fallthrough]]; + return std::make_pair(X86::XMM0, &X86::VR128RegClass); case MVT::f128: case MVT::v16i8: case MVT::v8i16: @@ -57005,11 +57017,11 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, case MVT::v16f16: if (!Subtarget.hasFP16()) break; - [[fallthrough]]; + return std::make_pair(X86::YMM0, &X86::VR256RegClass); case MVT::v16bf16: if (!Subtarget.hasBF16()) break; - [[fallthrough]]; + return std::make_pair(X86::YMM0, &X86::VR256RegClass); case MVT::v32i8: case MVT::v16i16: case MVT::v8i32: @@ -57022,11 +57034,11 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, case MVT::v32f16: if (!Subtarget.hasFP16()) break; - [[fallthrough]]; + return std::make_pair(X86::ZMM0, &X86::VR512_0_15RegClass); case MVT::v32bf16: if (!Subtarget.hasBF16()) break; - [[fallthrough]]; + return std::make_pair(X86::ZMM0, &X86::VR512_0_15RegClass); case MVT::v64i8: case MVT::v32i16: case MVT::v8f64: From 37065c223ac8036155fb50386b980e45b2b8cfaf Mon Sep 17 00:00:00 2001 From: Phoebe Wang Date: Tue, 17 Oct 2023 11:38:04 +0800 Subject: [PATCH 3/3] Check hasVLX for BF16 --- llvm/lib/Target/X86/X86ISelLowering.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 6a9f39ada651c..bc9368f327d0c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -56907,7 +56907,7 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, return std::make_pair(0U, &X86::VR128XRegClass); return std::make_pair(0U, &X86::VR128RegClass); case MVT::v8bf16: - if (!Subtarget.hasBF16()) + if (!Subtarget.hasBF16() || !Subtarget.hasVLX()) break; if (VConstraint) return std::make_pair(0U, &X86::VR128XRegClass); @@ -56930,7 +56930,7 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, return std::make_pair(0U, &X86::VR256XRegClass); return std::make_pair(0U, &X86::VR256RegClass); case MVT::v16bf16: - if (!Subtarget.hasBF16()) + if (!Subtarget.hasBF16() || !Subtarget.hasVLX()) break; if (VConstraint) return std::make_pair(0U, &X86::VR256XRegClass); @@ -57002,7 +57002,7 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, break; return std::make_pair(X86::XMM0, &X86::VR128RegClass); case MVT::v8bf16: - if (!Subtarget.hasBF16()) + if (!Subtarget.hasBF16() || !Subtarget.hasVLX()) break; return std::make_pair(X86::XMM0, &X86::VR128RegClass); case MVT::f128: @@ -57019,7 +57019,7 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, break; return std::make_pair(X86::YMM0, &X86::VR256RegClass); case MVT::v16bf16: - if (!Subtarget.hasBF16()) + if (!Subtarget.hasBF16() || !Subtarget.hasVLX()) break; return std::make_pair(X86::YMM0, &X86::VR256RegClass); case MVT::v32i8: