Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support WrapperGPUArray nd indexing by fusing vectorized fallback. #512

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/host/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ end
struct ToGPU
array::AbstractGPUArray
end
ToGPU(A::WrappedArray) = ToGPU(parent(A))
function Adapt.adapt_storage(to::ToGPU, xs::Array)
arr = similar(to.array, eltype(xs), size(xs))
copyto!(arr, xs)
Expand Down
33 changes: 27 additions & 6 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,24 @@ end

## vectorized indexing

function vectorized_getindex(src::AbstractGPUArray, Is...)
shape = Base.index_shape(Is...)
dest = similar(src, shape)
function vectorized_getindex!(dest::AbstractGPUArray, src::AbstractArray, Is...)
any(isempty, Is) && return dest # indexing with empty array
idims = map(length, Is)

# NOTE: we are pretty liberal here supporting non-GPU indices...
Is = map(x->adapt(ToGPU(src), x), Is)
Is = map(adapt(ToGPU(dest)), Is)
@boundscheck checkbounds(src, Is...)

gpu_call(getindex_kernel, dest, src, idims, Is...)
return dest
end

function vectorized_getindex(src::AbstractGPUArray, Is...)
shape = Base.index_shape(Is...)
dest = similar(src, shape)
return vectorized_getindex!(dest, src, Is...)
end

@generated function getindex_kernel(ctx::AbstractKernelContext, dest, src, idims,
Is::Vararg{Any,N}) where {N}
quote
Expand All @@ -87,7 +91,7 @@ end
end
end

function vectorized_setindex!(dest::AbstractGPUArray, src, Is...)
function vectorized_setindex!(dest::AbstractArray, src, Is...)
isempty(Is) && return dest
idims = length.(Is)
len = prod(idims)
Expand All @@ -101,7 +105,7 @@ function vectorized_setindex!(dest::AbstractGPUArray, src, Is...)
end

# NOTE: we are pretty liberal here supporting non-GPU indices...
Is = map(x->adapt(ToGPU(dest), x), Is)
Is = map(adapt(ToGPU(dest)), Is)
@boundscheck checkbounds(dest, Is...)

gpu_call(setindex_kernel, dest, adapt(ToGPU(dest), src), idims, len, Is...;
Expand Down Expand Up @@ -144,6 +148,23 @@ end
end)
end

## Vectorized index overloading for `WrappedGPUArray`
# We'd better not to overload `getindex`/`setindex!` directly as otherwise
# the ambiguities from the default scalar fallback become a mess.
# The default `getindex` for `AbstractArray` follows a `similar`-`copyto!` style.
# Thus we only dispatch the `copyto!` part (`Base._unsafe_getindex!`) to our implement.
function Base._unsafe_getindex!(dest::AbstractGPUArray, src::AbstractArray, Is::Vararg{Union{Real, AbstractArray}, N}) where {N}
return vectorized_getindex!(dest, src, Base.ensure_indexable(Is)...)
end
# Similar for `setindex!`, its default fallback is equivalent to `copyto!`.
# We only dispatch the `copyto!` part (`Base._unsafe_setindex!`) to our implement.
function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
end
# And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization.
function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
end

# find*

Expand Down
17 changes: 17 additions & 0 deletions test/testsuite/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,23 @@ end
@test compare(a->a[1:1,1:1], AT, a)
@test compare(a->a[1:1,1:1,1:1], AT, a)
end

@testset "getindex for WrapperGPUArray" begin
a = rand(Float32, 5, 5)
@test compare(a->a'[:, 1], AT, a)
@test compare(a->Base.PermutedDimsArray(a, (2, 1))[2:-1:1, 1:2], AT, a)
@test compare(a->LowerTriangular(a)[:], AT, a) broken=(string(AT) in ["MtlArray", "oneArray"])
@test compare(a->Symmetric(a, :U)[a .> 0], AT, a)
end

@testset "setindex! for WrapperGPUArray" for T in eltypes
x = AT(zeros(T, (10, 10)))'
y = AT(rand(T, (5, 5)))
x[2:6, 2:6] = y
@test Array(parent(x)[2:6, 2:6]) == Array(y)'
x[2:6, 2:6] = 1:25
@test Array(parent(x)[2:6, 2:6]) == reshape(1:25, 5, 5)'
end
end

@testsuite "indexing find" (AT, eltypes)->begin
Expand Down
Loading