Skip to content

Commit

Permalink
Don't use broader parameter type than allowed by StaticArray
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Dec 19, 2018
1 parent 93b7ea6 commit 3972ccd
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/MMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ end
end
end

@inline convert(::Type{MMatrix{S1,S2}}, a::StaticArray{<:Any, T}) where {S1,S2,T} = MMatrix{S1,S2,T}(Tuple(a))
@inline convert(::Type{MMatrix{S1,S2}}, a::StaticArray{<:Tuple, T}) where {S1,S2,T} = MMatrix{S1,S2,T}(Tuple(a))
@inline MMatrix(a::StaticMatrix) = MMatrix{size(typeof(a),1),size(typeof(a),2)}(Tuple(a))

# Simplified show for the type
Expand Down
2 changes: 1 addition & 1 deletion src/SMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end
end
end

@inline convert(::Type{SMatrix{S1,S2}}, a::StaticArray{<:Any, T}) where {S1,S2,T} = SMatrix{S1,S2,T}(Tuple(a))
@inline convert(::Type{SMatrix{S1,S2}}, a::StaticArray{<:Tuple, T}) where {S1,S2,T} = SMatrix{S1,S2,T}(Tuple(a))
@inline SMatrix(a::StaticMatrix) = SMatrix{size(typeof(a),1),size(typeof(a),2)}(Tuple(a))

# Simplified show for the type
Expand Down
2 changes: 1 addition & 1 deletion src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ const StaticMatrixLike{T} = Union{
Diagonal{T, <:StaticVector{<:Any, T}}
}
const StaticVecOrMatLike{T} = Union{StaticVector{<:Any, T}, StaticMatrixLike{T}}
const StaticArrayLike{T} = Union{StaticVecOrMatLike{T}, StaticArray{<:Any, T}}
const StaticArrayLike{T} = Union{StaticVecOrMatLike{T}, StaticArray{<:Tuple, T}}

const AbstractScalar{T} = AbstractArray{T, 0} # not exported, but useful none-the-less
const StaticArrayNoEltype{S, N, T} = StaticArray{S, T, N}
Expand Down
4 changes: 1 addition & 3 deletions src/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ end
end
end

# ambiguity with AbstractRNG and non-Float64... possibly an optimized form in Base?
@inline rand!(rng::MersenneTwister, a::SA) where {SA <: StaticArray{<:Any, Float64}} = _rand!(rng, Size(SA), a)
@inline rand!(rng::MersenneTwister, a::SA) where {SA <: StaticArray{<:Tuple, Float64, <:Any}} = _rand!(rng, Size(SA), a)
@inline rand!(rng::MersenneTwister, a::SA) where {SA <: StaticArray{<:Tuple, Float64}} = _rand!(rng, Size(SA), a)

@inline randn!(rng::AbstractRNG, a::SA) where {SA <: StaticArray} = _randn!(rng, Size(SA), a)
@generated function _randn!(rng::AbstractRNG, ::Size{s}, a::SA) where {s, SA <: StaticArray}
Expand Down
6 changes: 3 additions & 3 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ using Base.Broadcast: _bcsm
# A constructor that changes the style parameter N (array dimension) is also required
struct StaticArrayStyle{N} <: AbstractArrayStyle{N} end
StaticArrayStyle{M}(::Val{N}) where {M,N} = StaticArrayStyle{N}()
BroadcastStyle(::Type{<:StaticArray{<:Any, <:Any, N}}) where {N} = StaticArrayStyle{N}()
BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray{<:Any, <:Any, N}}}) where {N} = StaticArrayStyle{N}()
BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray{<:Any, <:Any, N}}}) where {N} = StaticArrayStyle{N}()
BroadcastStyle(::Type{<:StaticArray{<:Tuple, <:Any, N}}) where {N} = StaticArrayStyle{N}()
BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}()
BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}()
# Precedence rules
BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
DefaultArrayStyle(Val(max(M, N)))
Expand Down
2 changes: 1 addition & 1 deletion src/deque.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ end
# could also be justified to live in src/indexing.jl
import Base.setindex
@propagate_inbounds setindex(a::StaticArray, x, index::Int) = _setindex(Length(a), a, convert(eltype(typeof(a)), x), index)
@generated function _setindex(::Length{L}, a::StaticArray{<:Any,T}, x::T, index::Int) where {L, T}
@generated function _setindex(::Length{L}, a::StaticArray{<:Tuple,T}, x::T, index::Int) where {L, T}
exprs = [:(ifelse($i == index, x, a[$i])) for i = 1:L]
return quote
@_propagate_inbounds_meta
Expand Down
4 changes: 2 additions & 2 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ end

# getindex

@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Any, Int}, Colon}...)
@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, Colon}...)
_getindex(a, index_sizes(Size(a), inds...), inds)
end

Expand Down Expand Up @@ -247,7 +247,7 @@ end

# setindex!

@propagate_inbounds function setindex!(a::StaticArray, value, inds::Union{Int, StaticArray{<:Any, Int}, Colon}...)
@propagate_inbounds function setindex!(a::StaticArray, value, inds::Union{Int, StaticArray{<:Tuple, Int}, Colon}...)
_setindex!(a, value, index_sizes(Size(a), inds...), inds)
end

Expand Down
20 changes: 10 additions & 10 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,23 +176,23 @@ end
# with an initial value v0 = true and false.
#
# TODO: change to use Base.reduce_empty/Base.reduce_first
@inline iszero(a::StaticArray{<:Any,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true)
@inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true)

@inline sum(a::StaticArray{<:Any,T}; dims=:) where {T} = reduce(+, a; dims=dims)
@inline sum(f, a::StaticArray{<:Any,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims)
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Any,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims) # avoid ambiguity
@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = reduce(+, a; dims=dims)
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims)
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims) # avoid ambiguity

@inline prod(a::StaticArray{<:Any,T}; dims=:) where {T} = reduce(*, a; dims=dims)
@inline prod(f, a::StaticArray{<:Any,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Any,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)
@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = reduce(*, a; dims=dims)
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)

@inline count(a::StaticArray{<:Any,Bool}; dims=:) = reduce(+, a; dims=dims)
@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(+, a; dims=dims)
@inline count(f, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, +, a; dims=dims)

@inline all(a::StaticArray{<:Any,Bool}; dims=:) = reduce(&, a; dims=dims, init=true) # non-branching versions
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(&, a; dims=dims, init=true) # non-branching versions
@inline all(f::Function, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, &, a; dims=dims, init=true)

@inline any(a::StaticArray{<:Any,Bool}; dims=:) = reduce(|, a; dims=dims, init=false) # (benchmarking needed)
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(|, a; dims=dims, init=false) # (benchmarking needed)
@inline any(f::Function, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, |, a; dims=dims, init=false) # (benchmarking needed)

_mean_denom(a, dims::Colon) = length(a)
Expand Down
2 changes: 1 addition & 1 deletion src/matrix_multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ end
end
end

@generated function partly_unrolled_multiply(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticArray{<:Any, Tb}) where {sa, sb, Ta, Tb}
@generated function partly_unrolled_multiply(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticArray{<:Tuple, Tb}) where {sa, sb, Ta, Tb}
if sa[2] != sb[1]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
end
Expand Down
8 changes: 4 additions & 4 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ end
end
end

@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::StaticArray{<:Any,TA}, B::UpperTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB}
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::StaticArray{<:Tuple,TA}, B::UpperTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB}
m = sa[1]
if length(sa) == 1
n = 1
Expand Down Expand Up @@ -225,7 +225,7 @@ end
end
end

@generated function _A_mul_Bc(::Size{sa}, ::Size{sb}, A::StaticArray{<:Any,TA}, B::UpperTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB}
@generated function _A_mul_Bc(::Size{sa}, ::Size{sb}, A::StaticArray{<:Tuple,TA}, B::UpperTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB}
m = sa[1]
if length(sa) == 1
n = 1
Expand Down Expand Up @@ -284,7 +284,7 @@ end
end
end

@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::StaticArray{<:Any,TA}, B::LowerTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB}
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::StaticArray{<:Tuple,TA}, B::LowerTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB}
m = sa[1]
if length(sa) == 1
n = 1
Expand Down Expand Up @@ -316,7 +316,7 @@ end
end
end

@generated function _A_mul_Bc(::Size{sa}, ::Size{sb}, A::StaticArray{<:Any,TA}, B::LowerTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB}
@generated function _A_mul_Bc(::Size{sa}, ::Size{sb}, A::StaticArray{<:Tuple,TA}, B::LowerTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB}
m = sa[1]
if length(sa) == 1
n = 1
Expand Down

0 comments on commit 3972ccd

Please sign in to comment.