Skip to content

Commit

Permalink
Merge pull request #5294 from lindahua/dh/reducedim2
Browse files Browse the repository at this point in the history
WIP: Fast implementation of reduction along dims
  • Loading branch information
lindahua committed Jan 5, 2014
2 parents 7155567 + da2d02c commit 85cb30f
Show file tree
Hide file tree
Showing 6 changed files with 499 additions and 146 deletions.
94 changes: 24 additions & 70 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1409,34 +1409,9 @@ function nnz{T}(a::AbstractArray{T})
return n
end

# for reductions that expand 0 dims to 1
reduced_dims(A, region) = ntuple(ndims(A), i->(in(i, region) ? 1 :
size(A,i)))

# keep 0 dims in place
reduced_dims0(A, region) = ntuple(ndims(A), i->(size(A,i)==0 ? 0 :
in(i, region) ? 1 :
size(A,i)))

reducedim(f::Function, A, region, v0) =
reducedim(f, A, region, v0, similar(A, reduced_dims(A, region)))

maximum{T}(A::AbstractArray{T}, region) =
isempty(A) ? similar(A,reduced_dims0(A,region)) : reducedim(scalarmax,A,region,typemin(T))
minimum{T}(A::AbstractArray{T}, region) =
isempty(A) ? similar(A,reduced_dims0(A,region)) : reducedim(scalarmin,A,region,typemax(T))
sum{T}(A::AbstractArray{T}, region) = reducedim(+,A,region,zero(T))
prod{T}(A::AbstractArray{T}, region) = reducedim(*,A,region,one(T))

all(A::AbstractArray{Bool}, region) = reducedim(&,A,region,true)
any(A::AbstractArray{Bool}, region) = reducedim(|,A,region,false)
sum(A::AbstractArray{Bool}, region) = reducedim(+,A,region,0,similar(A,Int,reduced_dims(A,region)))
sum(A::AbstractArray{Bool}) = nnz(A)
prod(A::AbstractArray{Bool}) =
error("use all() instead of prod() for boolean arrays")
prod(A::AbstractArray{Bool}, region) =
error("use all() instead of prod() for boolean arrays")


# a fast implementation of sum in sequential order (from left to right)
function sum_seq{T}(a::AbstractArray{T}, ifirst::Int, ilast::Int)
Expand Down Expand Up @@ -1594,56 +1569,33 @@ function cumsum_kbn{T<:FloatingPoint}(A::AbstractArray{T}, axis::Integer)
return B + C
end

function prod{T}(A::AbstractArray{T})
if isempty(A)
return one(T)
end
v = A[1]
for i=2:length(A)
@inbounds v *= A[i]
end
v
end

function minimum{T<:Real}(A::AbstractArray{T})
if isempty(A); error("argument must not be empty"); end
v = A[1]
for i=2:length(A)
@inbounds x = A[i]
if x < v
v = x
end
function prod_rgn{T}(A::AbstractArray{T}, first::Int, last::Int)
if first > last
return one(T)
end
v
end

function maximum{T<:Real}(A::AbstractArray{T})
if isempty(A); error("argument must not be empty"); end
v = A[1]
for i=2:length(A)
@inbounds x = A[i]
if x > v
v = x
end
i = first
v = A[i]
while i < last
@inbounds v *= A[i+=1]
end
v
return v
end
prod{T}(A::AbstractArray{T}) = prod_rgn(A, 1, length(A))

# specialized versions for floating-point, which deal with NaNs

function minimum{T<:FloatingPoint}(A::AbstractArray{T})
if isempty(A); error("argument must not be empty"); end
n = length(A)
function minimum_rgn{T<:Real}(A::AbstractArray{T}, first::Int, last::Int)
if first > last; error("argument range must not be empty"); end

# locate the first non NaN number
v = A[1]
i = 2
while v != v && i <= n
v = A[first]
i = first + 1
while v != v && i <= last
@inbounds v = A[i]
i += 1
end

while i <= n
while i <= last
@inbounds x = A[i]
if x < v
v = x
Expand All @@ -1653,19 +1605,18 @@ function minimum{T<:FloatingPoint}(A::AbstractArray{T})
v
end

function maximum{T<:FloatingPoint}(A::AbstractArray{T})
if isempty(A); error("argument must not be empty"); end
n = length(A)
function maximum_rgn{T<:Real}(A::AbstractArray{T}, first::Int, last::Int)
if first > last; error("argument range must not be empty"); end

# locate the first non NaN number
v = A[1]
i = 2
while v != v && i <= n
v = A[first]
i = first + 1
while v != v && i <= last
@inbounds v = A[i]
i += 1
end

while i <= n
while i <= last
@inbounds x = A[i]
if x > v
v = x
Expand All @@ -1675,6 +1626,9 @@ function maximum{T<:FloatingPoint}(A::AbstractArray{T})
v
end

minimum{T<:Real}(A::AbstractArray{T}) = minimum_rgn(A, 1, length(A))
maximum{T<:Real}(A::AbstractArray{T}) = maximum_rgn(A, 1, length(A))

# extrema

function extrema{T<:Real}(A::AbstractArray{T})
Expand Down
76 changes: 0 additions & 76 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1401,82 +1401,6 @@ function indcopy(sz::Dims, I::(RangeIndex...))
end


## Reductions ##

# TODO:
# - find out why inner loop with dimsA[i] instead of size(A,i) is way too slow

let reducedim_cache = nothing
# generate the body of the N-d loop to compute a reduction
function gen_reducedim_func(n, f)
ivars = { symbol(string("i",i)) for i=1:n }
# limits and vars for reduction loop
lo = { symbol(string("lo",i)) for i=1:n }
hi = { symbol(string("hi",i)) for i=1:n }
rvars = { symbol(string("r",i)) for i=1:n }
setlims = { quote
# each dim of reduction is either 1:sizeA or ivar:ivar
if in($i,region)
$(lo[i]) = 1
$(hi[i]) = size(A,$i)
else
$(lo[i]) = $(hi[i]) = $(ivars[i])
end
end for i=1:n }
rranges = { :( $(lo[i]):$(hi[i]) ) for i=1:n } # lo:hi for all dims
body =
quote
_tot = v0
$(setlims...)
$(make_loop_nest(rvars, rranges,
:(_tot = ($f)(_tot, A[$(rvars...)]))))
R[_ind] = _tot
_ind += 1
end
quote
local _F_
function _F_(f, A, region, R, v0)
_ind = 1
$(make_loop_nest(ivars, { :(1:size(R,$i)) for i=1:n }, body))
end
_F_
end
end

global reducedim
function reducedim(f::Function, A, region, v0, R)
ndimsA = ndims(A)

if is(reducedim_cache,nothing)
reducedim_cache = Dict()
end

key = ndimsA
fname = :f

if (is(f,+) && (fname=:+;true)) ||
(is(f,*) && (fname=:*;true)) ||
(is(f,scalarmax) && (fname=:scalarmax;true)) ||
(is(f,scalarmin) && (fname=:scalarmin;true)) ||
(is(f,&) && (fname=:&;true)) ||
(is(f,|) && (fname=:|;true))
key = (fname, ndimsA)
end

if !haskey(reducedim_cache,key)
fexpr = gen_reducedim_func(ndimsA, fname)
func = eval(fexpr)
reducedim_cache[key] = func
else
func = reducedim_cache[key]
end

func(f, A, region, R, v0)

return R
end
end

## Filter ##

# given a function returning a boolean and an array, return matching elements
Expand Down
Loading

0 comments on commit 85cb30f

Please sign in to comment.