Skip to content

Commit

Permalink
Merge pull request #659 from mateuszbaran/mbaran/faster-reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
c42f authored Sep 26, 2019
2 parents 8b77542 + d3623b9 commit ae78e52
Showing 1 changed file with 38 additions and 23 deletions.
61 changes: 38 additions & 23 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,23 @@ end
end
end

@inline _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S} =
_mapreduce(f, op, Val(D), nt, sz, a)
@inline function _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S}
# Body of this function is split because constant propagation (at least
# as of Julia 1.2) can't always correctly propagate here and
# as a result the function is not type stable and very slow.
# This makes it at least fast for three dimensions but people should use
# for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
if D == 1
return _mapreduce(f, op, Val(1), nt, sz, a)
elseif D == 2
return _mapreduce(f, op, Val(2), nt, sz, a)
elseif D == 3
return _mapreduce(f, op, Val(3), nt, sz, a)
else
return _mapreduce(f, op, Val(D), nt, sz, a)
end
end


@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{()},
::Size{S}, a::StaticArray) where {S,D}
N = length(S)
Expand Down Expand Up @@ -161,7 +174,9 @@ end
## reduce ##
############

@inline reduce(op, a::StaticArray; kw...) = mapreduce(identity, op, a; kw...)
@inline reduce(op, a::StaticArray; dims=:, kw...) = _reduce(op, a, dims, kw.data)

@inline _reduce(op, a::StaticArray, dims, kw::NamedTuple=NamedTuple()) = _mapreduce(identity, op, dims, kw, Size(a), a)

#######################
## related functions ##
Expand All @@ -186,38 +201,38 @@ end
# TODO: change to use Base.reduce_empty/Base.reduce_first
@inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true)

@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = reduce(+, a; dims=dims)
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims)
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims) # avoid ambiguity
@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(+, a, dims)
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a)
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) # avoid ambiguity

@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = reduce(*, a; dims=dims)
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)
@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(*, a, dims)
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)

@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(+, a; dims=dims)
@inline count(f, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, +, a; dims=dims)
@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(+, a, dims)
@inline count(f, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, +, dims, NamedTuple(), Size(a), a)

@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(&, a; dims=dims, init=true) # non-branching versions
@inline all(f::Function, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, &, a; dims=dims, init=true)
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, (init=true,)) # non-branching versions
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, (init=true,), Size(a), a)

@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(|, a; dims=dims, init=false) # (benchmarking needed)
@inline any(f::Function, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, |, a; dims=dims, init=false) # (benchmarking needed)
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, (init=false,)) # (benchmarking needed)
@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, (init=false,), Size(a), a) # (benchmarking needed)

@inline Base.in(x, a::StaticArray) = mapreduce(==(x), |, a, init=false)
@inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, (init=false,), Size(a), a)

_mean_denom(a, dims::Colon) = length(a)
_mean_denom(a, dims::Int) = size(a, dims)
_mean_denom(a, ::Val{D}) where {D} = size(a, D)
_mean_denom(a, ::Type{Val{D}}) where {D} = size(a, D)

@inline mean(a::StaticArray; dims=:) = sum(a; dims=dims) / _mean_denom(a,dims)
@inline mean(f::Function, a::StaticArray;dims=:) = sum(f, a; dims=dims) / _mean_denom(a,dims)
@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims)
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) / _mean_denom(a, dims)

@inline minimum(a::StaticArray; dims=:) = reduce(min, a; dims=dims) # base has mapreduce(idenity, scalarmin, a)
@inline minimum(f::Function, a::StaticArray; dims=:) = mapreduce(f, min, a; dims=dims)
@inline minimum(a::StaticArray; dims=:) = _reduce(min, a, dims) # base has mapreduce(idenity, scalarmin, a)
@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, NamedTuple(), Size(a), a)

@inline maximum(a::StaticArray; dims=:) = reduce(max, a; dims=dims) # base has mapreduce(idenity, scalarmax, a)
@inline maximum(f::Function, a::StaticArray; dims=:) = mapreduce(f, max, a; dims=dims)
@inline maximum(a::StaticArray; dims=:) = _reduce(max, a, dims) # base has mapreduce(idenity, scalarmax, a)
@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, NamedTuple(), Size(a), a)

# Diff is slightly different
@inline diff(a::StaticArray; dims) = _diff(Size(a), a, dims)
Expand Down

0 comments on commit ae78e52

Please sign in to comment.