Skip to content

Commit 84d92de

Browse files
VyacheslavLevytskyysys-ce-bb
authored andcommitted
add initial f16 type support for atomicrmw in llvm-spirv translator (#2210)
This PR aims to add f16 type support for atomicrmw in llvm-spirv translator, with the reference to the extension documented in [1]. There are two concerns related to the subject: SPIRVAtomicFAddEXTInst::getRequiredExtension() should return a list of required extension to support the requirement to list both SPV_EXT_shader_atomic_float16_add and SPV_EXT_shader_atomic_float_add extensions in the module (see "Extension Name" section of the ref [1]). However, the return type is std::optional<ExtensionID> and returning a vector would need a bigger rework. Including SPV_EXT_shader_atomic_float16_add into --spirv-ext argument of llvm-spirv doesn't result in producing the correspondent capability (AtomicFloat16AddEXT) and extension in a SPIRV output. $ llvm-spirv AtomicFAddEXT.ll.tmp.bc --spirv-ext=+SPV_EXT_shader_atomic_float_add,+SPV_EXT_shader_atomic_float16_add -o AtomicFAddEXT.ll.tmp.spv $ llvm-spirv -to-text AtomicFAddEXT.ll.tmp.spv -o /dev/stdout ... 2 Capability AtomicFloat32AddEXT 2 Capability AtomicFloat64AddEXT 9 Extension "SPV_EXT_shader_atomic_float_add" ... This prevents extending the test case of AtomicFAddEXT.ll in EXT/SPV_EXT_shader_atomic_float. References: [1] https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/EXT/SPV_EXT_shader_atomic_float16_add.asciidoc Original commit: KhronosGroup/SPIRV-LLVM-Translator@1aae8db
1 parent 9757b85 commit 84d92de

File tree

5 files changed

+112
-15
lines changed

5 files changed

+112
-15
lines changed

llvm-spirv/include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define EXT(X)
33
#endif
44

5+
EXT(SPV_EXT_shader_atomic_float16_add)
56
EXT(SPV_EXT_shader_atomic_float_add)
67
EXT(SPV_EXT_shader_atomic_float_min_max)
78
EXT(SPV_EXT_image_raw10_raw12)

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2746,16 +2746,22 @@ class SPIRVAtomicStoreInst : public SPIRVAtomicInstBase {
27462746
class SPIRVAtomicFAddEXTInst : public SPIRVAtomicInstBase {
27472747
public:
27482748
std::optional<ExtensionID> getRequiredExtension() const override {
2749+
assert(hasType());
2750+
if (getType()->isTypeFloat(16))
2751+
return ExtensionID::SPV_EXT_shader_atomic_float16_add;
27492752
return ExtensionID::SPV_EXT_shader_atomic_float_add;
27502753
}
27512754

27522755
SPIRVCapVec getRequiredCapability() const override {
27532756
assert(hasType());
2757+
if (getType()->isTypeFloat(16))
2758+
return {CapabilityAtomicFloat16AddEXT};
27542759
if (getType()->isTypeFloat(32))
27552760
return {CapabilityAtomicFloat32AddEXT};
2756-
assert(getType()->isTypeFloat(64) &&
2757-
"AtomicFAddEXT can only be generated for f32 or f64 types");
2758-
return {CapabilityAtomicFloat64AddEXT};
2761+
if (getType()->isTypeFloat(64))
2762+
return {CapabilityAtomicFloat64AddEXT};
2763+
llvm_unreachable(
2764+
"AtomicFAddEXT can only be generated for f16, f32, f64 types");
27592765
}
27602766
};
27612767

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,17 @@ void SPIRVModuleImpl::addExtension(ExtensionID Ext) {
655655
return;
656656
}
657657
SPIRVExt.insert(ExtName);
658+
659+
// SPV_EXT_shader_atomic_float16_add extends the
660+
// SPV_EXT_shader_atomic_float_add extension.
661+
// The specification requires both extensions to be added to use
662+
// AtomicFloat16AddEXT capability whereas getRequiredExtension()
663+
// is able to return a single extensionID.
664+
if (Ext == ExtensionID::SPV_EXT_shader_atomic_float16_add) {
665+
SPIRVMap<ExtensionID, std::string>::find(
666+
ExtensionID::SPV_EXT_shader_atomic_float_add, &ExtName);
667+
SPIRVExt.insert(ExtName);
668+
}
658669
}
659670

660671
void SPIRVModuleImpl::addCapability(SPIRVCapabilityKind Cap) {
@@ -690,7 +701,8 @@ SPIRVConstant *SPIRVModuleImpl::getLiteralAsConstant(unsigned Literal) {
690701
if (Loc != LiteralMap.end())
691702
return Loc->second;
692703
auto *Ty = addIntegerType(32);
693-
auto *V = new SPIRVConstant(this, Ty, getId(), static_cast<uint64_t>(Literal));
704+
auto *V =
705+
new SPIRVConstant(this, Ty, getId(), static_cast<uint64_t>(Literal));
694706
LiteralMap[Literal] = V;
695707
addConstant(V);
696708
return V;
@@ -1227,8 +1239,9 @@ SPIRVValue *SPIRVModuleImpl::addSpecConstantComposite(
12271239
End = ((Elements.end() - End) > MaxNumElements) ? End + MaxNumElements
12281240
: Elements.end();
12291241
Slice.assign(Start, End);
1230-
auto *Continued = static_cast<SPIRVSpecConstantComposite::ContinuedInstType>(
1231-
addSpecConstantCompositeContinuedINTEL(Slice));
1242+
auto *Continued =
1243+
static_cast<SPIRVSpecConstantComposite::ContinuedInstType>(
1244+
addSpecConstantCompositeContinuedINTEL(Slice));
12321245
Res->addContinuedInstruction(Continued);
12331246
}
12341247
return Res;
@@ -1440,7 +1453,7 @@ SPIRVValue *SPIRVModuleImpl::addAsmINTEL(SPIRVTypeFunction *TheType,
14401453
const std::string &TheInstructions,
14411454
const std::string &TheConstraints) {
14421455
auto *Asm = new SPIRVAsmINTEL(this, TheType, getId(), TheTarget,
1443-
TheInstructions, TheConstraints);
1456+
TheInstructions, TheConstraints);
14441457
return add(Asm);
14451458
}
14461459

@@ -1725,8 +1738,9 @@ SPIRVInstruction *SPIRVModuleImpl::addExpectKHRInst(SPIRVType *ResultTy,
17251738
// Create AliasDomainDeclINTEL/AliasScopeDeclINTEL/AliasScopeListDeclINTEL
17261739
// instructions
17271740
template <typename AliasingInstType>
1728-
SPIRVEntry *SPIRVModuleImpl::getOrAddMemAliasingINTELInst(
1729-
std::vector<SPIRVId> Args, llvm::MDNode *MD) {
1741+
SPIRVEntry *
1742+
SPIRVModuleImpl::getOrAddMemAliasingINTELInst(std::vector<SPIRVId> Args,
1743+
llvm::MDNode *MD) {
17301744
assert(MD && "noalias/alias.scope metadata can't be null");
17311745
// Don't duplicate aliasing instruction. For that use a map with a MDNode key
17321746
if (AliasInstMDMap.find(MD) != AliasInstMDMap.end())
@@ -1737,20 +1751,23 @@ SPIRVEntry *SPIRVModuleImpl::getOrAddMemAliasingINTELInst(
17371751
}
17381752

17391753
// Create AliasDomainDeclINTEL instruction
1740-
SPIRVEntry *SPIRVModuleImpl::getOrAddAliasDomainDeclINTELInst(
1741-
std::vector<SPIRVId> Args, llvm::MDNode *MD) {
1754+
SPIRVEntry *
1755+
SPIRVModuleImpl::getOrAddAliasDomainDeclINTELInst(std::vector<SPIRVId> Args,
1756+
llvm::MDNode *MD) {
17421757
return getOrAddMemAliasingINTELInst<SPIRVAliasDomainDeclINTEL>(Args, MD);
17431758
}
17441759

17451760
// Create AliasScopeDeclINTEL instruction
1746-
SPIRVEntry *SPIRVModuleImpl::getOrAddAliasScopeDeclINTELInst(
1747-
std::vector<SPIRVId> Args, llvm::MDNode *MD) {
1761+
SPIRVEntry *
1762+
SPIRVModuleImpl::getOrAddAliasScopeDeclINTELInst(std::vector<SPIRVId> Args,
1763+
llvm::MDNode *MD) {
17481764
return getOrAddMemAliasingINTELInst<SPIRVAliasScopeDeclINTEL>(Args, MD);
17491765
}
17501766

17511767
// Create AliasScopeListDeclINTEL instruction
1752-
SPIRVEntry *SPIRVModuleImpl::getOrAddAliasScopeListDeclINTELInst(
1753-
std::vector<SPIRVId> Args, llvm::MDNode *MD) {
1768+
SPIRVEntry *
1769+
SPIRVModuleImpl::getOrAddAliasScopeListDeclINTELInst(std::vector<SPIRVId> Args,
1770+
llvm::MDNode *MD) {
17541771
return getOrAddMemAliasingINTELInst<SPIRVAliasScopeListDeclINTEL>(Args, MD);
17551772
}
17561773

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv --spirv-ext=+SPV_EXT_shader_atomic_float16_add %t.bc -o %t.spv
3+
; RUN: spirv-val %t.spv
4+
; RUN: llvm-spirv -to-text %t.spv -o - | FileCheck %s
5+
6+
; CHECK-DAG: Extension "SPV_EXT_shader_atomic_float16_add"
7+
; CHECK-DAG: Extension "SPV_EXT_shader_atomic_float_add"
8+
; CHECK-DAG: Capability AtomicFloat16AddEXT
9+
; CHECK: TypeInt [[TypeIntID:[0-9]+]] 32 0
10+
; CHECK-DAG: Constant [[TypeIntID]] [[ScopeDevice:[0-9]+]] 1 {{$}}
11+
; CHECK-DAG: Constant [[TypeIntID]] [[MemSem_SequentiallyConsistent:[0-9]+]] 16
12+
; CHECK: TypeFloat [[TypeFloatHalfID:[0-9]+]] 16
13+
; CHECK: Variable {{[0-9]+}} [[HalfPointer:[0-9]+]]
14+
; CHECK: Constant [[TypeFloatHalfID]] [[HalfValue:[0-9]+]] 20800
15+
16+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
17+
target triple = "spir64"
18+
19+
@f = common dso_local local_unnamed_addr addrspace(1) global half 0.000000e+00, align 4
20+
21+
; Function Attrs: nounwind
22+
define dso_local spir_func void @test_atomicrmw_fadd() local_unnamed_addr #0 {
23+
entry:
24+
%0 = atomicrmw fadd ptr addrspace(1) @f, half 42.000000e+00 seq_cst
25+
; CHECK: AtomicFAddEXT [[TypeFloatHalfID]] {{[0-9]+}} [[HalfPointer]] [[ScopeDevice]] [[MemSem_SequentiallyConsistent]] [[HalfValue]]
26+
27+
ret void
28+
}
29+
30+
attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
31+
32+
!llvm.module.flags = !{!0}
33+
34+
!0 = !{i32 1, !"wchar_size", i32 4}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv --spirv-ext=+SPV_EXT_shader_atomic_float16_add %t.bc -o %t.spv
3+
; RUN: spirv-val %t.spv
4+
; RUN: llvm-spirv -to-text %t.spv -o - | FileCheck --check-prefix=CHECK-SPIRV %s
5+
6+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefixes=CHECK-LLVM
8+
9+
; CHECK-SPIRV-DAG: Extension "SPV_EXT_shader_atomic_float16_add"
10+
; CHECK-SPIRV-DAG: Extension "SPV_EXT_shader_atomic_float_add"
11+
; CHECK-SPIRV-DAG: Capability AtomicFloat16AddEXT
12+
; CHECK-SPIRV: TypeInt [[Int:[0-9]+]] 32 0
13+
; CHECK-SPIRV-DAG: Constant [[Int]] [[ScopeDevice:[0-9]+]] 1 {{$}}
14+
; CHECK-SPIRV-DAG: Constant [[Int]] [[MemSem_SequentiallyConsistent:[0-9]+]] 16
15+
; CHECK-SPIRV: TypeFloat [[Half:[0-9]+]] 16
16+
; CHECK-SPIRV: Variable {{[0-9]+}} [[HalfPointer:[0-9]+]]
17+
; CHECK-SPIRV: Constant [[Half]] [[HalfValue:[0-9]+]] 15360
18+
19+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
20+
target triple = "spir64"
21+
22+
@f = common dso_local local_unnamed_addr addrspace(1) global half 0.000000e+00, align 4
23+
24+
; Function Attrs: nounwind
25+
define dso_local spir_func void @test_atomicrmw_fadd() local_unnamed_addr #0 {
26+
entry:
27+
%0 = atomicrmw fsub ptr addrspace(1) @f, half 1.0e+00 seq_cst
28+
; CHECK-SPIRV: FNegate [[Half]] [[NegateValue:[0-9]+]] [[HalfValue]]
29+
; CHECK-SPIRV: AtomicFAddEXT [[Half]] {{[0-9]+}} [[HalfPointer]] [[ScopeDevice]] [[MemSem_SequentiallyConsistent]] [[NegateValue]]
30+
; CHECK-LLVM: [[FNegateLLVM:%[0-9]+]] = fneg half 0xH3C00
31+
; CHECK-LLVM: call spir_func half {{.*}}atomic_add{{.*}}(ptr addrspace(1) @f, half [[FNegateLLVM]])
32+
ret void
33+
}
34+
35+
attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
36+
37+
!llvm.module.flags = !{!0}
38+
39+
!0 = !{i32 1, !"wchar_size", i32 4}

0 commit comments

Comments
 (0)