diff --git a/Project.toml b/Project.toml index 3f1e533419..08a8c4b331 100644 --- a/Project.toml +++ b/Project.toml @@ -37,11 +37,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" demumble_jll = "1e29f10c-031c-5a83-9565-69cddfc27673" [weakdeps] +Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] +AtomixExt = "Atomix" ChainRulesCoreExt = "ChainRulesCore" EnzymeCoreExt = "EnzymeCore" SpecialFunctionsExt = "SpecialFunctions" @@ -49,6 +51,7 @@ SpecialFunctionsExt = "SpecialFunctions" [compat] AbstractFFTs = "0.4, 0.5, 1.0" Adapt = "4" +Atomix = "0.1, 1" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.2, 0.3, 0.4, 0.5" CUDA_Driver_jll = "0.10" diff --git a/ext/AtomixExt.jl b/ext/AtomixExt.jl new file mode 100644 index 0000000000..57495c7641 --- /dev/null +++ b/ext/AtomixExt.jl @@ -0,0 +1,83 @@ +#= +MIT License + +Copyright (c) 2022 Takafumi Arakaki and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +=# + +module AtomixExt + +# TODO: respect ordering + +using Atomix: Atomix, IndexableRef +using CUDA: CUDA, CuDeviceArray + +const CuIndexableRef{Indexable<:CuDeviceArray} = IndexableRef{Indexable} + +function Atomix.get(ref::CuIndexableRef, order) + error("not implemented") +end + +function Atomix.set!(ref::CuIndexableRef, v, order) + error("not implemented") +end + +@inline function Atomix.replace!( + ref::CuIndexableRef, + expected, + desired, + success_ordering, + failure_ordering, +) + ptr = Atomix.pointer(ref) + expected = convert(eltype(ref), expected) + desired = convert(eltype(ref), desired) + begin + old = CUDA.atomic_cas!(ptr, expected, desired) + end + return (; old = old, success = old === expected) +end + +@inline function Atomix.modify!(ref::CuIndexableRef, op::OP, x, order) where {OP} + x = convert(eltype(ref), x) + ptr = Atomix.pointer(ref) + begin + old = if op === (+) + CUDA.atomic_add!(ptr, x) + elseif op === (-) + CUDA.atomic_sub!(ptr, x) + elseif op === (&) + CUDA.atomic_and!(ptr, x) + elseif op === (|) + CUDA.atomic_or!(ptr, x) + elseif op === xor + CUDA.atomic_xor!(ptr, x) + elseif op === min + CUDA.atomic_min!(ptr, x) + elseif op === max + CUDA.atomic_max!(ptr, x) + else + error("not implemented") + end + end + return old => op(old, x) +end + +end # module AtomixCUDA diff --git a/test/Project.toml b/test/Project.toml index 5d6ea83e88..728d7a8ed1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CUDA_Driver_jll = "4ee394cb-3365-5eb0-8335-949819d2adfc" CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" diff --git a/test/base/atomix.jl b/test/base/atomix.jl new file mode 100644 index 0000000000..47785d2190 --- /dev/null +++ b/test/base/atomix.jl @@ -0,0 +1,83 @@ +#= +MIT License + +Copyright (c) 2022 Takafumi Arakaki and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +=# + +using Atomix +using Test +using CUDA + +function cuda(f) + function g() + f() + nothing + end + CUDA.@cuda g() +end + +@testset "cas" begin + idx = ( + data = 1, + cas1_ok = 2, + cas2_ok = 3, + # ... + ) + @assert minimum(idx) >= 1 + @assert maximum(idx) == length(idx) + + A = CUDA.zeros(Int, length(idx)) + cuda() do + GC.@preserve A begin + ref = Atomix.IndexableRef(A, (1,)) + (old, success) = Atomix.replace!(ref, 0, 42) + A[idx.cas1_ok] = old == 0 && success + (old, success) = Atomix.replace!(ref, 0, 43) + A[idx.cas2_ok] = old == 42 && !success + end + end + @test collect(A) == [42, 1, 1] +end + +@testset "inc" begin + @testset "core" begin + A = CUDA.CuVector(1:3) + cuda() do + GC.@preserve A begin + ref = Atomix.IndexableRef(A, (1,)) + pre, post = Atomix.modify!(ref, +, 1) + A[2] = pre + A[3] = post + end + end + @test collect(A) == [2, 1, 2] + end + + @testset "sugar" begin + A = CUDA.ones(Int, 3) + cuda() do + GC.@preserve A begin + Atomix.@atomic A[begin] += 1 + end + end + @test collect(A) == [2, 1, 1] + end +end