Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: use pairwise summation for sum, cumsum, and cumprod #4039

Merged
merged 3 commits into from
Aug 13, 2013
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
@@ -931,17 +931,28 @@ function (!=)(A::AbstractArray, B::AbstractArray)
return false
end

for (f, op) = ((:cumsum, :+), (:cumprod, :*) )
for (f, fp, op) = ((:cumsum, :cumsum_pairwise, :+),
(:cumprod, :cumprod_pairwise, :*) )
# in-place cumsum of c = s+v(i1:n), using pairwise summation as for sum
@eval function ($fp)(v::AbstractVector, c::AbstractVector, s, i1, n)
if n < 128
@inbounds c[i1] = ($op)(s, v[i1])
for i = i1+1:i1+n-1
@inbounds c[i] = $(op)(c[i-1], v[i])
end
else
n2 = div(n,2)
($fp)(v, c, s, i1, n2)
($fp)(v, c, c[(i1+n2)-1], i1+n2, n-n2)
end
end

@eval function ($f)(v::AbstractVector)
n = length(v)
c = $(op===:+ ? (:(similar(v,typeof(+zero(eltype(v)))))) :
(:(similar(v))))
if n == 0; return c; end

c[1] = v[1]
for i=2:n
c[i] = ($op)(c[i-1], v[i])
end
($fp)(v, c, $(op==:+ ? :(zero(eltype(v))) : :(one(eltype(v)))), 1, n)
return c
end

@@ -1367,17 +1378,37 @@ prod(A::AbstractArray{Bool}) =
prod(A::AbstractArray{Bool}, region) =
error("use all() instead of prod() for boolean arrays")

function sum{T}(A::AbstractArray{T})
if isempty(A)
return zero(T)
end
v = A[1]
for i=2:length(A)
@inbounds v += A[i]
# Pairwise (cascade) summation of A[i1:i1+n-1], which O(log n) error growth
# [vs O(n) for a simple loop] with negligible performance cost if
# the base case is large enough. See, e.g.:
# http://en.wikipedia.org/wiki/Pairwise_summation
# Higham, Nicholas J. (1993), "The accuracy of floating point
# summation", SIAM Journal on Scientific Computing 14 (4): 783–799.
# In fact, the root-mean-square error growth, assuming random roundoff
# errors, is only O(sqrt(log n)), which is nearly indistinguishable from O(1)
# in practice. See:
# Manfred Tasche and Hansmartin Zeuner, Handbook of
# Analytic-Computational Methods in Applied Mathematics (2000).
function sum_pairwise(A::AbstractArray, i1,n)
if n < 128
@inbounds s = A[i1]
for i = i1+1:i1+n-1
@inbounds s += A[i]
end
return s
else
n2 = div(n,2)
return sum_pairwise(A, i1, n2) + sum_pairwise(A, i1+n2, n-n2)
end
v
end

function sum{T}(A::AbstractArray{T})
n = length(A)
n == 0 ? zero(T) : sum_pairwise(A, 1, n)
end

# Kahan (compensated) summation: O(1) error growth, at the expense
# of a considerable increase in computational expense.
function sum_kbn{T<:FloatingPoint}(A::AbstractArray{T})
n = length(A)
if (n == 0)
25 changes: 23 additions & 2 deletions base/reduce.jl
Original file line number Diff line number Diff line change
@@ -151,6 +151,27 @@ function mapreduce(f::Callable, op::Function, v0, itr)
return v
end

# mapreduce for associative operations, using pairwise recursive reduction
# for improved accuracy (see sum_pairwise)
function mr_pairwise(f::Callable, op::Function, A::AbstractArray, i1,n)
if n < 128
@inbounds v = f(A[i1])
for i = i1+1:i1+n-1
@inbounds v = op(v,f(A[i]))
end
return v
else
n2 = div(n,2)
return op(mr_pairwise(f,op,A, i1,n2), mr_pairwise(f,op,A, i1+n2,n-n2))
end
end
function mapreduce_associative(f::Callable, op::Function, A::AbstractArray)
n = length(A)
n == 0 ? op() : mr_pairwise(f,op,A, 1,n)
end
# can't easily do pairwise reduction without random access, so punt:
mapreduce_associative(f::Callable, op::Function, itr) = mapreduce(f, op, itr)

function any(itr)
for x in itr
if x
@@ -171,8 +192,8 @@ end

max(f::Function, itr) = mapreduce(f, max, itr)
min(f::Function, itr) = mapreduce(f, min, itr)
sum(f::Function, itr) = mapreduce(f, + , itr)
prod(f::Function, itr) = mapreduce(f, * , itr)
sum(f::Function, itr) = mapreduce_associative(f, + , itr)
prod(f::Function, itr) = mapreduce_associative(f, * , itr)

function count(pred::Function, itr)
s = 0
23 changes: 17 additions & 6 deletions base/statistics.jl
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ function mean(iterable)
end
return total/count
end
mean(v::AbstractArray) = sum(v) / length(v)
mean(v::AbstractArray, region) = sum(v, region) / prod(size(v)[region])

function median!{T<:Real}(v::AbstractVector{T}; checknan::Bool=true)
@@ -28,16 +29,26 @@ end
median{T<:Real}(v::AbstractArray{T}; checknan::Bool=true) =
median!(vec(copy(v)), checknan=checknan)

## variance with known mean
function varm(v::AbstractVector, m::Number)
## variance with known mean, using pairwise summation
function varm_pairwise(A::AbstractArray, m, i1,n) # see sum_pairwise
if n < 128
@inbounds s = abs2(A[i1] - m)
for i = i1+1:i1+n-1
@inbounds s += abs2(A[i] - m)
end
return s
else
n2 = div(n,2)
return varm_pairwise(A, m, i1, n2) + varm_pairwise(A, m, i1+n2, n-n2)
end
end
function varm(v::AbstractArray, m::Number)
n = length(v)
if n == 0 || n == 1
return NaN
end
x = v - m
return dot(x, x) / (n - 1)
return varm_pairwise(v, m, 1,n) / (n - 1)
end
varm(v::AbstractArray, m::Number) = varm(vec(v), m)
varm(v::Ranges, m::Number) = var(v)

## variance
@@ -52,7 +63,7 @@ end
var(v::AbstractArray) = varm(v, mean(v))
function var(v::AbstractArray, region)
x = v .- mean(v, region)
return sum(x.*x, region) / (prod(size(v)[region]) - 1)
return sum(abs2(x), region) / (prod(size(v)[region]) - 1)
end

## standard deviation with known mean
9 changes: 9 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
@@ -308,6 +308,15 @@ v[2,2,1,1] = 40.0

@test isequal(v,sum(z,(3,4)))

z = rand(10^6)
let es = sum_kbn(z), es2 = sum_kbn(z[1:10^5])
@test (es - sum(z)) < es * 1e-13
cs = cumsum(z)
@test (es - cs[end]) < es * 1e-13
@test (es2 - cs[10^5]) < es2 * 1e-13
end
@test sum(sin(z)) == sum(sin, z)

## large matrices transpose ##

for i = 1 : 3
4 changes: 4 additions & 0 deletions test/statistics.jl
Original file line number Diff line number Diff line change
@@ -32,6 +32,10 @@
@test all(hist([1:100]/100,0.0:0.01:1.0)[2] .==1)
@test hist([1,1,1,1,1])[2][1] == 5

A = Complex128[exp(i*im) for i in 1:10^4]
@test_approx_eq varm(A,0.) sum(map(abs2,A))/(length(A)-1)
@test_approx_eq varm(A,mean(A)) var(A,1)

@test midpoints(1.0:1.0:10.0) == 1.5:1.0:9.5
@test midpoints(1:10) == 1.5:9.5
@test midpoints(Float64[1.0:1.0:10.0]) == Float64[1.5:1.0:9.5]