diff --git a/base/abstractarray.jl b/base/abstractarray.jl index b55d3c0e773e5..c215b8f8aad7a 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -392,7 +392,11 @@ julia> stride(A,3) 12 ``` """ -stride(A::AbstractArray, k::Integer) = strides(A)[k] +function stride(A::AbstractArray, k::Integer) + st = strides(A) + k ≤ ndims(A) && return st[k] + return sum(st .* size(A)) +end @inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...) size_to_strides(s, d) = (s,) diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index 458beea92604e..0fd339c5d05f3 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -199,17 +199,17 @@ convert(::Type{Adjoint{T,S}}, A::Adjoint) where {T,S} = Adjoint{T,S}(convert(S, convert(::Type{Transpose{T,S}}, A::Transpose) where {T,S} = Transpose{T,S}(convert(S, A.parent)) # Strides and pointer for transposed strided arrays — but only if the elements are actually stored in memory -Base.strides(A::Adjoint{<:Real, <:StridedVector}) = (stride(A.parent, 2), stride(A.parent, 1)) -Base.strides(A::Transpose{<:Any, <:StridedVector}) = (stride(A.parent, 2), stride(A.parent, 1)) +Base.strides(A::Adjoint{<:Real, <:AbstractVector}) = (stride(A.parent, 2), stride(A.parent, 1)) +Base.strides(A::Transpose{<:Any, <:AbstractVector}) = (stride(A.parent, 2), stride(A.parent, 1)) # For matrices it's slightly faster to use reverse and avoid calling stride twice -Base.strides(A::Adjoint{<:Real, <:StridedMatrix}) = reverse(strides(A.parent)) -Base.strides(A::Transpose{<:Any, <:StridedMatrix}) = reverse(strides(A.parent)) +Base.strides(A::Adjoint{<:Real, <:AbstractMatrix}) = reverse(strides(A.parent)) +Base.strides(A::Transpose{<:Any, <:AbstractMatrix}) = reverse(strides(A.parent)) -Base.unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent) -Base.unsafe_convert(::Type{Ptr{T}}, A::Transpose{<:Any, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent) +Base.unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real, <:AbstractVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent) +Base.unsafe_convert(::Type{Ptr{T}}, A::Transpose{<:Any, <:AbstractVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent) -Base.elsize(::Type{<:Adjoint{<:Real, P}}) where {P<:StridedVecOrMat} = Base.elsize(P) -Base.elsize(::Type{<:Transpose{<:Any, P}}) where {P<:StridedVecOrMat} = Base.elsize(P) +Base.elsize(::Type{<:Adjoint{<:Real, P}}) where {P<:AbstractVecOrMat} = Base.elsize(P) +Base.elsize(::Type{<:Transpose{<:Any, P}}) where {P<:AbstractVecOrMat} = Base.elsize(P) # for vectors, the semantics of the wrapped and unwrapped types differ # so attempt to maintain both the parent and wrapper type insofar as possible diff --git a/stdlib/LinearAlgebra/test/blas.jl b/stdlib/LinearAlgebra/test/blas.jl index 23c6d68cdc997..e34e721675c61 100644 --- a/stdlib/LinearAlgebra/test/blas.jl +++ b/stdlib/LinearAlgebra/test/blas.jl @@ -458,7 +458,44 @@ Base.setindex!(A::WrappedArray, v, i::Int) = setindex!(A.A, v, i) Base.setindex!(A::WrappedArray{T, N}, v, I::Vararg{Int, N}) where {T, N} = setindex!(A.A, v, I...) Base.unsafe_convert(::Type{Ptr{T}}, A::WrappedArray{T}) where T = Base.unsafe_convert(Ptr{T}, A.A) -Base.stride(A::WrappedArray, i::Int) = stride(A.A, i) +Base.strides(A::WrappedArray) = strides(A.A) + +@testset "strided interface adjtrans" begin + x = WrappedArray([1, 2, 3, 4]) + @test stride(x,1) == 1 + @test stride(x,2) == stride(x,3) == 4 + @test strides(x') == strides(transpose(x)) == (4,1) + @test pointer(x') == pointer(transpose(x)) == pointer(x) + @test_throws BoundsError stride(x,0) + + A = WrappedArray([1 2; 3 4; 5 6]) + @test stride(A,1) == 1 + @test stride(A,2) == 3 + @test stride(A,3) == stride(A,4) >= 6 + @test strides(A') == strides(transpose(A)) == (3,1) + @test pointer(A') == pointer(transpose(A)) == pointer(A) + @test_throws BoundsError stride(A,0) + + y = WrappedArray([1+im, 2, 3, 4]) + @test strides(transpose(y)) == (4,1) + @test pointer(transpose(y)) == pointer(y) + @test_throws MethodError strides(y') + @test_throws ErrorException pointer(y') + + B = WrappedArray([1+im 2; 3 4; 5 6]) + @test strides(transpose(B)) == (3,1) + @test pointer(transpose(B)) == pointer(B) + @test_throws MethodError strides(B') + @test_throws ErrorException pointer(B') + + @test_throws MethodError stride(1:5,0) + @test_throws MethodError stride(1:5,1) + @test_throws MethodError stride(1:5,2) + @test_throws MethodError strides(transpose(1:5)) + @test_throws MethodError strides((1:5)') + @test_throws ErrorException pointer(transpose(1:5)) + @test_throws ErrorException pointer((1:5)') +end @testset "strided interface blas" begin for elty in (Float32, Float64, ComplexF32, ComplexF64) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 837cfe62bf6ae..04fee20085c54 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1082,36 +1082,43 @@ end Ap = Base.PermutedDimsArray(A, perm) At = transpose(A) Aa = adjoint(A) + St = transpose(A) + Sa = adjoint(A) Sp = Base.PermutedDimsArray(S, perm) Ps = Strider{Int, 2}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)]) @test pointer(Ap) == pointer(Sp) == pointer(Ps) == pointer(At) == pointer(Aa) for i in 1:length(Ap) # This is intentionally disabled due to ambiguity - @test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i) == pointer(At, i) == pointer(Aa, i) - @test pointer(Ps, i) == pointer(At, i) == pointer(Aa, i) - @test P[i] == Ap[i] == Sp[i] == Ps[i] == At[i] == Aa[i] + @test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i) == pointer(At, i) == pointer(Aa, i) == pointer(St, i) == pointer(Sa, i) + @test pointer(Ps, i) == pointer(At, i) == pointer(Aa, i) == pointer(St, i) == pointer(Sa, i) + @test P[i] == Ap[i] == Sp[i] == Ps[i] == At[i] == Aa[i] == St[i] == Sa[i] end Pv = view(P, idxs[collect(perm)]...) Apv = view(Ap, idxs[collect(perm)]...) Atv = view(At, idxs[collect(perm)]...) Ata = view(Aa, idxs[collect(perm)]...) + Stv = view(St, idxs[collect(perm)]...) + Sta = view(Sa, idxs[collect(perm)]...) Spv = view(Sp, idxs[collect(perm)]...) Pvs = Strider{Int, 2}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Apv), size(Apv)) @test pointer(Apv) == pointer(Spv) == pointer(Pvs) == pointer(Atv) == pointer(Ata) for i in 1:length(Apv) - @test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i) == pointer(Atv, i) == pointer(Ata, i) - @test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i] + @test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i) == pointer(Atv, i) == pointer(Ata, i) == pointer(Stv, i) == pointer(Sta, i) + @test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i] == Stv[i] == Sta[i] end Vp = permutedims(Av, perm) Avp = Base.PermutedDimsArray(Av, perm) Avt = transpose(Av) Ava = adjoint(Av) + Svt = transpose(Sv) + Sva = adjoint(Sv) Svp = Base.PermutedDimsArray(Sv, perm) @test pointer(Avp) == pointer(Svp) == pointer(Avt) == pointer(Ava) for i in 1:length(Avp) # This is intentionally disabled due to ambiguity - @test_broken pointer(Avp, i) == pointer(Svp, i) == pointer(Avt, i) == pointer(Ava, i) - @test Vp[i] == Avp[i] == Svp[i] == Avt[i] == Ava[i] + @test_broken pointer(Avp, i) == pointer(Svp, i) == pointer(Avt, i) == pointer(Ava, i) == pointer(Svt, i) == pointer(Sva, i) + @test pointer(Avt, i) == pointer(Ava, i) == pointer(Svt, i) == pointer(Sva, i) + @test Vp[i] == Avp[i] == Svp[i] == Avt[i] == Ava[i] == Svt[i] == Sva[i] end end end