Skip to content

Commit

Permalink
Support the SPV_EXT_shader_atomic_float_min_max extension (#916)
Browse files Browse the repository at this point in the history
* Support the SPV_EXT_shader_atomic_float_min_max extension

As with other atomic instructions, we're looking to translate
an LLVM IR call akin to `__spirv_AtomicFMaxEXT()` into a
corresponding SPIR-V OpCode akin to `AtomicFMaxEXT`.

Extension specification can be found at
https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/EXT/SPV_EXT_shader_atomic_float_min_max.asciidoc

Support in SPIR-V Header was added with
KhronosGroup/SPIRV-Headers#189.

Signed-off-by: Artem Gindinson <artem.gindinson@intel.com>
  • Loading branch information
Artem Gindinson authored Feb 16, 2021
1 parent 20ba782 commit d1213fe
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#endif

EXT(SPV_EXT_shader_atomic_float_add)
EXT(SPV_EXT_shader_atomic_float_min_max)
EXT(SPV_KHR_no_integer_wrap_decoration)
EXT(SPV_KHR_float_controls)
EXT(SPV_INTEL_subgroups)
Expand Down
21 changes: 21 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2773,6 +2773,25 @@ class SPIRVAtomicFAddEXTInst : public SPIRVAtomicInstBase {
}
};

class SPIRVAtomicFMinMaxEXTBase : public SPIRVAtomicInstBase {
public:
llvm::Optional<ExtensionID> getRequiredExtension() const override {
return ExtensionID::SPV_EXT_shader_atomic_float_min_max;
}

SPIRVCapVec getRequiredCapability() const override {
assert(hasType());
if (getType()->isTypeFloat(16))
return {CapabilityAtomicFloat16MinMaxEXT};
if (getType()->isTypeFloat(32))
return {CapabilityAtomicFloat32MinMaxEXT};
if (getType()->isTypeFloat(64))
return {CapabilityAtomicFloat64MinMaxEXT};
llvm_unreachable(
"AtomicF(Min|Max)EXT can only be generated for f16, f32, f64 types");
}
};

#define _SPIRV_OP(x, ...) \
typedef SPIRVInstTemplate<SPIRVAtomicInstBase, Op##x, __VA_ARGS__> SPIRV##x;
// Atomic builtins
Expand Down Expand Up @@ -2800,6 +2819,8 @@ _SPIRV_OP(MemoryBarrier, false, 3)
// Specialized atomic builtins
_SPIRV_OP(AtomicStore, AtomicStoreInst, false, 5)
_SPIRV_OP(AtomicFAddEXT, AtomicFAddEXTInst, true, 7)
_SPIRV_OP(AtomicFMinEXT, AtomicFMinMaxEXTBase, true, 7)
_SPIRV_OP(AtomicFMaxEXT, AtomicFMinMaxEXTBase, true, 7)
#undef _SPIRV_OP

class SPIRVImageInstBase : public SPIRVInstTemplateBase {
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
"PhysicalStorageBufferAddressesEXT");
add(CapabilityAtomicFloat32AddEXT, "AtomicFloat32AddEXT");
add(CapabilityAtomicFloat64AddEXT, "AtomicFloat64AddEXT");
add(CapabilityAtomicFloat32MinMaxEXT, "AtomicFloat32MinMaxEXT");
add(CapabilityAtomicFloat64MinMaxEXT, "AtomicFloat64MinMaxEXT");
add(CapabilityAtomicFloat16MinMaxEXT, "AtomicFloat16MinMaxEXT");
add(CapabilityComputeDerivativeGroupLinearNV,
"ComputeDerivativeGroupLinearNV");
add(CapabilityCooperativeMatrixNV, "CooperativeMatrixNV");
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ _SPIRV_OP(FunctionPointerCallINTEL, 5601)
_SPIRV_OP(AsmTargetINTEL, 5609)
_SPIRV_OP(AsmINTEL, 5610)
_SPIRV_OP(AsmCallINTEL, 5611)
_SPIRV_OP(AtomicFMinEXT, 5614)
_SPIRV_OP(AtomicFMaxEXT, 5615)
_SPIRV_OP(VmeImageINTEL, 5699)
_SPIRV_OP(TypeVmeImageINTEL, 5700)
_SPIRV_OP(TypeAvcImePayloadINTEL, 5701)
Expand Down
7 changes: 7 additions & 0 deletions lib/SPIRV/libSPIRV/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,9 @@ enum Capability {
CapabilityFunctionPointersINTEL = 5603,
CapabilityIndirectReferencesINTEL = 5604,
CapabilityAsmINTEL = 5606,
CapabilityAtomicFloat32MinMaxEXT = 5612,
CapabilityAtomicFloat64MinMaxEXT = 5613,
CapabilityAtomicFloat16MinMaxEXT = 5616,
CapabilityVectorComputeINTEL = 5617,
CapabilityVectorAnyINTEL = 5619,
CapabilitySubgroupAvcMotionEstimationINTEL = 5696,
Expand Down Expand Up @@ -1538,6 +1541,8 @@ enum Op {
OpAsmTargetINTEL = 5609,
OpAsmINTEL = 5610,
OpAsmCallINTEL = 5611,
OpAtomicFMinEXT = 5614,
OpAtomicFMaxEXT = 5615,
OpDecorateString = 5632,
OpDecorateStringGOOGLE = 5632,
OpMemberDecorateString = 5633,
Expand Down Expand Up @@ -2175,6 +2180,8 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) {
case OpAsmTargetINTEL: *hasResult = true; *hasResultType = true; break;
case OpAsmINTEL: *hasResult = true; *hasResultType = true; break;
case OpAsmCallINTEL: *hasResult = true; *hasResultType = true; break;
case OpAtomicFMinEXT: *hasResult = true; *hasResultType = true; break;
case OpAtomicFMaxEXT: *hasResult = true; *hasResultType = true; break;
case OpDecorateString: *hasResult = false; *hasResultType = false; break;
case OpMemberDecorateString: *hasResult = false; *hasResultType = false; break;
case OpVmeImageINTEL: *hasResult = true; *hasResultType = true; break;
Expand Down
123 changes: 123 additions & 0 deletions test/AtomicFMaxEXT.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_EXT_shader_atomic_float_min_max -o %t.spv
; RUN: llvm-spirv -to-text %t.spv -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV

; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown-sycldevice"

%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range" = type { %"class._ZTSN2cl4sycl6detail5arrayILi1EEE.cl::sycl::detail::array" }
%"class._ZTSN2cl4sycl6detail5arrayILi1EEE.cl::sycl::detail::array" = type { [1 x i64] }
%"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id" = type { %"class._ZTSN2cl4sycl6detail5arrayILi1EEE.cl::sycl::detail::array" }

$_ZTSZZ8max_testIfEvN2cl4sycl5queueEmENKUlRNS1_7handlerEE16_14clES4_EUlNS1_4itemILi1ELb1EEEE19_37 = comdat any

$_ZTSZZ8max_testIdEvN2cl4sycl5queueEmENKUlRNS1_7handlerEE16_14clES4_EUlNS1_4itemILi1ELb1EEEE19_37 = comdat any

@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

; CHECK-SPIRV: Capability AtomicFloat32MinMaxEXT
; CHECK-SPIRV: Capability AtomicFloat64MinMaxEXT
; CHECK-SPIRV: Extension "SPV_EXT_shader_atomic_float_min_max"
; CHECK-SPIRV: TypeFloat [[TYPE_FLOAT_32:[0-9]+]] 32
; CHECK-SPIRV: TypeFloat [[TYPE_FLOAT_64:[0-9]+]] 64

; Function Attrs: convergent norecurse
define weak_odr dso_local spir_kernel void @_ZTSZZ8max_testIfEvN2cl4sycl5queueEmENKUlRNS1_7handlerEE16_14clES4_EUlNS1_4itemILi1ELb1EEEE19_37(float addrspace(1)* %_arg_, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_1, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_2, %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* byval(%"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id") align 8 %_arg_3, float addrspace(1)* %_arg_4, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_6, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_7, %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* byval(%"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id") align 8 %_arg_8) local_unnamed_addr #0 comdat !kernel_arg_buffer_location !4 {
entry:
%0 = getelementptr inbounds %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id", %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* %_arg_3, i64 0, i32 0, i32 0, i64 0
%1 = load i64, i64* %0, align 8
%add.ptr.i29 = getelementptr inbounds float, float addrspace(1)* %_arg_, i64 %1
%2 = getelementptr inbounds %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id", %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* %_arg_8, i64 0, i32 0, i32 0, i64 0
%3 = load i64, i64* %2, align 8
%add.ptr.i = getelementptr inbounds float, float addrspace(1)* %_arg_4, i64 %3
%4 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId to <3 x i64> addrspace(4)*), align 32, !noalias !5
%5 = extractelement <3 x i64> %4, i64 0
%conv.i = trunc i64 %5 to i32
%conv3.i = sitofp i32 %conv.i to float
%add.i = fadd float %conv3.i, 1.000000e+00
; CHECK-SPIRV: 7 AtomicFMaxEXT [[TYPE_FLOAT_32]]
; CHECK-LLVM: call spir_func float @[[FLOAT_FUNC_NAME:_Z21__spirv_AtomicFMaxEXT[[:alnum:]]+]]({{.*}})
%call3.i.i.i = tail call spir_func float @_Z21__spirv_AtomicFMaxEXTPU3AS1fN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagEf(float addrspace(1)* %add.ptr.i29, i32 1, i32 896, float %add.i) #2
%sext.i = shl i64 %5, 32
%conv6.i = ashr exact i64 %sext.i, 32
%ptridx.i.i = getelementptr inbounds float, float addrspace(1)* %add.ptr.i, i64 %conv6.i
%ptridx.ascast.i.i = addrspacecast float addrspace(1)* %ptridx.i.i to float addrspace(4)*
store float %call3.i.i.i, float addrspace(4)* %ptridx.ascast.i.i, align 4, !tbaa !14
ret void
}

; Function Attrs: convergent
; CHECK-LLVM: declare {{.*}}spir_func float @[[FLOAT_FUNC_NAME]](float addrspace(1)*, i32, i32, float)
declare dso_local spir_func float @_Z21__spirv_AtomicFMaxEXTPU3AS1fN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagEf(float addrspace(1)*, i32, i32, float) local_unnamed_addr #1

; Function Attrs: convergent norecurse
define weak_odr dso_local spir_kernel void @_ZTSZZ8max_testIdEvN2cl4sycl5queueEmENKUlRNS1_7handlerEE16_14clES4_EUlNS1_4itemILi1ELb1EEEE19_37(double addrspace(1)* %_arg_, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_1, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_2, %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* byval(%"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id") align 8 %_arg_3, double addrspace(1)* %_arg_4, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_6, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_7, %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* byval(%"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id") align 8 %_arg_8) local_unnamed_addr #0 comdat !kernel_arg_buffer_location !4 {
entry:
%0 = getelementptr inbounds %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id", %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* %_arg_3, i64 0, i32 0, i32 0, i64 0
%1 = load i64, i64* %0, align 8
%add.ptr.i29 = getelementptr inbounds double, double addrspace(1)* %_arg_, i64 %1
%2 = getelementptr inbounds %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id", %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* %_arg_8, i64 0, i32 0, i32 0, i64 0
%3 = load i64, i64* %2, align 8
%add.ptr.i = getelementptr inbounds double, double addrspace(1)* %_arg_4, i64 %3
%4 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId to <3 x i64> addrspace(4)*), align 32, !noalias !18
%5 = extractelement <3 x i64> %4, i64 0
%conv.i = trunc i64 %5 to i32
%conv3.i = sitofp i32 %conv.i to double
%add.i = fadd double %conv3.i, 1.000000e+00
; CHECK-SPIRV: 7 AtomicFMaxEXT [[TYPE_FLOAT_64]]
; CHECK-LLVM: call spir_func double @[[DOUBLE_FUNC_NAME:_Z21__spirv_AtomicFMaxEXT[[:alnum:]]+]]({{.*}})
%call3.i.i.i = tail call spir_func double @_Z21__spirv_AtomicFMaxEXTPU3AS1dN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagEd(double addrspace(1)* %add.ptr.i29, i32 1, i32 896, double %add.i) #2
%sext.i = shl i64 %5, 32
%conv6.i = ashr exact i64 %sext.i, 32
%ptridx.i.i = getelementptr inbounds double, double addrspace(1)* %add.ptr.i, i64 %conv6.i
%ptridx.ascast.i.i = addrspacecast double addrspace(1)* %ptridx.i.i to double addrspace(4)*
store double %call3.i.i.i, double addrspace(4)* %ptridx.ascast.i.i, align 8, !tbaa !27
ret void
}

; Function Attrs: convergent
; CHECK-LLVM: declare {{.*}}spir_func double @[[DOUBLE_FUNC_NAME]](double addrspace(1)*, i32, i32, double)
declare dso_local spir_func double @_Z21__spirv_AtomicFMaxEXTPU3AS1dN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagEd(double addrspace(1)*, i32, i32, double) local_unnamed_addr #1

attributes #0 = { convergent norecurse "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="true" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { convergent "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #2 = { convergent nounwind }

!llvm.module.flags = !{!0}
!opencl.spir.version = !{!1}
!spirv.Source = !{!2}
!llvm.ident = !{!3}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 1, i32 2}
!2 = !{i32 4, i32 100000}
!3 = !{!"clang version 12.0.0"}
!4 = !{i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1}
!5 = !{!6, !8, !10, !12}
!6 = distinct !{!6, !7, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEE8initSizeEv: %agg.result"}
!7 = distinct !{!7, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEE8initSizeEv"}
!8 = distinct !{!8, !9, !"_ZN7__spirvL22initGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEEET0_v: %agg.result"}
!9 = distinct !{!9, !"_ZN7__spirvL22initGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEEET0_v"}
!10 = distinct !{!10, !11, !"_ZN2cl4sycl6detail7Builder7getItemILi1ELb1EEENSt9enable_ifIXT0_EKNS0_4itemIXT_EXT0_EEEE4typeEv: %agg.result"}
!11 = distinct !{!11, !"_ZN2cl4sycl6detail7Builder7getItemILi1ELb1EEENSt9enable_ifIXT0_EKNS0_4itemIXT_EXT0_EEEE4typeEv"}
!12 = distinct !{!12, !13, !"_ZN2cl4sycl6detail7Builder10getElementILi1ELb1EEEDTcl7getItemIXT_EXT0_EEEEPNS0_4itemIXT_EXT0_EEE: %agg.result"}
!13 = distinct !{!13, !"_ZN2cl4sycl6detail7Builder10getElementILi1ELb1EEEDTcl7getItemIXT_EXT0_EEEEPNS0_4itemIXT_EXT0_EEE"}
!14 = !{!15, !15, i64 0}
!15 = !{!"float", !16, i64 0}
!16 = !{!"omnipotent char", !17, i64 0}
!17 = !{!"Simple C++ TBAA"}
!18 = !{!19, !21, !23, !25}
!19 = distinct !{!19, !20, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEE8initSizeEv: %agg.result"}
!20 = distinct !{!20, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEE8initSizeEv"}
!21 = distinct !{!21, !22, !"_ZN7__spirvL22initGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEEET0_v: %agg.result"}
!22 = distinct !{!22, !"_ZN7__spirvL22initGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEEET0_v"}
!23 = distinct !{!23, !24, !"_ZN2cl4sycl6detail7Builder7getItemILi1ELb1EEENSt9enable_ifIXT0_EKNS0_4itemIXT_EXT0_EEEE4typeEv: %agg.result"}
!24 = distinct !{!24, !"_ZN2cl4sycl6detail7Builder7getItemILi1ELb1EEENSt9enable_ifIXT0_EKNS0_4itemIXT_EXT0_EEEE4typeEv"}
!25 = distinct !{!25, !26, !"_ZN2cl4sycl6detail7Builder10getElementILi1ELb1EEEDTcl7getItemIXT_EXT0_EEEEPNS0_4itemIXT_EXT0_EEE: %agg.result"}
!26 = distinct !{!26, !"_ZN2cl4sycl6detail7Builder10getElementILi1ELb1EEEDTcl7getItemIXT_EXT0_EEEEPNS0_4itemIXT_EXT0_EEE"}
!27 = !{!28, !28, i64 0}
!28 = !{!"double", !16, i64 0}
Loading

0 comments on commit d1213fe

Please sign in to comment.