Skip to content

Commit

Permalink
Generalize indexing and reduce its ambiguity footprint
Browse files Browse the repository at this point in the history
This should support non-scalar static indexing for any dimensionality
(not just dimensions 1-4). By having fewer `getindex` and `setindex!`
methods it should also substantially reduce the likelihood of
ambiguities with other packages.
  • Loading branch information
timholy committed Dec 19, 2018
1 parent da4a1ed commit 93b7ea6
Showing 1 changed file with 23 additions and 92 deletions.
115 changes: 23 additions & 92 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ setindex!(a::StaticArray, value, i::Int) = error("setindex!(::$(typeof(a)), valu

# Note: all indexing behavior defaults to dense, linear indexing

@propagate_inbounds function getindex(a::StaticArray, inds::Int...)
@boundscheck checkbounds(a, inds...)
@propagate_inbounds function getindex(a::StaticArray{<:Tuple,<:Any,N}, inds::Vararg{Int,N}) where N
@boundscheck checkbounds(a, inds...)
_getindex_scalar(Size(a), a, inds...)
end

Expand All @@ -30,8 +30,8 @@ end
end
end

@propagate_inbounds function setindex!(a::StaticArray, value, inds::Int...)
@boundscheck checkbounds(a, inds...)
@propagate_inbounds function setindex!(a::StaticArray{<:Tuple,<:Any,N}, value, inds::Vararg{Int,N}) where N
@boundscheck checkbounds(a, inds...)
_setindex!_scalar(Size(a), a, value, inds...)
end

Expand Down Expand Up @@ -182,46 +182,32 @@ end
## Multidimensional non-scalar indexing ##
###########################################

# getindex

@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Any, Int}, Colon}...)
_getindex(a, index_sizes(Size(a), inds...), inds)
end
# To intercept `A[i1, ...]` where all `i` indexes have static sizes,
# create a wrapper used to mark non-scalar indexing operations.
# We insert this at a point in the dispatch hierarchy where we can intercept any
# `typeof(A)` (specifically, including dynamic arrays) without triggering ambiguities.

# Hard to describe "Union{Int, StaticArray{<:Any, Int}} with at least one StaticArray{<:Any, Int}"
# Here we require the first StaticArray{<:Any, Int} to be within the first four dimensions
@propagate_inbounds function getindex(a::AbstractArray, i1::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_getindex(a, index_sizes(i1, inds...), (i1, inds...))
struct StaticIndexing{I}
ind::I
end
unwrap(i::StaticIndexing) = i.ind

@propagate_inbounds function getindex(a::AbstractArray, i1::Int, i2::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_getindex(a, index_sizes(i1, i2, inds...), (i1, i2, inds...))
function Base.to_indices(A, I::Tuple{Vararg{Union{Integer, CartesianIndex, StaticArray{<:Tuple,Int}}}})
inds = to_indices(A, axes(A), I)
return map(StaticIndexing, inds)
end

@propagate_inbounds function getindex(a::AbstractArray, i1::Int, i2::Int, i3::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_getindex(a, index_sizes(i1, i2, i3, inds...), (i1, i2, i3, inds...))
end

@propagate_inbounds function getindex(a::AbstractArray, i1::Int, i2::Int, i3::Int, i4::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_getindex(a, index_sizes(i1, i2, i3, i4, inds...), (i1, i2, i3, i4, inds...))
end

# Disambuguity methods for the above
@propagate_inbounds function getindex(a::StaticArray, i1::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_getindex(a, index_sizes(i1, inds...), (i1, inds...))
end
# getindex

@propagate_inbounds function getindex(a::StaticArray, i1::Int, i2::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_getindex(a, index_sizes(i1, i2, inds...), (i1, i2, inds...))
@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Any, Int}, Colon}...)
_getindex(a, index_sizes(Size(a), inds...), inds)
end

@propagate_inbounds function getindex(a::StaticArray, i1::Int, i2::Int, i3::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_getindex(a, index_sizes(i1, i2, i3, inds...), (i1, i2, i3, inds...))
function Base._getindex(::IndexStyle, A::AbstractArray, i1::StaticIndexing, I::StaticIndexing...)
inds = (unwrap(i1), map(unwrap, I)...)
return StaticArrays._getindex(A, index_sizes(inds...), inds)
end

@propagate_inbounds function getindex(a::StaticArray, i1::Int, i2::Int, i3::Int, i4::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_getindex(a, index_sizes(i1, i2, i3, i4, inds...), (i1, i2, i3, i4, inds...))
end

@generated function _getindex(a::AbstractArray, ind_sizes::Tuple{Vararg{Size}}, inds)
newsize = out_index_size(ind_sizes.parameters...)
Expand Down Expand Up @@ -265,64 +251,9 @@ end
_setindex!(a, value, index_sizes(Size(a), inds...), inds)
end

# Hard to describe "Union{Int, StaticArray{<:Any, Int}} with at least one StaticArray{<:Any, Int}"
# Here we require the first StaticArray{<:Any, Int} to be within the first four dimensions
@propagate_inbounds function setindex!(a::AbstractArray, value, i1::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, inds...), (i1, inds...))
end

@propagate_inbounds function setindex!(a::AbstractArray, value, i1::Int, i2::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, i2, inds...), (i1, i2, inds...))
end

@propagate_inbounds function setindex!(a::AbstractArray, value, i1::Int, i2::Int, i3::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, i2, i3, inds...), (i1, i2, i3, inds...))
end

@propagate_inbounds function setindex!(a::AbstractArray, value, i1::Int, i2::Int, i3::Int, i4::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, i2, i3, i4, inds...), (i1, i2, i3, i4, inds...))
end

# Disambiguity methods for the above
@propagate_inbounds function setindex!(a::StaticArray, value, i1::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, inds...), (i1, inds...))
end

@propagate_inbounds function setindex!(a::StaticArray, value, i1::Int, i2::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, i2, inds...), (i1, i2, inds...))
end

@propagate_inbounds function setindex!(a::StaticArray, value, i1::Int, i2::Int, i3::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, i2, i3, inds...), (i1, i2, i3, inds...))
end

@propagate_inbounds function setindex!(a::StaticArray, value, i1::Int, i2::Int, i3::Int, i4::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, i2, i3, i4, inds...), (i1, i2, i3, i4, inds...))
end

# disambiguities from Base
@propagate_inbounds function setindex!(a::Array, value, i1::StaticVector{<:Any, Int})
_setindex!(a, value, index_sizes(i1), (i1,))
end

@propagate_inbounds function setindex!(a::Array, value::AbstractArray, i1::StaticVector{<:Any, Int})
_setindex!(a, value, index_sizes(i1), (i1,))
end

@propagate_inbounds function setindex!(a::Array, value::AbstractArray, i1::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, inds...), (i1, inds...))
end

@propagate_inbounds function setindex!(a::Array, value::AbstractArray, i1::Int, i2::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, i2, inds...), (i1, i2, inds...))
end

@propagate_inbounds function setindex!(a::Array, value::AbstractArray, i1::Int, i2::Int, i3::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, i2, i3, inds...), (i1, i2, i3, inds...))
end

@propagate_inbounds function setindex!(a::Array, value::AbstractArray, i1::Int, i2::Int, i3::Int, i4::StaticArray{<:Any, Int}, inds::Union{Int, StaticArray{<:Any, Int}}...)
_setindex!(a, value, index_sizes(i1, i2, i3, i4, inds...), (i1, i2, i3, i4, inds...))
function Base._setindex!(::IndexStyle, a::AbstractArray, value, i1::StaticIndexing, I::StaticIndexing...)
inds = (unwrap(i1), map(unwrap, I)...)
return StaticArrays._setindex!(a, value, index_sizes(inds...), inds)
end

# setindex! from a scalar
Expand Down

0 comments on commit 93b7ea6

Please sign in to comment.