Skip to content

Commit

Permalink
Make host to device memory copies blocking. (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Oct 1, 2024
1 parent 8aca9d7 commit 3621255
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ function Base.copyto!(dest::CLArray{T}, doffs::Int, src::Array{T}, soffs::Int,
@boundscheck checkbounds(src, soffs)
@boundscheck checkbounds(src, soffs+n-1)
unsafe_copyto!(dest, doffs, src, soffs, n)
# device->host copies need to be blocking, because the user will expect the
# values to be available
return dest
end

Expand All @@ -186,7 +188,15 @@ function Base.copyto!(dest::Array{T}, doffs::Int, src::CLArray{T}, soffs::Int,
@boundscheck checkbounds(dest, doffs+n-1)
@boundscheck checkbounds(src, soffs)
@boundscheck checkbounds(src, soffs+n-1)
unsafe_copyto!(dest, doffs, src, soffs, n; blocking=true)
unsafe_copyto!(dest, doffs, src, soffs, n)
# host->device copies need to be blocking, because otherwise the host memory
# can be modified or even freed before the asynchronous copy finishes
#
# TODO: this is bad for performance, so we should probably:
# - expose `blocking=false`/`async=true` to the user, so that
# they can promise the buffer won't be freed or mutated behind our back
# - use a staging buffer to perform a host->host copy first;
# probably only for small buffers.
return dest
end
Base.copyto!(dest::Array{T}, src::CLArray{T}) where {T} =
Expand All @@ -200,7 +210,8 @@ function Base.copyto!(dest::CLArray{T}, doffs::Int, src::CLArray{T}, soffs::Int,
@boundscheck checkbounds(src, soffs)
@boundscheck checkbounds(src, soffs+n-1)
@assert context(dest) == context(src)
unsafe_copyto!(dest, doffs, src, soffs, n)
unsafe_copyto!(dest, doffs, src, soffs, n; blocking=false)
# device->device copies can be asynchronous
return dest
end
Base.copyto!(dest::CLArray{T}, src::CLArray{T}) where {T} =
Expand All @@ -210,7 +221,7 @@ for (srcty, dstty) in [(:Array, :CLArray), (:CLArray, :Array), (:CLArray, :CLArr
@eval begin
function Base.unsafe_copyto!(dst::$dstty{T}, dst_off::Int,
src::$srcty{T}, src_off::Int,
N::Int; blocking::Bool=false) where T
N::Int; blocking::Bool=true) where T
nbytes = N * sizeof(T)
cl.enqueue_svm_memcpy(pointer(dst, dst_off), pointer(src, src_off), nbytes;
blocking)
Expand Down

0 comments on commit 3621255

Please sign in to comment.