Skip to content

Commit 8e8c02c

Browse files
authored
Support for SPV_INTEL_shader_atomic_bfloat16 extension (#3343)
Spec is available here: intel/llvm#20009 Author: "Ratajewski, Andrzej" <andrzej.ratajewski@intel.com> Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
1 parent 6fce8e4 commit 8e8c02c

File tree

10 files changed

+209
-3
lines changed

10 files changed

+209
-3
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,4 @@ EXT(SPV_INTEL_bfloat16_arithmetic)
8181
EXT(SPV_INTEL_ternary_bitwise_function)
8282
EXT(SPV_INTEL_int4)
8383
EXT(SPV_INTEL_function_variants)
84+
EXT(SPV_INTEL_shader_atomic_bfloat16)

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3020,40 +3020,48 @@ class SPIRVAtomicFAddEXTInst : public SPIRVAtomicInstBase {
30203020
public:
30213021
std::optional<ExtensionID> getRequiredExtension() const override {
30223022
assert(hasType());
3023+
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
3024+
return ExtensionID::SPV_INTEL_shader_atomic_bfloat16;
30233025
if (getType()->isTypeFloat(16))
30243026
return ExtensionID::SPV_EXT_shader_atomic_float16_add;
30253027
return ExtensionID::SPV_EXT_shader_atomic_float_add;
30263028
}
30273029

30283030
SPIRVCapVec getRequiredCapability() const override {
30293031
assert(hasType());
3032+
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
3033+
return {internal::CapabilityAtomicBFloat16AddINTEL};
30303034
if (getType()->isTypeFloat(16))
30313035
return {CapabilityAtomicFloat16AddEXT};
30323036
if (getType()->isTypeFloat(32))
30333037
return {CapabilityAtomicFloat32AddEXT};
30343038
if (getType()->isTypeFloat(64))
30353039
return {CapabilityAtomicFloat64AddEXT};
30363040
llvm_unreachable(
3037-
"AtomicFAddEXT can only be generated for f16, f32, f64 types");
3041+
"AtomicFAddEXT can only be generated for bf16, f16, f32, f64 types");
30383042
}
30393043
};
30403044

30413045
class SPIRVAtomicFMinMaxEXTBase : public SPIRVAtomicInstBase {
30423046
public:
30433047
std::optional<ExtensionID> getRequiredExtension() const override {
3048+
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
3049+
return ExtensionID::SPV_INTEL_shader_atomic_bfloat16;
30443050
return ExtensionID::SPV_EXT_shader_atomic_float_min_max;
30453051
}
30463052

30473053
SPIRVCapVec getRequiredCapability() const override {
30483054
assert(hasType());
3055+
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
3056+
return {internal::CapabilityAtomicBFloat16MinMaxINTEL};
30493057
if (getType()->isTypeFloat(16))
30503058
return {CapabilityAtomicFloat16MinMaxEXT};
30513059
if (getType()->isTypeFloat(32))
30523060
return {CapabilityAtomicFloat32MinMaxEXT};
30533061
if (getType()->isTypeFloat(64))
30543062
return {CapabilityAtomicFloat64MinMaxEXT};
3055-
llvm_unreachable(
3056-
"AtomicF(Min|Max)EXT can only be generated for f16, f32, f64 types");
3063+
llvm_unreachable("AtomicF(Min|Max)EXT can only be generated for bf16, f16, "
3064+
"f32, f64 types");
30573065
}
30583066
};
30593067

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,9 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
625625
add(CapabilityLongCompositesINTEL, "LongCompositesINTEL");
626626
add(CapabilityOptNoneEXT, "OptNoneEXT");
627627
add(CapabilityAtomicFloat16AddEXT, "AtomicFloat16AddEXT");
628+
add(internal::CapabilityAtomicBFloat16AddINTEL, "AtomicBFloat16AddINTEL");
629+
add(internal::CapabilityAtomicBFloat16MinMaxINTEL,
630+
"AtomicBFloat16MinMaxINTEL");
628631
add(CapabilityDebugInfoModuleINTEL, "DebugInfoModuleINTEL");
629632
add(CapabilityBFloat16ConversionINTEL, "Bfloat16ConversionINTEL");
630633
add(CapabilitySplitBarrierINTEL, "SplitBarrierINTEL");

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ enum InternalCapability {
108108
ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192,
109109
ICapabilityBFloat16ArithmeticINTEL = 6226,
110110
ICapabilityCooperativeMatrixOffsetInstructionsINTEL = 6238,
111+
ICapabilityAtomicBFloat16AddINTEL = 6255,
112+
ICapabilityAtomicBFloat16MinMaxINTEL = 6256,
111113
ICapabilityCooperativeMatrixPrefetchINTEL = 6411,
112114
ICapabilityMaskedGatherScatterINTEL = 6427,
113115
ICapabilityJointMatrixWIInstructionsINTEL = 6435,
@@ -203,6 +205,9 @@ _SPIRV_OP(Capability, BindlessImagesINTEL)
203205
_SPIRV_OP(Op, ConvertHandleToImageINTEL)
204206
_SPIRV_OP(Op, ConvertHandleToSamplerINTEL)
205207
_SPIRV_OP(Op, ConvertHandleToSampledImageINTEL)
208+
209+
_SPIRV_OP(Capability, AtomicBFloat16AddINTEL)
210+
_SPIRV_OP(Capability, AtomicBFloat16MinMaxINTEL)
206211
#undef _SPIRV_OP
207212

208213
constexpr SourceLanguage SourceLanguagePython =
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 -o %t.spv
3+
; RUN: llvm-spirv -to-text %t.spv -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv --spirv-target-env=SPV-IR -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefixes=CHECK-LLVM-SPV
8+
9+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
10+
target triple = "spir64-unknown-unknown"
11+
12+
; CHECK-SPIRV-DAG: Capability AtomicBFloat16AddINTEL
13+
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
14+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
15+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"
16+
17+
; CHECK-SPIRV: TypeFloat [[BFLOAT:[0-9]+]] 16 0
18+
19+
; Function Attrs: convergent norecurse nounwind
20+
define dso_local spir_func bfloat @test_AtomicFAddEXT_bfloat(ptr addrspace(4) align 2 dereferenceable(4) %Arg) {
21+
entry:
22+
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
23+
; CHECK-SPIRV: AtomicFAddEXT [[BFLOAT]]
24+
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
25+
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
26+
ret bfloat %ret
27+
}
28+
29+
; Function Attrs: convergent
30+
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 -o %t.spv
3+
; RUN: llvm-spirv -to-text %t.spv -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv --spirv-target-env=SPV-IR -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefixes=CHECK-LLVM-SPV
8+
9+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
10+
target triple = "spir64-unknown-unknown"
11+
12+
; CHECK-SPIRV-DAG: Capability AtomicBFloat16MinMaxINTEL
13+
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
14+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
15+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"
16+
17+
; CHECK-SPIRV: TypeFloat [[BFLOAT:[0-9]+]] 16 0
18+
19+
; Function Attrs: convergent norecurse nounwind
20+
define dso_local spir_func bfloat @test_AtomicFMaxEXT_bfloat(ptr addrspace(4) align 2 dereferenceable(4) %Arg) {
21+
entry:
22+
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
23+
; CHECK-SPIRV: AtomicFMaxEXT [[BFLOAT]]
24+
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
25+
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
26+
ret bfloat %ret
27+
}
28+
29+
; Function Attrs: convergent
30+
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 -o %t.spv
3+
; RUN: llvm-spirv -to-text %t.spv -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv --spirv-target-env=SPV-IR -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefixes=CHECK-LLVM-SPV
8+
9+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
10+
target triple = "spir64-unknown-unknown"
11+
12+
; CHECK-SPIRV-DAG: Capability AtomicBFloat16MinMaxINTEL
13+
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
14+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
15+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"
16+
17+
; CHECK-SPIRV: TypeFloat [[BFLOAT:[0-9]+]] 16 0
18+
19+
; Function Attrs: convergent norecurse nounwind
20+
define dso_local spir_func bfloat @test_AtomicFMinEXT_bfloat(ptr addrspace(4) align 2 dereferenceable(4) %Arg) {
21+
entry:
22+
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
23+
; CHECK-SPIRV: AtomicFMinEXT [[BFLOAT]]
24+
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
25+
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
26+
ret bfloat %ret
27+
}
28+
29+
; Function Attrs: convergent
30+
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 %t.bc -o %t.spv
3+
; RUN: llvm-spirv -to-text %t.spv -o - | FileCheck %s
4+
5+
; CHECK-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
6+
; CHECK-DAG: Extension "SPV_KHR_bfloat16"
7+
; CHECK-DAG: Capability AtomicBFloat16AddINTEL
8+
; CHECK-DAG: Capability BFloat16TypeKHR
9+
; CHECK: TypeInt [[Int:[0-9]+]] 32 0
10+
; CHECK-DAG: Constant [[Int]] [[Scope_CrossDevice:[0-9]+]] 0 {{$}}
11+
; CHECK-DAG: Constant [[Int]] [[MemSem_SequentiallyConsistent:[0-9]+]] 16
12+
; CHECK: TypeFloat [[BFloat:[0-9]+]] 16 0
13+
; CHECK: Variable {{[0-9]+}} [[BFloatPointer:[0-9]+]]
14+
; CHECK: Constant [[BFloat]] [[BFloatValue:[0-9]+]] 16936
15+
16+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
17+
target triple = "spir64"
18+
19+
@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
20+
21+
; Function Attrs: nounwind
22+
define dso_local spir_func void @test_atomicrmw_fadd() local_unnamed_addr #0 {
23+
entry:
24+
%0 = atomicrmw fadd ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
25+
; CHECK: AtomicFAddEXT [[BFloat]] {{[0-9]+}} [[BFloatPointer]] [[Scope_CrossDevice]] [[MemSem_SequentiallyConsistent]] [[BFloatValue]]
26+
27+
ret void
28+
}
29+
30+
attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
31+
32+
!llvm.module.flags = !{!0}
33+
34+
!0 = !{i32 1, !"wchar_size", i32 4}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 %t.bc -o %t.spv
3+
; RUN: llvm-spirv -to-text %t.spv -o - | FileCheck %s
4+
5+
; CHECK-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
6+
; CHECK-DAG: Extension "SPV_KHR_bfloat16"
7+
; CHECK-DAG: AtomicBFloat16MinMaxINTEL
8+
; CHECK-DAG: Capability BFloat16TypeKHR
9+
; CHECK: TypeInt [[Int:[0-9]+]] 32 0
10+
; CHECK-DAG: Constant [[Int]] [[Scope_CrossDevice:[0-9]+]] 0 {{$}}
11+
; CHECK-DAG: Constant [[Int]] [[MemSem_SequentiallyConsistent:[0-9]+]] 16
12+
; CHECK: TypeFloat [[BFloat:[0-9]+]] 16 0
13+
; CHECK: Variable {{[0-9]+}} [[BFloatPointer:[0-9]+]]
14+
; CHECK: Constant [[BFloat]] [[BFloatValue:[0-9]+]] 16936
15+
16+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
17+
target triple = "spir64"
18+
19+
@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 4
20+
21+
; Function Attrs: nounwind
22+
define dso_local spir_func void @test_atomicrmw_fadd() local_unnamed_addr #0 {
23+
entry:
24+
%0 = atomicrmw fmin ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
25+
; CHECK: AtomicFMinEXT [[BFloat]] {{[0-9]+}} [[BFloatPointer]] [[Scope_CrossDevice]] [[MemSem_SequentiallyConsistent]] [[BFloatValue]]
26+
%1 = atomicrmw fmax ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
27+
; CHECK: AtomicFMaxEXT [[BFloat]] {{[0-9]+}} [[BFloatPointer]] [[Scope_CrossDevice]] [[MemSem_SequentiallyConsistent]] [[BFloatValue]]
28+
29+
ret void
30+
}
31+
32+
attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
33+
34+
!llvm.module.flags = !{!0}
35+
36+
!0 = !{i32 1, !"wchar_size", i32 4}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: not llvm-spirv --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16 %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-NO-BF
3+
; RUN: not llvm-spirv --spirv-ext=+SPV_KHR_bfloat16 %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ATOM
4+
5+
; CHECK-NO-BF: RequiresExtension: Feature requires the following SPIR-V extension:
6+
; CHECK-NO-BF-NEXT: SPV_KHR_bfloat16
7+
; CHECK-NO-BF-NEXT: NOTE: LLVM module contains bfloat type, translation of which requires this extension
8+
9+
; CHECK-NO-ATOM: RequiresExtension: Feature requires the following SPIR-V extension:
10+
; CHECK-NO-ATOM-NEXT: SPV_INTEL_shader_atomic_bfloat16
11+
12+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
13+
target triple = "spir64"
14+
15+
@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
16+
17+
; Function Attrs: nounwind
18+
define dso_local spir_func void @test_atomicrmw_fadd() local_unnamed_addr #0 {
19+
entry:
20+
%0 = atomicrmw fadd ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
21+
22+
ret void
23+
}
24+
25+
attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
26+
27+
!llvm.module.flags = !{!0}
28+
29+
!0 = !{i32 1, !"wchar_size", i32 4}

0 commit comments

Comments
 (0)