Skip to content

Commit

Permalink
Fix pointer to no longer assume contiguity (#36405)
Browse files Browse the repository at this point in the history
* Fix pointer to no longer assume contiguity
  • Loading branch information
mbauman authored Jun 26, 2020
1 parent 29e1454 commit 59b8dde
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 22 deletions.
23 changes: 15 additions & 8 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,14 @@ end
pointer(x::AbstractArray{T}) where {T} = unsafe_convert(Ptr{T}, x)
function pointer(x::AbstractArray{T}, i::Integer) where T
@_inline_meta
unsafe_convert(Ptr{T}, x) + (i - first(LinearIndices(x)))*elsize(x)
unsafe_convert(Ptr{T}, x) + _memory_offset(x, i)
end

# The distance from pointer(x) to the element at x[I...] in bytes
_memory_offset(x::DenseArray, I...) = (_to_linear_index(x, I...) - first(LinearIndices(x)))*elsize(x)
function _memory_offset(x::AbstractArray, I...)
J = _to_subscript_indices(x, I...)
return sum(map((i, s, o)->s*(i-o), J, strides(x), Tuple(first(CartesianIndices(x)))))*elsize(x)
end

## Approach:
Expand Down Expand Up @@ -1078,10 +1085,10 @@ function _getindex(::IndexLinear, A::AbstractArray, I::Vararg{Int,M}) where M
@inbounds r = getindex(A, _to_linear_index(A, I...))
r
end
_to_linear_index(A::AbstractArray, i::Int) = i
_to_linear_index(A::AbstractVector, i::Int, I::Int...) = i
_to_linear_index(A::AbstractArray, i::Integer) = i
_to_linear_index(A::AbstractVector, i::Integer, I::Integer...) = i
_to_linear_index(A::AbstractArray) = 1
_to_linear_index(A::AbstractArray, I::Int...) = (@_inline_meta; _sub2ind(A, I...))
_to_linear_index(A::AbstractArray, I::Integer...) = (@_inline_meta; _sub2ind(A, I...))

## IndexCartesian Scalar indexing: Canonical method is full dimensionality of Ints
function _getindex(::IndexCartesian, A::AbstractArray, I::Vararg{Int,M}) where M
Expand All @@ -1094,12 +1101,12 @@ function _getindex(::IndexCartesian, A::AbstractArray{T,N}, I::Vararg{Int, N}) w
@_propagate_inbounds_meta
getindex(A, I...)
end
_to_subscript_indices(A::AbstractArray, i::Int) = (@_inline_meta; _unsafe_ind2sub(A, i))
_to_subscript_indices(A::AbstractArray, i::Integer) = (@_inline_meta; _unsafe_ind2sub(A, i))
_to_subscript_indices(A::AbstractArray{T,N}) where {T,N} = (@_inline_meta; fill_to_length((), 1, Val(N)))
_to_subscript_indices(A::AbstractArray{T,0}) where {T} = ()
_to_subscript_indices(A::AbstractArray{T,0}, i::Int) where {T} = ()
_to_subscript_indices(A::AbstractArray{T,0}, I::Int...) where {T} = ()
function _to_subscript_indices(A::AbstractArray{T,N}, I::Int...) where {T,N}
_to_subscript_indices(A::AbstractArray{T,0}, i::Integer) where {T} = ()
_to_subscript_indices(A::AbstractArray{T,0}, I::Integer...) where {T} = ()
function _to_subscript_indices(A::AbstractArray{T,N}, I::Integer...) where {T,N}
@_inline_meta
J, Jrem = IteratorsMD.split(I, Val(N))
_to_subscript_indices(A, J, Jrem)
Expand Down
1 change: 1 addition & 0 deletions base/permuteddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ function Base.strides(A::PermutedDimsArray{T,N,perm}) where {T,N,perm}
s = strides(parent(A))
ntuple(d->s[perm[d]], Val(N))
end
Base.elsize(::Type{<:PermutedDimsArray{<:Any, <:Any, <:Any, <:Any, P}}) where {P} = Base.elsize(P)

@inline function Base.getindex(A::PermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}) where {T,N,perm,iperm}
@boundscheck checkbounds(A, I...)
Expand Down
17 changes: 3 additions & 14 deletions base/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -398,23 +398,12 @@ find_extended_inds(::ScalarIndex, I...) = (@_inline_meta; find_extended_inds(I..
find_extended_inds(i1, I...) = (@_inline_meta; (i1, find_extended_inds(I...)...))
find_extended_inds() = ()

unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{RangeIndex}}}) where {T,N,P} =
unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T)
function unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{RangeIndex}}}) where {T,N,P}
return unsafe_convert(Ptr{T}, V.parent) + _memory_offset(V.parent, map(first, V.indices)...)
end

pointer(V::FastSubArray, i::Int) = pointer(V.parent, V.offset1 + V.stride1*i)
pointer(V::FastContiguousSubArray, i::Int) = pointer(V.parent, V.offset1 + i)
pointer(V::SubArray, i::Int) = _pointer(V, i)
_pointer(V::SubArray{<:Any,1}, i::Int) = pointer(V, (i,))
_pointer(V::SubArray, i::Int) = pointer(V, Base._ind2sub(axes(V), i))

function pointer(V::SubArray{T,N,<:Array,<:Tuple{Vararg{RangeIndex}}}, is::Tuple{Vararg{Int}}) where {T,N}
index = first_index(V)
strds = strides(V)
for d = 1:length(is)
index += (is[d]-1)*strds[d]
end
return pointer(V.parent, index)
end

# indices are taken from the range/vector
# Since bounds-checking is performance-critical and uses
Expand Down
3 changes: 3 additions & 0 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ Base.strides(A::Transpose{<:Any, <:StridedMatrix}) = reverse(strides(A.parent))
Base.unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
Base.unsafe_convert(::Type{Ptr{T}}, A::Transpose{<:Any, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)

Base.elsize(::Type{<:Adjoint{<:Real, P}}) where {P<:StridedVecOrMat} = Base.elsize(P)
Base.elsize(::Type{<:Transpose{<:Any, P}}) where {P<:StridedVecOrMat} = Base.elsize(P)

# for vectors, the semantics of the wrapped and unwrapped types differ
# so attempt to maintain both the parent and wrapper type insofar as possible
similar(A::AdjOrTransAbsVec) = wrapperop(A)(similar(A.parent))
Expand Down
141 changes: 141 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -978,3 +978,144 @@ end
@test Core.sizeof(arrayOfUInt48) == 24
end
end

struct Strider{T,N} <: AbstractArray{T,N}
data::Vector{T}
offset::Int
strides::NTuple{N,Int}
size::NTuple{N,Int}
end
function Strider{T}(strides::NTuple{N}, size::NTuple{N}) where {T,N}
offset = 1-sum(strides .* (strides .< 0) .* (size .- 1))
data = Array{T}(undef, sum(abs.(strides) .* (size .- 1)) + 1)
return Strider{T, N, Vector{T}}(data, offset, strides, size)
end
function Strider(vec::AbstractArray{T}, strides::NTuple{N}, size::NTuple{N}) where {T,N}
offset = 1-sum(strides .* (strides .< 0) .* (size .- 1))
@assert length(vec) >= sum(abs.(strides) .* (size .- 1)) + 1
return Strider{T, N}(vec, offset, strides, size)
end
Base.size(S::Strider) = S.size
function Base.getindex(S::Strider{<:Any,N}, I::Vararg{Int,N}) where {N}
return S.data[sum(S.strides .* (I .- 1)) + S.offset]
end
Base.strides(S::Strider) = S.strides
Base.elsize(::Type{<:Strider{T}}) where {T} = Base.elsize(Vector{T})
Base.unsafe_convert(::Type{Ptr{T}}, S::Strider{T}) where {T} = pointer(S.data, S.offset)

@testset "Simple 3d strided views and permutes" for sz in ((5, 3, 2), (7, 11, 13))
A = collect(reshape(1:prod(sz), sz))
S = Strider(vec(A), strides(A), sz)
@test pointer(A) == pointer(S)
for i in 1:prod(sz)
@test pointer(A, i) == pointer(S, i)
@test A[i] == S[i]
end
for idxs in ((1:sz[1], 1:sz[2], 1:sz[3]),
(1:sz[1], 2:2:sz[2], sz[3]:-1:1),
(2:2:sz[1]-1, sz[2]:-1:1, sz[3]:-2:2),
(sz[1]:-1:1, sz[2]:-1:1, sz[3]:-1:1),
(sz[1]-1:-3:1, sz[2]:-2:3, 1:sz[3]),)
Ai = A[idxs...]
Av = view(A, idxs...)
Sv = view(S, idxs...)
Ss = Strider{Int, 3}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Av), length.(idxs))
@test pointer(Av) == pointer(Sv) == pointer(Ss)
for i in 1:length(Av)
@test pointer(Av, i) == pointer(Sv, i) == pointer(Ss, i)
@test Ai[i] == Av[i] == Sv[i] == Ss[i]
end
for perm in ((3, 2, 1), (2, 1, 3), (3, 1, 2))
P = permutedims(A, perm)
Ap = Base.PermutedDimsArray(A, perm)
Sp = Base.PermutedDimsArray(S, perm)
Ps = Strider{Int, 3}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)])
@test pointer(Ap) == pointer(Sp) == pointer(Ps)
for i in 1:length(Ap)
# This is intentionally disabled due to ambiguity
@test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i)
@test P[i] == Ap[i] == Sp[i] == Ps[i]
end
Pv = view(P, idxs[collect(perm)]...)
Pi = P[idxs[collect(perm)]...]
Apv = view(Ap, idxs[collect(perm)]...)
Spv = view(Sp, idxs[collect(perm)]...)
Pvs = Strider{Int, 3}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Apv), size(Apv))
@test pointer(Apv) == pointer(Spv) == pointer(Pvs)
for i in 1:length(Apv)
@test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i)
@test Pi[i] == Pv[i] == Apv[i] == Spv[i] == Pvs[i]
end
Vp = permutedims(Av, perm)
Ip = permutedims(Ai, perm)
Avp = Base.PermutedDimsArray(Av, perm)
Svp = Base.PermutedDimsArray(Sv, perm)
@test pointer(Avp) == pointer(Svp)
for i in 1:length(Avp)
# This is intentionally disabled due to ambiguity
@test_broken pointer(Avp, i) == pointer(Svp, i)
@test Ip[i] == Vp[i] == Avp[i] == Svp[i]
end
end
end
end

@testset "simple 2d strided views, permutes, transposes" for sz in ((5, 3), (7, 11))
A = collect(reshape(1:prod(sz), sz))
S = Strider(vec(A), strides(A), sz)
@test pointer(A) == pointer(S)
for i in 1:prod(sz)
@test pointer(A, i) == pointer(S, i)
@test A[i] == S[i]
end
for idxs in ((1:sz[1], 1:sz[2]),
(1:sz[1], 2:2:sz[2]),
(2:2:sz[1]-1, sz[2]:-1:1),
(sz[1]:-1:1, sz[2]:-1:1),
(sz[1]-1:-3:1, sz[2]:-2:3),)
Av = view(A, idxs...)
Sv = view(S, idxs...)
Ss = Strider{Int, 2}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Av), length.(idxs))
@test pointer(Av) == pointer(Sv) == pointer(Ss)
for i in 1:length(Av)
@test pointer(Av, i) == pointer(Sv, i) == pointer(Ss, i)
@test Av[i] == Sv[i] == Ss[i]
end
perm = (2, 1)
P = permutedims(A, perm)
Ap = Base.PermutedDimsArray(A, perm)
At = transpose(A)
Aa = adjoint(A)
Sp = Base.PermutedDimsArray(S, perm)
Ps = Strider{Int, 2}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)])
@test pointer(Ap) == pointer(Sp) == pointer(Ps) == pointer(At) == pointer(Aa)
for i in 1:length(Ap)
# This is intentionally disabled due to ambiguity
@test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i) == pointer(At, i) == pointer(Aa, i)
@test pointer(Ps, i) == pointer(At, i) == pointer(Aa, i)
@test P[i] == Ap[i] == Sp[i] == Ps[i] == At[i] == Aa[i]
end
Pv = view(P, idxs[collect(perm)]...)
Apv = view(Ap, idxs[collect(perm)]...)
Atv = view(At, idxs[collect(perm)]...)
Ata = view(Aa, idxs[collect(perm)]...)
Spv = view(Sp, idxs[collect(perm)]...)
Pvs = Strider{Int, 2}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Apv), size(Apv))
@test pointer(Apv) == pointer(Spv) == pointer(Pvs) == pointer(Atv) == pointer(Ata)
for i in 1:length(Apv)
@test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i) == pointer(Atv, i) == pointer(Ata, i)
@test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i]
end
Vp = permutedims(Av, perm)
Avp = Base.PermutedDimsArray(Av, perm)
Avt = transpose(Av)
Ava = adjoint(Av)
Svp = Base.PermutedDimsArray(Sv, perm)
@test pointer(Avp) == pointer(Svp) == pointer(Avt) == pointer(Ava)
for i in 1:length(Avp)
# This is intentionally disabled due to ambiguity
@test_broken pointer(Avp, i) == pointer(Svp, i) == pointer(Avt, i) == pointer(Ava, i)
@test Vp[i] == Avp[i] == Svp[i] == Avt[i] == Ava[i]
end
end
end

0 comments on commit 59b8dde

Please sign in to comment.