Skip to content

Commit

Permalink
Rework host indexing. (#499)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Oct 31, 2023
1 parent 5f40711 commit 4c13a99
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 61 deletions.
13 changes: 1 addition & 12 deletions src/host/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,21 +308,10 @@ function Adapt.adapt_storage(to::ToGPU, xs::Array)
arr
end

# we don't really want an array, so don't call `adapt(Array, ...)`,
# but just want GPUArray indices to get downloaded back to the CPU.
# this makes sure we preserve array-like containers, like Base.Slice.
struct BackToCPU end
Adapt.adapt_storage(::BackToCPU, xs::AbstractGPUArray) = convert(Array, xs)

@inline function Base.view(A::AbstractGPUArray, I::Vararg{Any,N}) where {N}
J = to_indices(A, I)
@boundscheck begin
# Base's boundscheck accesses the indices, so make sure they reside on the CPU.
# this is expensive, but it's a bounds check after all.
J_cpu = map(j->adapt(BackToCPU(), j), J)
checkbounds(A, J_cpu...)
end
J_gpu = map(j->adapt(ToGPU(A), j), J)
@boundscheck checkbounds(A, J...)
unsafe_view(A, J_gpu, GPUIndexStyle(I...))
end

Expand Down
147 changes: 98 additions & 49 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,77 @@
# host-level indexing


# basic indexing with integers
# indexing operators

Base.IndexStyle(::Type{<:AbstractGPUArray}) = Base.IndexLinear()

function Base.getindex(xs::AbstractGPUArray{T}, I::Integer...) where T
vectorized_indices(Is::Union{Integer,CartesianIndex}...) = Val{false}()
vectorized_indices(Is...) = Val{true}()

# TODO: re-use Base functionality for the conversion of indices to a linear index,
# by only implementing `getindex(A, ::Int)` etc. this is difficult due to
# ambiguities with the vectorized method that can take any index type.

Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, Is...) =
_getindex(vectorized_indices(Is...), A, to_indices(A, Is)...)
Base.@propagate_inbounds _getindex(::Val{false}, A::AbstractGPUArray, Is...) =
scalar_getindex(A, to_indices(A, Is)...)
Base.@propagate_inbounds _getindex(::Val{true}, A::AbstractGPUArray, Is...) =
vectorized_getindex(A, to_indices(A, Is)...)

Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, Is...) =
_setindex!(vectorized_indices(Is...), A, v, to_indices(A, Is)...)
Base.@propagate_inbounds _setindex!(::Val{false}, A::AbstractGPUArray, v, Is...) =
scalar_setindex!(A, v, to_indices(A, Is)...)
Base.@propagate_inbounds _setindex!(::Val{true}, A::AbstractGPUArray, v, Is...) =
vectorized_setindex!(A, v, to_indices(A, Is)...)

## scalar indexing

function scalar_getindex(A::AbstractGPUArray{T}, Is...) where T
@boundscheck checkbounds(A, Is...)
I = Base._to_linear_index(A, Is...)
getindex(A, I)
end

function scalar_setindex!(A::AbstractGPUArray{T}, v, Is...) where T
@boundscheck checkbounds(A, Is...)
I = Base._to_linear_index(A, Is...)
setindex!(A, v, I)
end

# we still dispatch to `Base.getindex(a, ::Int)` etc so that there's a single method to
# override when a back-end (e.g. with unified memory) wants to allow scalar indexing.

function Base.getindex(A::AbstractGPUArray{T}, I::Int) where T
@boundscheck checkbounds(A, I)
assertscalar("getindex")
i = Base._to_linear_index(xs, I...)
x = Array{T}(undef, 1)
copyto!(x, 1, xs, i, 1)
copyto!(x, 1, A, I, 1)
return x[1]
end

function Base.setindex!(xs::AbstractGPUArray{T}, v::T, I::Integer...) where T
function Base.setindex!(A::AbstractGPUArray{T}, v, I::Int) where T
@boundscheck checkbounds(A, I)
assertscalar("setindex!")
i = Base._to_linear_index(xs, I...)
x = T[v]
copyto!(xs, i, x, 1, 1)
return xs
copyto!(A, I, x, 1, 1)
return A
end

Base.setindex!(xs::AbstractGPUArray, v, I::Integer...) =
setindex!(xs, convert(eltype(xs), v), I...)

## vectorized indexing

# basic indexing with cartesian indices

Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, I::Union{Integer, CartesianIndex}...) =
A[Base.to_indices(A, I)...]
Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, I::Union{Integer, CartesianIndex}...) =
(A[Base.to_indices(A, I)...] = v; A)


# generalized multidimensional indexing

Base.getindex(A::AbstractGPUArray, I...) = _getindex(A, to_indices(A, I)...)

function _getindex(src::AbstractGPUArray, Is...)
function vectorized_getindex(src::AbstractGPUArray, Is...)
shape = Base.index_shape(Is...)
dest = similar(src, shape)
any(isempty, Is) && return dest # indexing with empty array
idims = map(length, Is)

AT = typeof(src).name.wrapper
# NOTE: we are pretty liberal here supporting non-GPU indices...
gpu_call(getindex_kernel, dest, src, idims, adapt(AT, Is)...)
Is = map(x->adapt(ToGPU(src), x), Is)
@boundscheck checkbounds(src, Is...)

gpu_call(getindex_kernel, dest, src, idims, Is...)
return dest
end

Expand All @@ -61,9 +87,7 @@ end
end
end

Base.setindex!(A::AbstractGPUArray, v, I...) = _setindex!(A, v, to_indices(A, I)...)

function _setindex!(dest::AbstractGPUArray, src, Is...)
function vectorized_setindex!(dest::AbstractGPUArray, src, Is...)
isempty(Is) && return dest
idims = length.(Is)
len = prod(idims)
Expand All @@ -76,9 +100,11 @@ function _setindex!(dest::AbstractGPUArray, src, Is...)
end
end

AT = typeof(dest).name.wrapper
# NOTE: we are pretty liberal here supporting non-GPU sources and indices...
gpu_call(setindex_kernel, dest, adapt(AT, src), idims, len, adapt(AT, Is)...;
# NOTE: we are pretty liberal here supporting non-GPU indices...
Is = map(x->adapt(ToGPU(dest), x), Is)
@boundscheck checkbounds(dest, Is...)

gpu_call(setindex_kernel, dest, adapt(ToGPU(dest), src), idims, len, Is...;
elements=len)
return dest
end
Expand All @@ -96,7 +122,30 @@ end
end


## find*
# bounds checking

# indices residing on the GPU should be bounds-checked on the GPU to avoid iteration.

# not all wrapped GPU arrays make sense as indices, so we use a subset of `AnyGPUArray`
const IndexGPUArray{T} = Union{AbstractGPUArray{T},
SubArray{T, <:Any, <:AbstractGPUArray},
LinearAlgebra.Adjoint{T}}

@inline function Base.checkindex(::Type{Bool}, inds::AbstractUnitRange, I::IndexGPUArray)
all(broadcast(I) do i
Base.checkindex(Bool, inds, i)
end)
end

@inline function Base.checkindex(::Type{Bool}, inds::Tuple,
I::IndexGPUArray{<:CartesianIndex})
all(broadcast(I) do i
Base.checkbounds_indices(Bool, inds, (i,))
end)
end


# find*

# simple array type that returns the index used to access an element, while
# retaining the dimensionality of the original array. this can be used to
Expand All @@ -107,15 +156,15 @@ struct EachIndex{T,N,IS} <: AbstractArray{T,N}
dims::NTuple{N,Int}
indices::IS
end
EachIndex(xs::AbstractArray) =
EachIndex{typeof(firstindex(xs)), ndims(xs), typeof(eachindex(xs))}(
size(xs), eachindex(xs))
EachIndex(A::AbstractArray) =
EachIndex{typeof(firstindex(A)), ndims(A), typeof(eachindex(A))}(
size(A), eachindex(A))
Base.size(ei::EachIndex) = ei.dims
Base.getindex(ei::EachIndex, i::Int) = ei.indices[i]
Base.IndexStyle(::Type{<:EachIndex}) = Base.IndexLinear()

function Base.findfirst(f::Function, xs::AnyGPUArray)
indices = EachIndex(xs)
function Base.findfirst(f::Function, A::AnyGPUArray)
indices = EachIndex(A)
dummy_index = first(indices)

# given two pairs of (istrue, index), return the one with the smallest index
Expand All @@ -130,23 +179,23 @@ function Base.findfirst(f::Function, xs::AnyGPUArray)
return (false, dummy_index)
end

res = mapreduce((x, y)->(f(x), y), reduction, xs, indices;
res = mapreduce((x, y)->(f(x), y), reduction, A, indices;
init = (false, dummy_index))
if res[1]
# out of consistency with Base.findarray, return a CartesianIndex
# when the input is a multidimensional array
ndims(xs) == 1 && return res[2]
return CartesianIndices(xs)[res[2]]
ndims(A) == 1 && return res[2]
return CartesianIndices(A)[res[2]]
else
return nothing
end
end

Base.findfirst(xs::AnyGPUArray{Bool}) = findfirst(identity, xs)
Base.findfirst(A::AnyGPUArray{Bool}) = findfirst(identity, A)

function findminmax(binop, xs::AnyGPUArray; init, dims)
indices = EachIndex(xs)
dummy_index = firstindex(xs)
function findminmax(binop, A::AnyGPUArray; init, dims)
indices = EachIndex(A)
dummy_index = firstindex(A)

function reduction(t1, t2)
(x, i), (y, j) = t1, t2
Expand All @@ -157,16 +206,16 @@ function findminmax(binop, xs::AnyGPUArray; init, dims)
end

if dims == Colon()
res = mapreduce(tuple, reduction, xs, indices; init = (init, dummy_index))
res = mapreduce(tuple, reduction, A, indices; init = (init, dummy_index))

# out of consistency with Base.findarray, return a CartesianIndex
# when the input is a multidimensional array
return (res[1], ndims(xs) == 1 ? res[2] : CartesianIndices(xs)[res[2]])
return (res[1], ndims(A) == 1 ? res[2] : CartesianIndices(A)[res[2]])
else
res = mapreduce(tuple, reduction, xs, indices;
res = mapreduce(tuple, reduction, A, indices;
init = (init, dummy_index), dims=dims)
vals = map(x->x[1], res)
inds = map(x->ndims(xs) == 1 ? x[2] : CartesianIndices(xs)[x[2]], res)
inds = map(x->ndims(A) == 1 ? x[2] : CartesianIndices(A)[x[2]], res)
return (vals, inds)
end
end
Expand Down
6 changes: 6 additions & 0 deletions test/testsuite/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ end
@test_throws DimensionMismatch x[1:9,1:9,:,:] = y
end

@testset "mismatching axes/indices" begin
a = rand(Float32, 1,1)
@test compare(a->a[1:1], AT, a)
@test compare(a->a[1:1,1:1], AT, a)
@test compare(a->a[1:1,1:1,1:1], AT, a)
end
end

@testsuite "indexing find" (AT, eltypes)->begin
Expand Down

0 comments on commit 4c13a99

Please sign in to comment.