Skip to content

Commit

Permalink
Some vcat / hcat overloads for Number + StaticArray (#768)
Browse files Browse the repository at this point in the history
These extra overloads should be helpful for users though it's not really
possible to fix this in generality until the dispatch mechanism in Base
is improved.

Co-authored-by: Mateusz Baran <mateuszbaran89@gmail.com>
  • Loading branch information
c42f and mateuszbaran authored Oct 23, 2020
1 parent b95d07e commit b6f58ad
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,15 @@ end

#--------------------------------------------------
# Concatenation
@inline vcat(a::StaticVecOrMatLike) = a
@inline vcat(a::StaticMatrixLike) = a
@inline vcat(a::StaticVecOrMatLike, b::StaticVecOrMatLike) = _vcat(Size(a), Size(b), a, b)
@inline vcat(a::StaticVecOrMatLike, b::StaticVecOrMatLike, c::StaticVecOrMatLike...) = vcat(vcat(a,b), vcat(c...))
# A couple of hacky overloads to avoid some vcat surprises.
# We can't really make this work a lot better in general without Base providing
# a dispatch mechanism for output container type.
@inline vcat(a::StaticVector) = a
@inline vcat(a::StaticVector, bs::Number...) = vcat(a, SVector(bs))
@inline vcat(a::Number, b::StaticVector) = vcat(similar_type(b, typeof(a), Size(1))((a,)), b)

@generated function _vcat(::Size{Sa}, ::Size{Sb}, a::StaticVecOrMatLike, b::StaticVecOrMatLike) where {Sa, Sb}
if Size(Sa)[2] != Size(Sb)[2]
Expand Down Expand Up @@ -261,6 +267,9 @@ end
@inline hcat(a::StaticMatrixLike) = a
@inline hcat(a::StaticVecOrMatLike, b::StaticVecOrMatLike) = _hcat(Size(a), Size(b), a, b)
@inline hcat(a::StaticVecOrMatLike, b::StaticVecOrMatLike, c::StaticVecOrMatLike...) = hcat(hcat(a,b), hcat(c...))
@inline hcat(a::StaticMatrix{1}) = a # disambiguation
@inline hcat(a::StaticMatrix{1}, bs::Number...) = hcat(a, SMatrix{1,length(bs)}(bs))
@inline hcat(a::Number, b::StaticMatrix{1}) = hcat(similar_type(b, typeof(a), Size(1))((a,)), b)

@generated function _hcat(::Size{Sa}, ::Size{Sb}, a::StaticVecOrMatLike, b::StaticVecOrMatLike) where {Sa, Sb}
if Sa[1] != Sb[1]
Expand Down
12 changes: 12 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,16 @@ end
let A = Transpose(@SMatrix [1 2; 3 4]), B = Adjoint(@SMatrix [5 6; 7 8])
@test @inferred(vcat(A, B)) === SMatrix{4, 2}([Matrix(A); Matrix(B)])
end

# hcat/vcat + mixtures of Number and SVector / SMatrix
@test @inferred(vcat(SA[1,2,3], 4, 5, 6)) === SVector{6}((1,2,3,4,5,6))
@test @inferred(vcat(0, SA[1,2,3])) === SVector{4}((0,1,2,3))
@test @inferred(hcat(SMatrix{1,3}((1,2,3)), 4, 5, 6)) === SMatrix{1,6}((1,2,3,4,5,6))
@test @inferred(hcat(0, SMatrix{1,3}((1,2,3)))) === SMatrix{1,4}((0,1,2,3))
@test @inferred(vcat(MVector((1,2,3)), 4, 5, 6))::MVector == [1,2,3,4,5,6]

@test @inferred(vcat(SA[1,2,3])) === SA[1,2,3]
@test @inferred(vcat(SA[1 2 3])) === SA[1 2 3]
@test @inferred(hcat(SA[1,2,3])) === SMatrix{3,1}(1,2,3)
@test @inferred(hcat(SA[1 2 3])) === SA[1 2 3]
end

0 comments on commit b6f58ad

Please sign in to comment.