diff --git a/base/reinterpretarray.jl b/base/reinterpretarray.jl index 1072dad1ed81d..c5b2f0db735c9 100644 --- a/base/reinterpretarray.jl +++ b/base/reinterpretarray.jl @@ -175,6 +175,8 @@ function _checked_strides(stp, N) map(first, drs) end +_checkcontiguous(::Type{Bool}, A::ReinterpretArray) = _checkcontiguous(parent(A)) + similar(a::ReinterpretArray, T::Type, d::Dims) = similar(a.parent, T, d) function check_readable(a::ReinterpretArray{T, N, S} where N) where {T,S} diff --git a/base/reshapedarray.jl b/base/reshapedarray.jl index cabe3c9d10a58..26f07972da951 100644 --- a/base/reshapedarray.jl +++ b/base/reshapedarray.jl @@ -292,3 +292,16 @@ substrides(strds::NTuple{N,Int}, I::Tuple{ReshapedUnitRange, Vararg{Any}}) where (size_to_strides(strds[1], size(I[1])...)..., substrides(tail(strds), tail(I))...) unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex,ReshapedUnitRange}}}}) where {T,N,P} = unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T) + + +_checkcontiguous(::Type{Bool}, A::AbstractArray) = size_to_strides(1, size(A)...) == strides(A) +_checkcontiguous(::Type{Bool}, A::DenseArray) = true +_checkcontiguous(::Type{Bool}, A::ReshapedArray) = _checkcontiguous(Bool, parent(A)) +_checkcontiguous(::Type{Bool}, A::FastContiguousSubArray) = _checkcontiguous(Bool, parent(A)) + +function strides(a::ReshapedArray) + # We can handle non-contiguous parent if it's a StridedVector + ndims(parent(a)) == 1 && return size_to_strides(only(strides(parent(a))), size(a)...) + _checkcontiguous(Bool, a) || throw(ArgumentError("Parent must be contiguous.")) + size_to_strides(1, size(a)...) +end diff --git a/test/abstractarray.jl b/test/abstractarray.jl index a2f104fd9ec2f..a948af6ba16a3 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1561,3 +1561,10 @@ end r = Base.IdentityUnitRange(3:4) @test reshape(r, :) === reshape(r, (:,)) === r end + +@testset "strides for ReshapedArray" begin + a = reshape(view(collect(1:100), 1:2:100), 5, 10) + @test strides(a) == (2, 10) + a = reshape(view(collect(1:100), 1:2:100, 1:1), 5, 10) + @test_throws ArgumentError strides(a) +end