Skip to content

Commit

Permalink
Generalize strides for ReshapedArray
Browse files Browse the repository at this point in the history
Generalize `strides` for `ReshapedArray`
  • Loading branch information
N5N3 committed Feb 3, 2022
1 parent d8723b0 commit 96a9cd6
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 0 deletions.
2 changes: 2 additions & 0 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ function _checked_strides(stp, N)
map(first, drs)
end

_checkcontiguous(::Type{Bool}, A::ReinterpretArray) = _checkcontiguous(parent(A))

similar(a::ReinterpretArray, T::Type, d::Dims) = similar(a.parent, T, d)

function check_readable(a::ReinterpretArray{T, N, S} where N) where {T,S}
Expand Down
13 changes: 13 additions & 0 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,16 @@ substrides(strds::NTuple{N,Int}, I::Tuple{ReshapedUnitRange, Vararg{Any}}) where
(size_to_strides(strds[1], size(I[1])...)..., substrides(tail(strds), tail(I))...)
unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex,ReshapedUnitRange}}}}) where {T,N,P} =
unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T)


_checkcontiguous(::Type{Bool}, A::AbstractArray) = size_to_strides(1, size(A)...) == strides(A)
_checkcontiguous(::Type{Bool}, A::DenseArray) = true
_checkcontiguous(::Type{Bool}, A::ReshapedArray) = _checkcontiguous(Bool, parent(A))
_checkcontiguous(::Type{Bool}, A::FastContiguousSubArray) = _checkcontiguous(Bool, parent(A))

function strides(a::ReshapedArray)
# We can handle non-contiguous parent if it's a StridedVector
ndims(parent(a)) == 1 && return size_to_strides(only(strides(parent(a))), size(a)...)
_checkcontiguous(Bool, a) || throw(ArgumentError("Parent must be contiguous."))
size_to_strides(1, size(a)...)
end
7 changes: 7 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1561,3 +1561,10 @@ end
r = Base.IdentityUnitRange(3:4)
@test reshape(r, :) === reshape(r, (:,)) === r
end

@testset "strides for ReshapedArray" begin
a = reshape(view(collect(1:100), 1:2:100), 5, 10)
@test strides(a) == (2, 10)
a = reshape(view(collect(1:100), 1:2:100, 1:1), 5, 10)
@test_throws ArgumentError strides(a)
end

0 comments on commit 96a9cd6

Please sign in to comment.