From da2d02c0da44a06f958bc2f73e912ea59175eb71 Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Sat, 4 Jan 2014 15:32:58 -0600 Subject: [PATCH] more flexible code generation functions Now the code-gen allows generating reduction functions that may contain more than one input arrays. --- base/reducedim.jl | 79 ++++++++++++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/base/reducedim.jl b/base/reducedim.jl index 879c45ecc7863..20f424146d56c 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -163,10 +163,13 @@ function rcompress_dims{N}(siz::NTuple{N,Int}, region) return (isrd[1], sdims) end -function generate_reducedim_funcs(fname, comb, sker, ker0!, ker1!) +function generate_reducedim_funcs(fname, params, args, sizexpr, comb, sker, ker0!, ker1!) # Parameters: # - # - fname: the interface function name (e.g. sum, maximum) + # - fname: the interface function name (e.g. sum, maximum) + # - params: a list of input parameters in function signatures + # - args: a list of argument symbols + # - sizexpr: the expression that calculates the input size # - comb: the combination operation (e.g. +) # - sker: a kernel function that reduces a vector (or a range of it) to a scalar # - ker0!: a kernel that initializes an accumulator array using the first column of terms @@ -179,32 +182,32 @@ function generate_reducedim_funcs(fname, comb, sker, ker0!, ker1!) quote global $(fname!) - function $(fname!)(dst::Array, a::Array, dim::Integer) - nd = ndims(a) - siz = size(a) + function $(fname!)(dst::Array, $(params...), dim::Integer) + siz = $(sizexpr) + nd = length(siz) if 1 <= dim <= nd if dim == 1 - $(fa!)(true, dst, 0, a, 0, prod(siz[2:nd]), siz[1]) + $(fa!)(true, dst, 0, $(args...), 0, prod(siz[2:nd]), siz[1]) elseif dim == nd - $(fb!)(true, dst, 0, a, 0, siz[nd], prod(siz[1:nd-1])) + $(fb!)(true, dst, 0, $(args...), 0, siz[nd], prod(siz[1:nd-1])) else - $(fb!)(true, dst, 0, a, 0, prod(siz[dim+1:nd]), siz[dim], prod(siz[1:dim-1])) + $(fb!)(true, dst, 0, $(args...), 0, prod(siz[dim+1:nd]), siz[dim], prod(siz[1:dim-1])) end else - $(ker0!)(dst, 1, a, 1, length(a)) + $(ker0!)(dst, 1, $(args...), 1, prod(siz)) end dst end - function $(fname!)(dst::Array, a::Array, region) + function $(fname!)(dst::Array, $(params...), region) if length(region) == 1 - $(fname!)(dst, a, region[1]) + $(fname!)(dst, $(args...), region[1]) else - isrd1, secs = rcompress_dims(size(a), region) + isrd1, secs = rcompress_dims($(sizexpr), region) if isrd1 - $(fa!)(true, dst, 0, a, 0, secs[end:-1:1]...) + $(fa!)(true, dst, 0, $(args...), 0, secs[end:-1:1]...) else - $(fb!)(true, dst, 0, a, 0, secs[end:-1:1]...) + $(fb!)(true, dst, 0, $(args...), 0, secs[end:-1:1]...) end end dst @@ -212,44 +215,44 @@ function generate_reducedim_funcs(fname, comb, sker, ker0!, ker1!) # $(fa!) global $(fa!) - function $(fa!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int) + function $(fa!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int) if isinit - dst[od+1] = $(sker)(a, oa+1, oa+n1) + dst[od+1] = $(sker)($(args...), oa+1, oa+n1) else - dst[od+1] = $(comb)(dst[od+1], $(sker)(a, oa+1, oa+n1)) + dst[od+1] = $(comb)(dst[od+1], $(sker)($(args...), oa+1, oa+n1)) end end - function $(fa!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int, n2::Int) + function $(fa!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int, n2::Int) if isinit for j = 1:n1 alast = oa + n2 - dst[od+j] = $(sker)(a, oa+1, alast) + dst[od+j] = $(sker)($(args...), oa+1, alast) oa = alast end else for j = 1:n1 alast = oa + n2 - dst[od+j] = $(comb)(dst[od+j], $(sker)(a, oa+1, alast)) + dst[od+j] = $(comb)(dst[od+j], $(sker)($(args...), oa+1, alast)) oa = alast end end end - function $(fa!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int, n2::Int, n3::Int, ns::Int...) + function $(fa!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int, n2::Int, n3::Int, ns::Int...) as::Int = *(n2, n3, ns...) if length(ns) & 1 == 0 - $(fa!)(isinit, dst, od, a, oa, n2, n3, ns...) + $(fa!)(isinit, dst, od, $(args...), oa, n2, n3, ns...) oa += as for j = 2:n1 - $(fa!)(false, dst, od, a, oa, n2, n3, ns...) + $(fa!)(false, dst, od, $(args...), oa, n2, n3, ns...) oa += as end else ds::Int = *(n3, ns[2:2:end]...) for j = 1:n1 - $(fa!)(isinit, dst, od, a, oa, n2, n3, ns...) + $(fa!)(isinit, dst, od, $(args...), oa, n2, n3, ns...) od += ds oa += as end @@ -258,43 +261,42 @@ function generate_reducedim_funcs(fname, comb, sker, ker0!, ker1!) # $(fb!) global $(fb!) - function $(fb!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int) + function $(fb!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int) if isinit - $(ker0!)(dst, od+1, a, oa+1, n1) + $(ker0!)(dst, od+1, $(args...), oa+1, n1) else - $(ker1!)(dst, od+1, a, oa+1, n1) + $(ker1!)(dst, od+1, $(args...), oa+1, n1) end end - function $(fb!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int, n2::Int) + function $(fb!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int, n2::Int) if isinit - $(ker0!)(dst, od+1, a, oa+1, n2) + $(ker0!)(dst, od+1, $(args...), oa+1, n2) else - $(ker1!)(dst, od+1, a, oa+1, n2) + $(ker1!)(dst, od+1, $(args...), oa+1, n2) end oa += n2 - for j = 2:n1 - $(ker1!)(dst, od+1, a, oa+1, n2) + $(ker1!)(dst, od+1, $(args...), oa+1, n2) oa += n2 end end - function $(fb!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int, n2::Int, n3::Int, ns::Int...) + function $(fb!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int, n2::Int, n3::Int, ns::Int...) as = *(n2, n3, ns...) if length(ns) & 1 == 0 ds::Int = *(n3, ns[2:2:end]...) for j = 1:n1 - $(fb!)(isinit, dst, od, a, oa, n2, n3, ns...) + $(fb!)(isinit, dst, od, $(args...), oa, n2, n3, ns...) od += ds oa += as end else - $(fb!)(isinit, dst, od, a, oa, n2, n3, ns...) + $(fb!)(isinit, dst, od, $(args...), oa, n2, n3, ns...) oa += as for j = 2:n1 - $(fb!)(false, dst, od, a, oa, n2, n3, ns...) + $(fb!)(false, dst, od, $(args...), oa, n2, n3, ns...) oa += as end end @@ -302,6 +304,11 @@ function generate_reducedim_funcs(fname, comb, sker, ker0!, ker1!) end end +function generate_reducedim_funcs(fname, comb, sker, ker0!, ker1!) + # specialized method to generate functions with single input arguments + generate_reducedim_funcs(fname, [:(a::Array)], [:a], :(size(a)), comb, sker, ker0!, ker1!) +end + macro code_reducedim(fname, comb, sker, ker0, ker1) esc(generate_reducedim_funcs(fname, comb, sker, ker0, ker1)) end