Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Fast implementation of reduction along dims #5294

Merged
merged 12 commits into from
Jan 5, 2014
94 changes: 24 additions & 70 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1409,34 +1409,9 @@ function nnz{T}(a::AbstractArray{T})
return n
end

# for reductions that expand 0 dims to 1
reduced_dims(A, region) = ntuple(ndims(A), i->(in(i, region) ? 1 :
size(A,i)))

# keep 0 dims in place
reduced_dims0(A, region) = ntuple(ndims(A), i->(size(A,i)==0 ? 0 :
in(i, region) ? 1 :
size(A,i)))

reducedim(f::Function, A, region, v0) =
reducedim(f, A, region, v0, similar(A, reduced_dims(A, region)))

maximum{T}(A::AbstractArray{T}, region) =
isempty(A) ? similar(A,reduced_dims0(A,region)) : reducedim(scalarmax,A,region,typemin(T))
minimum{T}(A::AbstractArray{T}, region) =
isempty(A) ? similar(A,reduced_dims0(A,region)) : reducedim(scalarmin,A,region,typemax(T))
sum{T}(A::AbstractArray{T}, region) = reducedim(+,A,region,zero(T))
prod{T}(A::AbstractArray{T}, region) = reducedim(*,A,region,one(T))

all(A::AbstractArray{Bool}, region) = reducedim(&,A,region,true)
any(A::AbstractArray{Bool}, region) = reducedim(|,A,region,false)
sum(A::AbstractArray{Bool}, region) = reducedim(+,A,region,0,similar(A,Int,reduced_dims(A,region)))
sum(A::AbstractArray{Bool}) = nnz(A)
prod(A::AbstractArray{Bool}) =
error("use all() instead of prod() for boolean arrays")
prod(A::AbstractArray{Bool}, region) =
error("use all() instead of prod() for boolean arrays")


# a fast implementation of sum in sequential order (from left to right)
function sum_seq{T}(a::AbstractArray{T}, ifirst::Int, ilast::Int)
Expand Down Expand Up @@ -1594,56 +1569,33 @@ function cumsum_kbn{T<:FloatingPoint}(A::AbstractArray{T}, axis::Integer)
return B + C
end

function prod{T}(A::AbstractArray{T})
if isempty(A)
return one(T)
end
v = A[1]
for i=2:length(A)
@inbounds v *= A[i]
end
v
end

function minimum{T<:Real}(A::AbstractArray{T})
if isempty(A); error("argument must not be empty"); end
v = A[1]
for i=2:length(A)
@inbounds x = A[i]
if x < v
v = x
end
function prod_rgn{T}(A::AbstractArray{T}, first::Int, last::Int)
if first > last
return one(T)
end
v
end

function maximum{T<:Real}(A::AbstractArray{T})
if isempty(A); error("argument must not be empty"); end
v = A[1]
for i=2:length(A)
@inbounds x = A[i]
if x > v
v = x
end
i = first
v = A[i]
while i < last
@inbounds v *= A[i+=1]
end
v
return v
end
prod{T}(A::AbstractArray{T}) = prod_rgn(A, 1, length(A))

# specialized versions for floating-point, which deal with NaNs

function minimum{T<:FloatingPoint}(A::AbstractArray{T})
if isempty(A); error("argument must not be empty"); end
n = length(A)
function minimum_rgn{T<:Real}(A::AbstractArray{T}, first::Int, last::Int)
if first > last; error("argument range must not be empty"); end

# locate the first non NaN number
v = A[1]
i = 2
while v != v && i <= n
v = A[first]
i = first + 1
while v != v && i <= last
@inbounds v = A[i]
i += 1
end

while i <= n
while i <= last
@inbounds x = A[i]
if x < v
v = x
Expand All @@ -1653,19 +1605,18 @@ function minimum{T<:FloatingPoint}(A::AbstractArray{T})
v
end

function maximum{T<:FloatingPoint}(A::AbstractArray{T})
if isempty(A); error("argument must not be empty"); end
n = length(A)
function maximum_rgn{T<:Real}(A::AbstractArray{T}, first::Int, last::Int)
if first > last; error("argument range must not be empty"); end

# locate the first non NaN number
v = A[1]
i = 2
while v != v && i <= n
v = A[first]
i = first + 1
while v != v && i <= last
@inbounds v = A[i]
i += 1
end

while i <= n
while i <= last
@inbounds x = A[i]
if x > v
v = x
Expand All @@ -1675,6 +1626,9 @@ function maximum{T<:FloatingPoint}(A::AbstractArray{T})
v
end

minimum{T<:Real}(A::AbstractArray{T}) = minimum_rgn(A, 1, length(A))
maximum{T<:Real}(A::AbstractArray{T}) = maximum_rgn(A, 1, length(A))

# extrema

function extrema{T<:Real}(A::AbstractArray{T})
Expand Down
76 changes: 0 additions & 76 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1401,82 +1401,6 @@ function indcopy(sz::Dims, I::(RangeIndex...))
end


## Reductions ##

# TODO:
# - find out why inner loop with dimsA[i] instead of size(A,i) is way too slow

let reducedim_cache = nothing
# generate the body of the N-d loop to compute a reduction
function gen_reducedim_func(n, f)
ivars = { symbol(string("i",i)) for i=1:n }
# limits and vars for reduction loop
lo = { symbol(string("lo",i)) for i=1:n }
hi = { symbol(string("hi",i)) for i=1:n }
rvars = { symbol(string("r",i)) for i=1:n }
setlims = { quote
# each dim of reduction is either 1:sizeA or ivar:ivar
if in($i,region)
$(lo[i]) = 1
$(hi[i]) = size(A,$i)
else
$(lo[i]) = $(hi[i]) = $(ivars[i])
end
end for i=1:n }
rranges = { :( $(lo[i]):$(hi[i]) ) for i=1:n } # lo:hi for all dims
body =
quote
_tot = v0
$(setlims...)
$(make_loop_nest(rvars, rranges,
:(_tot = ($f)(_tot, A[$(rvars...)]))))
R[_ind] = _tot
_ind += 1
end
quote
local _F_
function _F_(f, A, region, R, v0)
_ind = 1
$(make_loop_nest(ivars, { :(1:size(R,$i)) for i=1:n }, body))
end
_F_
end
end

global reducedim
function reducedim(f::Function, A, region, v0, R)
ndimsA = ndims(A)

if is(reducedim_cache,nothing)
reducedim_cache = Dict()
end

key = ndimsA
fname = :f

if (is(f,+) && (fname=:+;true)) ||
(is(f,*) && (fname=:*;true)) ||
(is(f,scalarmax) && (fname=:scalarmax;true)) ||
(is(f,scalarmin) && (fname=:scalarmin;true)) ||
(is(f,&) && (fname=:&;true)) ||
(is(f,|) && (fname=:|;true))
key = (fname, ndimsA)
end

if !haskey(reducedim_cache,key)
fexpr = gen_reducedim_func(ndimsA, fname)
func = eval(fexpr)
reducedim_cache[key] = func
else
func = reducedim_cache[key]
end

func(f, A, region, R, v0)

return R
end
end

## Filter ##

# given a function returning a boolean and an array, return matching elements
Expand Down
Loading