Skip to content

Commit

Permalink
Some fix for ReshapedReinterpretArray
Browse files Browse the repository at this point in the history
and add more test
  • Loading branch information
N5N3 committed Feb 3, 2022
1 parent 96a9cd6 commit a262868
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 22 deletions.
1 change: 1 addition & 0 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ function strides(a::ReshapedReinterpretArray)
stp = strides(ap)
els == elp && return stp
els < elp && return (1, map(Fix2(*, elp ÷ els), stp)...)
stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
return _checked_strides(stp, els ÷ elp)
end

Expand Down
62 changes: 40 additions & 22 deletions test/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,33 +157,51 @@ let A = collect(reshape(1:20, 5, 4))
@test reshape(R, :) isa StridedArray
end

function check_strides(A::AbstractArray)
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
dims = ntuple(identity, ndims(A))
map(i -> stride(A, i), dims) == strides(A) || return false
# Test strides via value check.
for i in eachindex(IndexLinear(), A)
A[i] === Base.unsafe_load(pointer(A, i)) || return false
end
return true
end

@testset "strides for NonReshapedReinterpretArray" begin
A = Matrix{Int16}(reshape(1:80, 20, 4))
for (T, st) in ((Int8, (1, 40)), (Int32, (1, 10)))
R = reinterpret(T, view(A, :, 1:2))
@test (stride(R, 1), stride(R, 2)) == strides(R) == st
R = reinterpret(T, view(A, 1:18, :))
@test (stride(R, 1), stride(R, 2)) == strides(R) == st
A = Array{Int32}(reshape(1:72, 9, 8))
for viewax2 in (1:8, 1:2:6, 7:-1:1, 5:-2:1)
# dim1 is contiguous
for T in (Int16, Float32)
@test check_strides(reinterpret(T, view(A, 1:8, viewax2)))
end
if mod(step(viewax2), 2) == 0
@test check_strides(reinterpret(Int64, view(A, 1:8, viewax2)))
else
@test_throws "Parent's strides" strides(reinterpret(Int64, view(A, 1:8, viewax2)))
end
# dim1 is not contiguous
for T in (Int16, Int64)
@test_throws "Parent must" strides(reinterpret(T, view(A, 8:-1:1, viewax2)))
end
@test check_strides(reinterpret(Float32, view(A, 8:-1:1, viewax2)))
end
A = Matrix{Int16}(reshape(1:76, 19, 4))
R = reinterpret(Int8, view(A, 1:18, :))
@test (stride(R, 1), stride(R, 2)) == strides(R) == (1, 38)
R = reinterpret(Int32, view(A, 1:18, :))
@test_throws ArgumentError strides(R)
R = reinterpret(Int8, view(A, 18:-1:1, :))
@test_throws ArgumentError strides(R)
R = reinterpret(Int8, view(A, 1:2:18, :))
@test_throws ArgumentError strides(R)
end

@testset "strides for ReshapedReinterpretArray" begin
A = Matrix{Int16}(reshape(1:12, 3, 4))
R = reinterpret(reshape, Int8, view(A, 1:2, 1:2))
@test (stride(R, 1), stride(R, 2), stride(R, 3)) == strides(R) == (1, 2, 6)
R = reinterpret(reshape, NTuple{3,Int16}, view(A, 1:3, 1:2))
@test (stride(R, 1),) == strides(R) == (1,)
R = reinterpret(reshape, Int32, view(A, 1:2, 1:2))
@test_throws ArgumentError strides(R)
A = Array{Int32}(reshape(1:192, 3, 8, 8))
for viewax1 in (1:8, 1:2:8, 8:-1:1, 8:-2:1), viewax2 in (1:2, 4:-1:1)
for T in (Int16, Float32)
@test check_strides(reinterpret(reshape, T, view(A, 1:2, viewax1, viewax2)))
@test check_strides(reinterpret(reshape, T, view(A, 1:2:3, viewax1, viewax2)))
end
if mod(step(viewax1), 2) == 0
@test check_strides(reinterpret(reshape, Int64, view(A, 1:2, viewax1, viewax2)))
else
@test_throws "Parent's strides" strides(reinterpret(reshape, Int64, view(A, 1:2, viewax1, viewax2)))
end
@test_throws "Parent must" strides(reinterpret(reshape, Int64, view(A, 1:2:3, viewax1, viewax2)))
end
end

@testset "strides" begin
Expand Down

0 comments on commit a262868

Please sign in to comment.