diff --git a/.gitignore b/.gitignore index 20bee6fd06..b9b34afe62 100644 --- a/.gitignore +++ b/.gitignore @@ -6,5 +6,4 @@ lcov.info build/ /lib/**/Manifest.toml -/LocalPreferences.toml /lib/**/LocalPreferences.toml diff --git a/LocalPreferences.toml b/LocalPreferences.toml new file mode 100644 index 0000000000..36e2ff8141 --- /dev/null +++ b/LocalPreferences.toml @@ -0,0 +1,9 @@ +[CUDA_Runtime_jll] +# which CUDA runtime to use; can be set to a supported +# version number or to "local" for a local installation. +#version = "11.8" + +[CUDA] +# whether to use a nonblocking synchronization mechanism, +# making it possible to do use cooperative multitasking. +#nonblocking_synchronization = true diff --git a/lib/cudadrv/CUDAdrv.jl b/lib/cudadrv/CUDAdrv.jl index f978a01441..0e4f0c0a92 100644 --- a/lib/cudadrv/CUDAdrv.jl +++ b/lib/cudadrv/CUDAdrv.jl @@ -29,3 +29,6 @@ include("graph.jl") # global state (CUDA.jl's driver wrappers behave like CUDA's runtime library) include("state.jl") + +# support for concurrent programming +include("synchronization.jl") diff --git a/lib/cudadrv/context.jl b/lib/cudadrv/context.jl index d8eec2420e..25a6df9158 100644 --- a/lib/cudadrv/context.jl +++ b/lib/cudadrv/context.jl @@ -299,7 +299,7 @@ associated with the current task. function synchronize(ctx::CuContext) push!(CuContext, ctx) try - nonblocking_synchronize() + device_synchronize() finally pop!(CuContext) end @@ -316,21 +316,9 @@ associated with the current task. On the device, `device_synchronize` acts as a synchronization point for child grids in the context of dynamic parallelism. """ -device_synchronize() = nonblocking_synchronize() +device_synchronize() # XXX: can we put the device docstring in dynamic_parallelism.jl? -@inline function nonblocking_synchronize() - # perform as much of the sync as possible without blocking in CUDA. - # XXX: remove this using a yield callback, or by synchronizing on a dedicated thread? - nonblocking_synchronize(legacy_stream()) - - # even though the GPU should be idle now, CUDA hooks work to the actual API call. - # see NVIDIA bug #3383169 for more details. - cuCtxSynchronize() - - check_exceptions() -end - ## cache config diff --git a/lib/cudadrv/events.jl b/lib/cudadrv/events.jl index 8c4192d9c8..63b42fce69 100644 --- a/lib/cudadrv/events.jl +++ b/lib/cudadrv/events.jl @@ -49,36 +49,7 @@ record(e::CuEvent, stream::CuStream=stream()) = Waits for an event to complete. """ -function synchronize(e::CuEvent) - # perform as much of the sync as possible without blocking in CUDA. - # XXX: remove this using a yield callback, or by synchronizing on a dedicated thread? - nonblocking_synchronize(e) - - # even though the GPU should be idle now, CUDA hooks work to the actual API call. - # see NVIDIA bug #3383169 for more details. - cuEventSynchronize(e) -end - -@inline function nonblocking_synchronize(e::CuEvent) - # fast path - isdone(e) && return - - # spin (initially without yielding to minimize latency) - spins = 0 - while spins < 256 - if spins < 32 - ccall(:jl_cpu_pause, Cvoid, ()) - # Temporary solution before we have gc transition support in codegen. - ccall(:jl_gc_safepoint, Cvoid, ()) - else - yield() - end - isdone(e) && return - spins += 1 - end - - return -end +synchronize(e::CuEvent) """ isdone(e::CuEvent) diff --git a/lib/cudadrv/state.jl b/lib/cudadrv/state.jl index b7ec0d85cc..b99134596a 100644 --- a/lib/cudadrv/state.jl +++ b/lib/cudadrv/state.jl @@ -444,7 +444,7 @@ an array or a dictionary, use additional locks. """ struct PerDevice{T} lock::ReentrantLock - values::LazyInitialized{Vector{Union{Nothing,Tuple{CuContext,T}}}} + values::LazyInitialized{Vector{Union{Nothing,Tuple{CuContext,T}}},Nothing} end function PerDevice{T}() where {T} diff --git a/lib/cudadrv/stream.jl b/lib/cudadrv/stream.jl index 67f455a740..cea1499983 100644 --- a/lib/cudadrv/stream.jl +++ b/lib/cudadrv/stream.jl @@ -120,75 +120,7 @@ associated with the current Julia task. See also: [`device_synchronize`](@ref) """ -function synchronize(stream::CuStream=stream(); blocking=nothing) - if blocking !== nothing - Base.depwarn("the blocking keyword to synchronize() has been deprecated", :synchronize) - end - - # perform as much of the sync as possible without blocking in CUDA. - # XXX: remove this using a yield callback, or by synchronizing on a dedicated stream? - nonblocking_synchronize(stream) - - # even though the GPU should be idle now, CUDA hooks work to the actual API call. - # see NVIDIA bug #3383169 for more details. - cuStreamSynchronize(stream) - - check_exceptions() -end - -@inline function nonblocking_synchronize(stream::CuStream) - # fast path - isdone(stream) && return - - # minimize latency of short operations by busy-waiting, - # initially without even yielding to other tasks - spins = 0 - while spins < 256 - if spins < 32 - ccall(:jl_cpu_pause, Cvoid, ()) - # Temporary solution before we have gc transition support in codegen. - ccall(:jl_gc_safepoint, Cvoid, ()) - else - yield() - end - isdone(stream) && return - spins += 1 - end - - # minimize CPU usage of long-running kernels by waiting for an event signalled by CUDA - event = Base.Event() - launch(; stream) do - notify(event) - end - # if an error occurs, the callback may never fire, so use a timer to detect such cases - dev = device() - timer = Timer(0; interval=1) - Base.@sync begin - Threads.@spawn try - device!(dev) - while true - try - Base.wait(timer) - catch err - err isa EOFError && break - rethrow() - end - if unsafe_cuStreamQuery(stream) != ERROR_NOT_READY - break - end - end - finally - notify(event) - end - - Threads.@spawn begin - Base.wait(event) - close(timer) - end - end - - return -end +synchronize(stream::CuStream=stream()) """ priority_range() diff --git a/lib/cudadrv/synchronization.jl b/lib/cudadrv/synchronization.jl new file mode 100644 index 0000000000..a055a50ae0 --- /dev/null +++ b/lib/cudadrv/synchronization.jl @@ -0,0 +1,268 @@ +# support for nonblocking synchronization + +const use_nonblocking_synchronization = + Preferences.@load_preference("nonblocking_synchronization", true) + + +# +# bidirectional channel +# + +# custom, unbuffered channel that supports returning a value to the sender +# without the need for a second channel +struct BidirectionalChannel{T} <: AbstractChannel{T} + cond_take::Threads.Condition # waiting for data to become available + cond_put::Threads.Condition # waiting for a writeable slot + cond_ret::Threads.Condition # waiting for a data to be returned + + function BidirectionalChannel{T}() where T + lock = ReentrantLock() + cond_put = Threads.Condition(lock) + cond_take = Threads.Condition(lock) + cond_ret = Threads.Condition(lock) + return new(cond_take, cond_put, cond_ret) + end +end + +Base.put!(c::BidirectionalChannel{T}, v) where T = put!(c, convert(T, v)) +function Base.put!(c::BidirectionalChannel{T}, v::T) where T + lock(c) + try + # wait for a slot to be available + while isempty(c.cond_take) + Base.wait(c.cond_put) + end + + # pass a value to the consumer + notify(c.cond_take, v, false, false) + + # wait for a return value to be produced + Base.wait(c.cond_ret) + finally + unlock(c) + end +end + +function Base.take!(f, c::BidirectionalChannel{T}) where T + lock(c) + try + # notify the producer that we're ready to accept a value + notify(c.cond_put, nothing, false, false) + + # receive a value from the producer + v = Base.wait(c.cond_take)::T + + # return a value to the producer + ret = f(v) + notify(c.cond_ret, ret, false, false) + finally + unlock(c) + end +end + +Base.lock(c::BidirectionalChannel) = lock(c.cond_take) +Base.unlock(c::BidirectionalChannel) = unlock(c.cond_take) + + +# +# nonblocking sync +# + +@static if VERSION >= v"1.9.2" + +# if we support foreign threads, perform the synchronization on a separate thread. + +const MAX_SYNC_THREADS = 4 +const sync_channels = Array{BidirectionalChannel{Any}}(undef, MAX_SYNC_THREADS) +const sync_channel_cursor = Threads.Atomic{UInt32}(1) + +function synchronization_worker(data) + i = Int(data) + chan = sync_channels[i] + + while true + # wait for work + take!(chan) do v + if v isa CuContext + context!(v) + unsafe_cuCtxSynchronize() + elseif v isa CuStream + context!(v.ctx) + unsafe_cuStreamSynchronize(v) + elseif v isa CuEvent + context!(v.ctx) + unsafe_cuEventSynchronize(v) + end + end + end +end + +@noinline function create_synchronization_worker(i) + sync_channels[i] = BidirectionalChannel{Any}() + # should be safe to assign before threads are running; + # any user will just submit work that makes it block + + # we don't know what the size of uv_thread_t is, so reserve enough space + tid = Ref{NTuple{32, UInt8}}(ntuple(i -> 0, 32)) + + cb = @cfunction(synchronization_worker, Cvoid, (Ptr{Cvoid},)) + @ccall uv_thread_create(tid::Ptr{Cvoid}, cb::Ptr{Cvoid}, Ptr{Cvoid}(i)::Ptr{Cvoid})::Int32 + + return +end + +function nonblocking_synchronize(val) + # get the channel of a synchronization worker + i = mod1(Threads.atomic_add!(sync_channel_cursor, UInt32(1)), MAX_SYNC_THREADS) + if !isassigned(sync_channels, i) + # TODO: write lock, double check, etc + create_synchronization_worker(i) + end + chan = @inbounds sync_channels[i] + + # submit the object to synchronize + res = put!(chan, val) + # this `put!` blocks until the worker has finished processing and returned value + # (which is different from regular channels) + if res != SUCCESS + throw_api_error(res) + end + + return +end + +function device_synchronize() + if use_nonblocking_synchronization + nonblocking_synchronize(context()) + else + cuCtxSynchronize() + end + check_exceptions() +end + +function synchronize(stream::CuStream=stream()) + if use_nonblocking_synchronization + if !isdone(stream) + # slow path + nonblocking_synchronize(stream) + end + else + cuStreamSynchronize(stream) + end + check_exceptions() +end + +function synchronize(event::CuEvent) + if use_nonblocking_synchronization + if !isdone(event) + # slow path + nonblocking_synchronize(event) + end + else + cuEventSynchronize(event) + end +end + +else + +# without thread adoption, have CUDA notify an async condition that wakes the libuv loop. +# this is not ideal: stream callbacks are deprecated, and do not fire in case of errors. +# furthermore, they do not trigger CUDA's synchronization hooks (see NVIDIA bug #3383169) +# requiring us to perform the actual API call again after nonblocking synchronization. + +function nonblocking_synchronize(stream::CuStream) + # fast path + isdone(stream) && return + + # minimize latency of short operations by busy-waiting, + # initially without even yielding to other tasks + spins = 0 + while spins < 256 + if spins < 32 + ccall(:jl_cpu_pause, Cvoid, ()) + # Temporary solution before we have gc transition support in codegen. + ccall(:jl_gc_safepoint, Cvoid, ()) + else + yield() + end + isdone(stream) && return + spins += 1 + end + + # minimize CPU usage of long-running kernels by waiting for an event signalled by CUDA + event = Base.Event() + launch(; stream) do + notify(event) + end + # if an error occurs, the callback may never fire, so use a timer to detect such cases + dev = device() + timer = Timer(0; interval=1) + Base.@sync begin + Threads.@spawn try + device!(dev) + while true + try + Base.wait(timer) + catch err + err isa EOFError && break + rethrow() + end + if unsafe_cuStreamQuery(stream) != ERROR_NOT_READY + break + end + end + finally + notify(event) + end + + Threads.@spawn begin + Base.wait(event) + close(timer) + end + end + + return +end + +function device_synchronize() + if use_nonblocking_synchronization + nonblocking_synchronize(legacy_stream()) + end + cuCtxSynchronize() + + check_exceptions() +end + +function synchronize(stream::CuStream=stream()) + if use_nonblocking_synchronization + nonblocking_synchronize(stream) + end + cuStreamSynchronize(stream) + + check_exceptions() +end + +function synchronize(e::CuEvent) + if use_nonblocking_synchronization + # fast path + isdone(e) && return + + # spin (initially without yielding to minimize latency) + spins = 0 + while spins < 256 + if spins < 32 + ccall(:jl_cpu_pause, Cvoid, ()) + # Temporary solution before we have gc transition support in codegen. + ccall(:jl_gc_safepoint, Cvoid, ()) + else + yield() + end + isdone(e) && return + spins += 1 + end + end + + cuEventSynchronize(e) +end + +end diff --git a/lib/utils/memoization.jl b/lib/utils/memoization.jl index 8f9f1f73f5..fff4b7fd7f 100644 --- a/lib/utils/memoization.jl +++ b/lib/utils/memoization.jl @@ -40,6 +40,9 @@ macro memoize(ex...) # anything, that entry will be the memoized new_value, or else a dictionary of values. @gensym global_cache + # in the presence of thread adoption, we need to use the maximum thread ID + nthreads = :( VERSION >= v"1.9" ? Threads.maxthreadid() : Threads.nthreads() ) + # generate code to access memoized values # (assuming the global_cache can be indexed with the thread ID) if key === nothing @@ -47,7 +50,7 @@ macro memoize(ex...) global_cache_eltyp = :(Union{Nothing,$rettyp}) ex = quote cache = get!($(esc(global_cache))) do - $global_cache_eltyp[nothing for _ in 1:Threads.nthreads()] + $global_cache_eltyp[nothing for _ in 1:$nthreads] end cached_value = @inbounds cache[Threads.threadid()] if cached_value !== nothing @@ -64,7 +67,7 @@ macro memoize(ex...) global_init = :(Union{Nothing,$rettyp}[nothing for _ in 1:$(esc(options[:maxlen]))]) ex = quote cache = get!($(esc(global_cache))) do - $global_cache_eltyp[$global_init for _ in 1:Threads.nthreads()] + $global_cache_eltyp[$global_init for _ in 1:$nthreads] end local_cache = @inbounds begin tid = Threads.threadid() @@ -86,7 +89,7 @@ macro memoize(ex...) global_init = :(Dict{$(key.typ),$rettyp}()) ex = quote cache = get!($(esc(global_cache))) do - $global_cache_eltyp[$global_init for _ in 1:Threads.nthreads()] + $global_cache_eltyp[$global_init for _ in 1:$nthreads] end local_cache = @inbounds begin tid = Threads.threadid() @@ -106,7 +109,9 @@ macro memoize(ex...) # define the per-thread cache @eval __module__ begin - const $global_cache = LazyInitialized{Vector{$(global_cache_eltyp)}}() + const $global_cache = LazyInitialized{Vector{$(global_cache_eltyp)}}() do cache + length(cache) == $nthreads + end end quote diff --git a/lib/utils/threading.jl b/lib/utils/threading.jl index 621d0877c3..4a0c06de80 100644 --- a/lib/utils/threading.jl +++ b/lib/utils/threading.jl @@ -10,26 +10,40 @@ This type is intended for lazy initialization of e.g. global structures, without `__init__`. It is similar to protecting accesses using a lock, but is much cheaper. """ -struct LazyInitialized{T} +struct LazyInitialized{T,F} # 0: uninitialized # 1: initializing # 2: initialized guard::Threads.Atomic{Int} value::Base.RefValue{T} - LazyInitialized{T}() where {T} = - new(Threads.Atomic{Int}(0), Ref{T}()) + validator::F end -@inline function Base.get!(constructor, x::LazyInitialized; hook=nothing) +LazyInitialized{T}(validator=nothing) where {T} = + LazyInitialized{T,typeof(validator)}(Threads.Atomic{Int}(0), Ref{T}(), validator) + +@inline function Base.get!(constructor, x::LazyInitialized) while x.guard[] != 2 - initialize!(x, constructor, hook) + initialize!(x, constructor) end assume(isassigned(x.value)) # to get rid of the check - x.value[] + val = x.value[] + + # check if the value is still valid + if x.validator !== nothing && !x.validator(val) + Threads.atomic_cas!(x.guard, 2, 0) + while x.guard[] != 2 + initialize!(x, constructor) + end + assume(isassigned(x.value)) + val = x.value[] + end + + return val end -@noinline function initialize!(x::LazyInitialized{T}, constructor::F1, hook::F2) where {T, F1, F2} +@noinline function initialize!(x::LazyInitialized{T}, constructor::F) where {T, F} status = Threads.atomic_cas!(x.guard, 0, 1) if status == 0 try @@ -39,10 +53,6 @@ end x.guard[] = 0 rethrow() end - - if hook !== nothing - hook() - end else yield() end diff --git a/src/pool.jl b/src/pool.jl index 38c111eb30..45d7ff586a 100644 --- a/src/pool.jl +++ b/src/pool.jl @@ -355,9 +355,6 @@ function retry_reclaim(f, isfailed) phase = 1 while true if is_stream_ordered - # NOTE: the stream-ordered allocator only releases memory on actual API calls, - # and not when our synchronization routines query the relevant streams. - # we do still call our routines to minimize the time we block in libcuda. if phase == 1 synchronize(state.stream) elseif phase == 2