diff --git a/base/reinterpretarray.jl b/base/reinterpretarray.jl index c5b2f0db735c9..ee3002b6eb2a3 100644 --- a/base/reinterpretarray.jl +++ b/base/reinterpretarray.jl @@ -156,6 +156,7 @@ function strides(a::ReshapedReinterpretArray) stp = strides(ap) els == elp && return stp els < elp && return (1, map(Fix2(*, elp ÷ els), stp)...) + stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!")) return _checked_strides(stp, els ÷ elp) end diff --git a/test/reinterpretarray.jl b/test/reinterpretarray.jl index 37b707b8f7508..de246fd3c5844 100644 --- a/test/reinterpretarray.jl +++ b/test/reinterpretarray.jl @@ -157,33 +157,51 @@ let A = collect(reshape(1:20, 5, 4)) @test reshape(R, :) isa StridedArray end +function check_strides(A::AbstractArray) + # Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A)) + dims = ntuple(identity, ndims(A)) + map(i -> stride(A, i), dims) == strides(A) || return false + # Test strides via value check. + for i in eachindex(IndexLinear(), A) + A[i] === Base.unsafe_load(pointer(A, i)) || return false + end + return true +end + @testset "strides for NonReshapedReinterpretArray" begin - A = Matrix{Int16}(reshape(1:80, 20, 4)) - for (T, st) in ((Int8, (1, 40)), (Int32, (1, 10))) - R = reinterpret(T, view(A, :, 1:2)) - @test (stride(R, 1), stride(R, 2)) == strides(R) == st - R = reinterpret(T, view(A, 1:18, :)) - @test (stride(R, 1), stride(R, 2)) == strides(R) == st + A = Array{Int32}(reshape(1:72, 9, 8)) + for viewax2 in (1:8, 1:2:6, 7:-1:1, 5:-2:1) + # dim1 is contiguous + for T in (Int16, Float32) + @test check_strides(reinterpret(T, view(A, 1:8, viewax2))) + end + if mod(step(viewax2), 2) == 0 + @test check_strides(reinterpret(Int64, view(A, 1:8, viewax2))) + else + @test_throws "Parent's strides" strides(reinterpret(Int64, view(A, 1:8, viewax2))) + end + # dim1 is not contiguous + for T in (Int16, Int64) + @test_throws "Parent must" strides(reinterpret(T, view(A, 8:-1:1, viewax2))) + end + @test check_strides(reinterpret(Float32, view(A, 8:-1:1, viewax2))) end - A = Matrix{Int16}(reshape(1:76, 19, 4)) - R = reinterpret(Int8, view(A, 1:18, :)) - @test (stride(R, 1), stride(R, 2)) == strides(R) == (1, 38) - R = reinterpret(Int32, view(A, 1:18, :)) - @test_throws ArgumentError strides(R) - R = reinterpret(Int8, view(A, 18:-1:1, :)) - @test_throws ArgumentError strides(R) - R = reinterpret(Int8, view(A, 1:2:18, :)) - @test_throws ArgumentError strides(R) end @testset "strides for ReshapedReinterpretArray" begin - A = Matrix{Int16}(reshape(1:12, 3, 4)) - R = reinterpret(reshape, Int8, view(A, 1:2, 1:2)) - @test (stride(R, 1), stride(R, 2), stride(R, 3)) == strides(R) == (1, 2, 6) - R = reinterpret(reshape, NTuple{3,Int16}, view(A, 1:3, 1:2)) - @test (stride(R, 1),) == strides(R) == (1,) - R = reinterpret(reshape, Int32, view(A, 1:2, 1:2)) - @test_throws ArgumentError strides(R) + A = Array{Int32}(reshape(1:192, 3, 8, 8)) + for viewax1 in (1:8, 1:2:8, 8:-1:1, 8:-2:1), viewax2 in (1:2, 4:-1:1) + for T in (Int16, Float32) + @test check_strides(reinterpret(reshape, T, view(A, 1:2, viewax1, viewax2))) + @test check_strides(reinterpret(reshape, T, view(A, 1:2:3, viewax1, viewax2))) + end + if mod(step(viewax1), 2) == 0 + @test check_strides(reinterpret(reshape, Int64, view(A, 1:2, viewax1, viewax2))) + else + @test_throws "Parent's strides" strides(reinterpret(reshape, Int64, view(A, 1:2, viewax1, viewax2))) + end + @test_throws "Parent must" strides(reinterpret(reshape, Int64, view(A, 1:2:3, viewax1, viewax2))) + end end @testset "strides" begin