Skip to content

Commit

Permalink
Merge pull request JuliaLang#70 from JuliaStats/dh/wsum2
Browse files Browse the repository at this point in the history
New implementation of weighted sum
  • Loading branch information
lindahua committed Jun 13, 2014
2 parents ff86193 + 87b8722 commit b976eac
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 106 deletions.
3 changes: 2 additions & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
julia 0.3-
julia 0.3-
ArrayViews 0.4.6-
1 change: 1 addition & 0 deletions runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using StatsBase

tests = ["mathfuns",
"weights",
"means",
"scalarstats",
"counts",
Expand Down
5 changes: 4 additions & 1 deletion src/StatsBase.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
module StatsBase
using ArrayViews

import Base: length, isempty, eltype, values, sum, mean, mean!, show, quantile
import Base: rand, rand!
import Base.LinAlg: BlasReal
import Base.LinAlg: BlasReal, BlasFloat
import Base.Cartesian: @ngenerate, @nloops, @nref, @nextract

export

Expand Down
15 changes: 0 additions & 15 deletions src/means.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,3 @@ function trimmean(x::RealArray, p::FloatingPoint)
end
end

# Weighted means

function wmean{T<:Number}(v::AbstractArray{T}, w::AbstractArray)
Base.depwarn("wmean is deprecated, use mean(v, weights(w)) instead.", :wmean)
mean(v, weights(w))
end

Base.mean(v::AbstractArray, w::WeightVec) = sum(v, w) / sum(w)

Base.mean!(r::AbstractArray, v::AbstractArray, w::WeightVec, dim::Int) =
scale!(Base.sum!(r, v, w, dim), inv(sum(w)))

Base.mean{T<:Number,W<:Real}(v::AbstractArray{T}, w::WeightVec{W}, dim::Int) =
mean!(Array(typeof((zero(T)*zero(W) + zero(T)*zero(W)) / one(W)), Base.reduced_dims(size(v), dim)), v, w, dim)

219 changes: 180 additions & 39 deletions src/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,60 +18,201 @@ values(wv::WeightVec) = wv.values
sum(wv::WeightVec) = wv.sum
isempty(wv::WeightVec) = isempty(wv.values)

Base.getindex(wv::WeightVec, i) = getindex(wv.values, i)


##### Weighted sum #####

# 1D weighted sum/mean
## weighted sum over vectors

wsum(v::AbstractVector, w::AbstractVector) = dot(v, w)
wsum(v::AbstractArray, w::AbstractVector) = dot(vec(v), w)

# Note: the methods for BitArray and SparseMatrixCSC are to avoid ambiguities
Base.sum(v::BitArray, w::WeightVec) = wsum(v, values(w))
Base.sum(v::SparseMatrixCSC, w::WeightVec) = wsum(v, values(w))
Base.sum(v::AbstractArray, w::WeightVec) = wsum(v, values(w))
Base.sum(v::AbstractArray, w::WeightVec) = dot(v, values(w))

# General Cartesian-based weighted sum across dimensions
import Base.Cartesian: @ngenerate, @nloops, @nref
@ngenerate N typeof(r) function wsum!{T,N,S,W<:Real}(r::AbstractArray{T,N}, v::AbstractArray{S,N},
w::AbstractVector{W}, dim::Int)
1 <= dim <= N || error("dim = $dim not in range [1,$N]")
for i = 1:N
(i == dim && size(r, i) == 1 && size(v, i) == length(w)) || size(r, i) == size(v, i) || error(DimensionMismatch(""))
## wsum along dimension
#
# Brief explanation of the algorithm:
# ------------------------------------
#
# 1. _wsum! provides the core implementation, which assumes that
# the dimensions of all input arguments are consistent, and no
# dimension checking is performed therein.
#
# wsum and wsum! perform argument checking and call _wsum!
# internally.
#
# 2. _wsum! adopt a Cartesian based implementation for general
# sub types of AbstractArray. Particularly, a faster routine
# that keeps a local accumulator will be used when dim = 1.
#
# The internal function that implements this is _wsum_general!
#
# 3. _wsum! is specialized for following cases:
# (a) A is a vector: we invoke the vector version wsum above.
# The internal function that implements this is _wsum1!
#
# (b) A is a dense matrix with eltype <: BlasReal: we call gemv!
# The internal function that implements this is _wsum2_blas!
#
# (c) A is a contiguous array with eltype <: BlasReal:
# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN)
# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN)
# otherwise: decompose A into multiple pages, and apply _wsum2!
# for each
#
# (d) A is a general dense array with eltype <: BlasReal:
# dim <= 2: delegate to (a) and (b)
# otherwise, decompose A into multiple pages
#

function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool)
r = wsum(A, w)
if init
R[1] = r
else
R[1] += r
end
fill!(r, 0)
weight = zero(W)
@nloops N i v d->(if d == dim
weight = w[i_d]
j_d = 1
else
j_d = i_d
end) @inbounds (@nref N r j) += (@nref N v i)*weight
r
return R
end

function _wsum2_blas!{T<:BlasReal}(R::StridedVector{T}, A::StridedMatrix{T}, w::StridedVector{T}, dim::Int, init::Bool)
beta = ifelse(init, zero(T), one(T))
trans = dim == 1 ? 'T' : 'N'
BLAS.gemv!(trans, one(T), A, w, beta, R)
return R
end

# Weighted sum via `A_mul_B!`/`At_mul_B!` for first and last
# dimensions of compatible arrays. `vec` and `reshape` are only
# guaranteed not to make a copy for Arrays, so only supports Arrays if
# these calls may be necessary.
function wsum!{W<:Real}(r::Union(Array, AbstractVector), v::Union(Array, AbstractMatrix), w::AbstractVector{W}, dim::Int)
function _wsumN!{T<:BlasReal,N}(R::ContiguousArray{T}, A::ContiguousArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool)
if dim == 1
m = size(v, 1)
n = div(length(v), m)
(length(r) == n && length(w) == m) || throw(DimensionMismatch(""))
At_mul_B!(vec(r), isa(v, AbstractMatrix) ? v : reshape(v, m, n), w)
elseif dim == ndims(v)
n = size(v, ndims(v))
m = div(length(v), n)
(length(r) == m && length(w) == n) || throw(DimensionMismatch(""))
A_mul_B!(vec(r), isa(v, AbstractMatrix) ? v : reshape(v, m, n), w)
m = size(A, 1)
n = div(length(A), m)
_wsum2_blas!(view(R,:), reshape_view(A, (m, n)), w, 1, init)
elseif dim == N
n = size(A, N)
m = div(length(A), n)
_wsum2_blas!(view(R,:), reshape_view(A, (m, n)), w, 2, init)
else # 1 < dim < N
m = 1
for i = 1:dim-1; m *= size(A, i); end
n = size(A, dim)
k = 1
for i = dim+1:N; k *= size(A, i); end
Av = reshape_view(A, (m, n, k))
Rv = reshape_view(R, (m, k))
for i = 1:k
_wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init)
end
end
return R
end

function _wsumN!{T<:BlasReal,N}(R::ContiguousArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool)
@assert N >= 3
if dim <= 2
m = size(A, 1)
n = size(A, 2)
npages = 1
for i = 3:N
npages *= size(A, i)
end
rlen = ifelse(dim == 1, n, m)
Rv = reshape_view(R, (rlen, npages))
for i = 1:npages
_wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init)
end
else
invoke(wsum!, (AbstractArray, AbstractArray, typeof(w), Int), r, v, w, dim)
_wsum_general!(R, A, w, dim, init)
end
r
return R
end

Base.sum!{W<:Real}(r::AbstractArray, v::AbstractArray, w::WeightVec{W}, dim::Int) =
wsum!(r, v, values(w), dim)
# General Cartesian-based weighted sum across dimensions
@ngenerate N typeof(R) function _wsum_general!{T,RT,WT,N}(R::AbstractArray{RT},
A::AbstractArray{T,N}, w::AbstractVector{WT}, dim::Int, init::Bool)
init && fill!(R, zero(RT))
wi = zero(WT)
if dim == 1
@nextract N sizeR d->size(R,d)
sizA1 = size(A, 1)
@nloops N i d->(d>1? (1:size(A,d)) : (1:1)) d->(j_d = sizeR_d==1 ? 1 : i_d) begin
@inbounds r = (@nref N R j)
for i_1 = 1:sizA1
@inbounds r += (@nref N A i) * w[i_1]
end
@inbounds (@nref N R j) = r
end
else
@nloops N i A d->(if d == dim
wi = w[i_d]
j_d = 1
else
j_d = i_d
end) @inbounds (@nref N R j) += (@nref N A i) * wi
end
return R
end


# N = 1
_wsum!{T<:BlasReal}(R::ContiguousArray{T}, A::DenseArray{T,1}, w::StridedVector{T}, dim::Int, init::Bool) =
_wsum1!(R, A, w, init)

# N = 2
_wsum!{T<:BlasReal}(R::ContiguousArray{T}, A::DenseArray{T,2}, w::StridedVector{T}, dim::Int, init::Bool) =
(_wsum2_blas!(view(R,:), A, w, dim, init); R)

# N >= 3
_wsum!{T<:BlasReal,N}(R::ContiguousArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) =
_wsumN!(R, A, w, dim, init)

_wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bool) = _wsum_general!(R, A, w, dim, init)

## wsum! and wsum

wsumtype{T,W}(::Type{T}, ::Type{W}) = typeof(zero(T) * zero(W) + zero(T) * zero(W))
wsumtype{T<:BlasReal}(::Type{T}, ::Type{T}) = T

function wsum!{T,N}(R::AbstractArray, A::AbstractArray{T,N}, w::AbstractVector, dim::Int; init::Bool=true)
1 <= dim <= N || error("dim should be within [1, $N]")
ndims(R) <= N || error("ndims(R) should not exceed $N")
length(w) == size(A,dim) || throw(DimensionMismatch("Inconsistent array dimension."))
# TODO: more careful examination of R's size
_wsum!(R, A, w, dim, init)
end

function wsum{T<:Number,W<:Real}(A::AbstractArray{T}, w::AbstractVector{W}, dim::Int)
length(w) == size(A,dim) || throw(DimensionMismatch("Inconsistent array dimension."))
_wsum!(Array(wsumtype(T,W), Base.reduced_dims(size(A), dim)), A, w, dim, true)
end

# extended sum! and wsum

Base.sum!{W<:Real}(R::AbstractArray, A::AbstractArray, w::WeightVec{W}, dim::Int; init::Bool=true) =
wsum!(R, A, values(w), dim; init=init)

Base.sum{T<:Number,W<:Real}(A::AbstractArray{T}, w::WeightVec{W}, dim::Int) = wsum(A, values(w), dim)


###### Weighted means #####

function wmean{T<:Number}(v::AbstractArray{T}, w::AbstractVector)
Base.depwarn("wmean is deprecated, use mean(v, weights(w)) instead.", :wmean)
mean(v, weights(w))
end

Base.mean(v::AbstractArray, w::WeightVec) = sum(v, w) / sum(w)

Base.mean!(R::AbstractArray, A::AbstractArray, w::WeightVec, dim::Int) =
scale!(Base.sum!(R, A, w, dim), inv(sum(w)))

wmeantype{T,W}(::Type{T}, ::Type{W}) = typeof((zero(T)*zero(W) + zero(T)*zero(W)) / one(W))
wmeantype{T<:BlasReal}(::Type{T}, ::Type{T}) = T

wsum{T<:Number,W<:Real}(v::AbstractArray{T}, w::AbstractVector{W}, dim::Int) =
wsum!(Array(typeof(zero(T)*zero(W) + zero(T)*zero(W)), Base.reduced_dims(size(v), dim)), v, w, dim)
Base.mean{T<:Number,W<:Real}(A::AbstractArray{T}, w::WeightVec{W}, dim::Int) =
mean!(Array(wmeantype(T, W), Base.reduced_dims(size(A), dim)), A, w, dim)

Base.sum{T<:Number,W<:Real}(v::AbstractArray{T}, w::WeightVec{W}, dim::Int) = wsum(v, values(w), dim)

50 changes: 0 additions & 50 deletions test/means.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,53 +18,3 @@ using Base.Test
@test_approx_eq trimmean([-100, 2, 3, 7, 200], 0.4) 4.0
@test_approx_eq trimmean([-100, 2, 3, 7, 200], 0.8) 3.0

@test_approx_eq sum([1.0, 2.0, 3.0], weights([1/3, 1/3, 1/3])) 2.0
@test_approx_eq sum([1.0, 2.0, 3.0], weights([1.0, 0.0, 0.0])) 1.0
@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.0, 1.0, 0.0])) 2.0
@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.0, 0.0, 1.0])) 3.0
@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.5, 0.0, 0.5])) 2.0
@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.5, 0.5, 0.0])) 1.5
@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.0, 0.5, 0.5])) 2.5

@test_approx_eq sum(1:3, weights([1/3, 1/3, 1/3])) 2.0
@test_approx_eq sum(1:3, weights([1.0, 0.0, 0.0])) 1.0
@test_approx_eq sum(1:3, weights([0.0, 1.0, 0.0])) 2.0
@test_approx_eq sum(1:3, weights([0.0, 0.0, 1.0])) 3.0
@test_approx_eq sum(1:3, weights([0.5, 0.0, 0.5])) 2.0
@test_approx_eq sum(1:3, weights([0.5, 0.5, 0.0])) 1.5
@test_approx_eq sum(1:3, weights([0.0, 0.5, 0.5])) 2.5
@test_approx_eq sum(1:3, weights([1.0, 1.0, 0.5])) 4.5
@test_approx_eq mean(1:3, weights([1.0, 1.0, 0.5])) 1.8

a = [1. 2. 3.; 4. 5. 6.]

@test size(mean(a, weights(ones(2)), 1)) == (1, 3)
@test_approx_eq sum(a, weights([1.0, 1.0]), 1) [5.0, 7.0, 9.0]
@test_approx_eq mean(a, weights([1.0, 1.0]), 1) [2.5, 3.5, 4.5]
@test_approx_eq sum(a, weights([1.0, 0.0]), 1) [1.0, 2.0, 3.0]
@test_approx_eq sum(a, weights([0.0, 1.0]), 1) [4.0, 5.0, 6.0]

@test size(mean(a, weights(ones(3)), 2)) == (2, 1)
@test_approx_eq wsum!(zeros(1, 2), a, [1.0, 1.0, 1.0], 2) [6.0 15.0]
@test_approx_eq wsum(a, [1.0, 1.0, 1.0], 2) [6.0 15.0]
@test_approx_eq sum!(zeros(1, 2), a, weights([1.0, 1.0, 1.0]), 2) [6.0 15.0]
@test_approx_eq sum(a, weights([1.0, 1.0, 1.0]), 2) [6.0 15.0]
@test_approx_eq mean(a, weights([1.0, 1.0, 1.0]), 2) [2.0 5.0]
@test_approx_eq sum(a, weights([1.0, 0.0, 0.0]), 2) [1.0 4.0]
@test_approx_eq sum(a, weights([0.0, 0.0, 1.0]), 2) [3.0 6.0]

@test_throws ErrorException mean(a, weights(ones(3)), 3)
@test_throws DimensionMismatch mean(a, weights(ones(2)), 2)
@test_throws DimensionMismatch mean!(ones(1, 1), a, weights(ones(3)), 2)

a = reshape(1.0:27.0, 3, 3, 3)

for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0])
@test_approx_eq sum(a, weights(wt), 1) sum(a.*reshape(wt, length(wt), 1, 1), 1)
@test_approx_eq sum(a, weights(wt), 2) sum(a.*reshape(wt, 1, length(wt), 1), 2)
@test_approx_eq sum(a, weights(wt), 3) sum(a.*reshape(wt, 1, 1, length(wt)), 3)
@test_approx_eq mean(a, weights(wt), 1) sum(a.*reshape(wt, length(wt), 1, 1), 1)/sum(wt)
@test_approx_eq mean(a, weights(wt), 2) sum(a.*reshape(wt, 1, length(wt), 1), 2)/sum(wt)
@test_approx_eq mean(a, weights(wt), 3) sum(a.*reshape(wt, 1, 1, length(wt)), 3)/sum(wt)
@test_throws ErrorException mean(a, weights(wt), 4)
end
Loading

0 comments on commit b976eac

Please sign in to comment.