Skip to content

Commit

Permalink
Remove unnecessary restriction to StridedVecOrMat (#35929)
Browse files Browse the repository at this point in the history
* Remove unnecessary restriction to `StridedVecOrMat`

The "Strided array interface" https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays-1 means that this is useful beyond these types

* Update adjtrans.jl

* Add tests for adj/trans strides

* Add tests, change strides(::Adjoint{<:Any,<:AbstractVector}) definition

* stride(::AbstractrArray, k) for all k, add ConjPtr

* Remove ConjPtr

* Always throw an error if strides is not implemented

* Update abstractarray.jl

* Update blas.jl

* Remove k < 1 special case

* Also widen elsize to AbstractVecOrMat

* Use strides for dim > ndims

* Update stdlib/LinearAlgebra/test/blas.jl

Co-authored-by: Matt Bauman <mbauman@gmail.com>
(cherry picked from commit 6b2c7f1)
  • Loading branch information
dlfivefifty authored and KristofferC committed Jul 8, 2020
1 parent 24f033c commit 7ad6877
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 17 deletions.
6 changes: 5 additions & 1 deletion base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,11 @@ julia> stride(A,3)
12
```
"""
stride(A::AbstractArray, k::Integer) = strides(A)[k]
function stride(A::AbstractArray, k::Integer)
st = strides(A)
k ndims(A) && return st[k]
return sum(st .* size(A))
end

@inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...)
size_to_strides(s, d) = (s,)
Expand Down
16 changes: 8 additions & 8 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,17 @@ convert(::Type{Adjoint{T,S}}, A::Adjoint) where {T,S} = Adjoint{T,S}(convert(S,
convert(::Type{Transpose{T,S}}, A::Transpose) where {T,S} = Transpose{T,S}(convert(S, A.parent))

# Strides and pointer for transposed strided arrays — but only if the elements are actually stored in memory
Base.strides(A::Adjoint{<:Real, <:StridedVector}) = (stride(A.parent, 2), stride(A.parent, 1))
Base.strides(A::Transpose{<:Any, <:StridedVector}) = (stride(A.parent, 2), stride(A.parent, 1))
Base.strides(A::Adjoint{<:Real, <:AbstractVector}) = (stride(A.parent, 2), stride(A.parent, 1))
Base.strides(A::Transpose{<:Any, <:AbstractVector}) = (stride(A.parent, 2), stride(A.parent, 1))
# For matrices it's slightly faster to use reverse and avoid calling stride twice
Base.strides(A::Adjoint{<:Real, <:StridedMatrix}) = reverse(strides(A.parent))
Base.strides(A::Transpose{<:Any, <:StridedMatrix}) = reverse(strides(A.parent))
Base.strides(A::Adjoint{<:Real, <:AbstractMatrix}) = reverse(strides(A.parent))
Base.strides(A::Transpose{<:Any, <:AbstractMatrix}) = reverse(strides(A.parent))

Base.unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
Base.unsafe_convert(::Type{Ptr{T}}, A::Transpose{<:Any, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
Base.unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real, <:AbstractVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
Base.unsafe_convert(::Type{Ptr{T}}, A::Transpose{<:Any, <:AbstractVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)

Base.elsize(::Type{<:Adjoint{<:Real, P}}) where {P<:StridedVecOrMat} = Base.elsize(P)
Base.elsize(::Type{<:Transpose{<:Any, P}}) where {P<:StridedVecOrMat} = Base.elsize(P)
Base.elsize(::Type{<:Adjoint{<:Real, P}}) where {P<:AbstractVecOrMat} = Base.elsize(P)
Base.elsize(::Type{<:Transpose{<:Any, P}}) where {P<:AbstractVecOrMat} = Base.elsize(P)

# for vectors, the semantics of the wrapped and unwrapped types differ
# so attempt to maintain both the parent and wrapper type insofar as possible
Expand Down
39 changes: 38 additions & 1 deletion stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,44 @@ Base.setindex!(A::WrappedArray, v, i::Int) = setindex!(A.A, v, i)
Base.setindex!(A::WrappedArray{T, N}, v, I::Vararg{Int, N}) where {T, N} = setindex!(A.A, v, I...)
Base.unsafe_convert(::Type{Ptr{T}}, A::WrappedArray{T}) where T = Base.unsafe_convert(Ptr{T}, A.A)

Base.stride(A::WrappedArray, i::Int) = stride(A.A, i)
Base.strides(A::WrappedArray) = strides(A.A)

@testset "strided interface adjtrans" begin
x = WrappedArray([1, 2, 3, 4])
@test stride(x,1) == 1
@test stride(x,2) == stride(x,3) == 4
@test strides(x') == strides(transpose(x)) == (4,1)
@test pointer(x') == pointer(transpose(x)) == pointer(x)
@test_throws BoundsError stride(x,0)

A = WrappedArray([1 2; 3 4; 5 6])
@test stride(A,1) == 1
@test stride(A,2) == 3
@test stride(A,3) == stride(A,4) >= 6
@test strides(A') == strides(transpose(A)) == (3,1)
@test pointer(A') == pointer(transpose(A)) == pointer(A)
@test_throws BoundsError stride(A,0)

y = WrappedArray([1+im, 2, 3, 4])
@test strides(transpose(y)) == (4,1)
@test pointer(transpose(y)) == pointer(y)
@test_throws MethodError strides(y')
@test_throws ErrorException pointer(y')

B = WrappedArray([1+im 2; 3 4; 5 6])
@test strides(transpose(B)) == (3,1)
@test pointer(transpose(B)) == pointer(B)
@test_throws MethodError strides(B')
@test_throws ErrorException pointer(B')

@test_throws MethodError stride(1:5,0)
@test_throws MethodError stride(1:5,1)
@test_throws MethodError stride(1:5,2)
@test_throws MethodError strides(transpose(1:5))
@test_throws MethodError strides((1:5)')
@test_throws ErrorException pointer(transpose(1:5))
@test_throws ErrorException pointer((1:5)')
end

@testset "strided interface blas" begin
for elty in (Float32, Float64, ComplexF32, ComplexF64)
Expand Down
21 changes: 14 additions & 7 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1082,36 +1082,43 @@ end
Ap = Base.PermutedDimsArray(A, perm)
At = transpose(A)
Aa = adjoint(A)
St = transpose(A)
Sa = adjoint(A)
Sp = Base.PermutedDimsArray(S, perm)
Ps = Strider{Int, 2}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)])
@test pointer(Ap) == pointer(Sp) == pointer(Ps) == pointer(At) == pointer(Aa)
for i in 1:length(Ap)
# This is intentionally disabled due to ambiguity
@test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i) == pointer(At, i) == pointer(Aa, i)
@test pointer(Ps, i) == pointer(At, i) == pointer(Aa, i)
@test P[i] == Ap[i] == Sp[i] == Ps[i] == At[i] == Aa[i]
@test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i) == pointer(At, i) == pointer(Aa, i) == pointer(St, i) == pointer(Sa, i)
@test pointer(Ps, i) == pointer(At, i) == pointer(Aa, i) == pointer(St, i) == pointer(Sa, i)
@test P[i] == Ap[i] == Sp[i] == Ps[i] == At[i] == Aa[i] == St[i] == Sa[i]
end
Pv = view(P, idxs[collect(perm)]...)
Apv = view(Ap, idxs[collect(perm)]...)
Atv = view(At, idxs[collect(perm)]...)
Ata = view(Aa, idxs[collect(perm)]...)
Stv = view(St, idxs[collect(perm)]...)
Sta = view(Sa, idxs[collect(perm)]...)
Spv = view(Sp, idxs[collect(perm)]...)
Pvs = Strider{Int, 2}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Apv), size(Apv))
@test pointer(Apv) == pointer(Spv) == pointer(Pvs) == pointer(Atv) == pointer(Ata)
for i in 1:length(Apv)
@test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i) == pointer(Atv, i) == pointer(Ata, i)
@test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i]
@test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i) == pointer(Atv, i) == pointer(Ata, i) == pointer(Stv, i) == pointer(Sta, i)
@test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i] == Stv[i] == Sta[i]
end
Vp = permutedims(Av, perm)
Avp = Base.PermutedDimsArray(Av, perm)
Avt = transpose(Av)
Ava = adjoint(Av)
Svt = transpose(Sv)
Sva = adjoint(Sv)
Svp = Base.PermutedDimsArray(Sv, perm)
@test pointer(Avp) == pointer(Svp) == pointer(Avt) == pointer(Ava)
for i in 1:length(Avp)
# This is intentionally disabled due to ambiguity
@test_broken pointer(Avp, i) == pointer(Svp, i) == pointer(Avt, i) == pointer(Ava, i)
@test Vp[i] == Avp[i] == Svp[i] == Avt[i] == Ava[i]
@test_broken pointer(Avp, i) == pointer(Svp, i) == pointer(Avt, i) == pointer(Ava, i) == pointer(Svt, i) == pointer(Sva, i)
@test pointer(Avt, i) == pointer(Ava, i) == pointer(Svt, i) == pointer(Sva, i)
@test Vp[i] == Avp[i] == Svp[i] == Avt[i] == Ava[i] == Svt[i] == Sva[i]
end
end
end

0 comments on commit 7ad6877

Please sign in to comment.