diff --git a/base/cartesian.jl b/base/cartesian.jl index 7bbf6e52ccfca..a892241f34e48 100644 --- a/base/cartesian.jl +++ b/base/cartesian.jl @@ -1,6 +1,6 @@ module Cartesian -export @ngenerate, @nsplat, @nloops, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, @nif, ngenerate +export @ngenerate, @nsplat, @nloops, @nfunction, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, @nif, ngenerate const CARTESIAN_DIMS = 4 @@ -299,6 +299,34 @@ function _nloops(N::Int, itersym::Symbol, rangeexpr::Expr, args::Expr...) ex end +# Generate function f(pre,i_1::T,i_2::T,..) from @nfunction N f pre i::T body +macro nfunction(N, fname, args...) + _nfunction(N, fname, args...) +end + +function _nfunction(N::Int, fname::Symbol, args...) + if length(args) < 2 + error("argument missing") + end + + prearg = args[1:end-2] + for k=1:length(prearg) + if !(isa(prearg[k],Symbol) || (isa(prearg[k],Expr) && prearg[k].head==:(::) && isa(prearg[k].args[1],Symbol) && isa(prearg[k].args[2],Symbol))) + error("invalid argument type for pre arguments") + end + end + iterarg = args[end-1] + if !(isa(iterarg,Symbol) || (isa(iterarg,Expr) && iterarg.head==:(::) && isa(iterarg.args[1],Symbol) && isa(iterarg.args[2],Symbol))) + error("invalid argument type for argument that will be iterated ") + end + iterarglist=(isa(iterarg,Symbol) ? [inlineanonymous(iterarg,i) for i=1:N] : [Expr(:(::),inlineanonymous(iterarg.args[1],i),iterarg.args[2]) for i=1:N]) + fcall=Expr(:call,fname,prearg...,iterarglist...) + + body = args[end] + + ex=Expr(:escape,Expr(:function,fcall,body)) +end + # Generate expression A[i1, i2, ...] macro nref(N, A, sym) _nref(N, A, sym) diff --git a/base/exports.jl b/base/exports.jl index a0c47747b95cb..5899d784e4e40 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -548,7 +548,6 @@ export permutations, permute!, permutedims, - permutedims!, prod!, prod, promote_shape, diff --git a/base/multidimensional.jl b/base/multidimensional.jl index cde16603f0e9b..6bad90ccb61e4 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -478,6 +478,172 @@ for (V, PT, BT) in {((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)} end end +@ngenerate N typeof(P) function permutedimsnew!{T1,T2,N}(P::StridedArray{T1,N},B::StridedArray{T2,N},perm,basesize::Int=1024) + length(perm) == N || error("expected permutation of size $N, but length(perm)=$(length(perm))") + isperm(perm) || error("input is not a permutation") + dims = size(P) + for i = 1:N + dims[i] == size(B,perm[i]) || throw(DimensionMismatch("destination tensor of incorrect size")) + end + @nexprs N d->(stridesB_{d} = stride(B,perm[d])) + @nexprs N d->(stridesP_{d} = stride(P,d)) + @nexprs N d->(dims_{d} = dims[d]) + + if isa(B, SubArray) + startB = B.first_index + B = B.parent + else + startB = 1 + end + if isa(P, SubArray) + startP = P.first_index + P = P.parent + else + startP = 1 + end + + if prod(dims)<=4*basesize + # copy data + @nexprs 1 d->(indB_{N} = startB) + @nexprs 1 d->(indP_{N} = startP) + @nloops(N, i, d->1:dims_{d}, + d->(indB_{d-1} = indB_{d};indP_{d-1}=indP_{d}), # PRE + d->(indB_{d} += stridesB_{d};indP_{d} += stridesP_{d}), # POST + @inbounds P[indP_0]=B[indB_0]) + else + @nexprs N d->(minstrides_{d} = min(stridesB_{d},stridesP_{d})) + + M=iceil(log2(prod(dims)/basesize)) + step=zeros(Int,M) + level=1 + @nexprs N d->(vecbdims_{d} = zeros(Int,M)) + @nexprs N d->(vecbdims_{d}[level] = dims_{d}) + vecoffsetB=zeros(Int,M) + vecoffsetP=zeros(Int,M) + vecdP=zeros(Int,M) + vecdB=zeros(Int,M) + vecdmax=zeros(Int,M) + vecnewdim=zeros(Int,M) + while level>0 + if level==M + @nexprs N d->(bdims_{d} = vecbdims_{d}[M]) + @nexprs 1 d->(indP_{N} = startP+vecoffsetP[M]) + @nexprs 1 d->(indB_{N} = startB+vecoffsetB[M]) + @nloops(N, i, d->1:bdims_{d}, + d->(indB_{d-1} = indB_{d};indP_{d-1}=indP_{d}), # PRE + d->(indB_{d} += stridesB_{d};indP_{d} += stridesP_{d}), # POST + @inbounds P[indP_0]=B[indB_0]) + level-=1 + elseif step[level]==0 + @nexprs N d->(bdims_{d} = vecbdims_{d}[level]) + dmax=1 + maxval=minstrides_1*bdims_1 + newdim=bdims_1>>1 + dP=stridesP_1 + dB=stridesB_1 + @nexprs N d->(newmax=minstrides_{d}*bdims_{d};if bdims_{d}>1 && newmax>maxval;dmax=d;newdim=bdims_{d}>>1;dP=stridesP_{d};dB=stridesB_{d};maxval=newmax;end) + vecnewdim[level]=newdim + vecdmax[level]=dmax + vecdP[level]=dP + vecdB[level]=dB + + @nexprs N d->(vecbdims_{d}[level+1] = (d==dmax ? newdim : bdims_{d})) + vecoffsetP[level+1]=vecoffsetP[level] + vecoffsetB[level+1]=vecoffsetB[level] + step[level+1]=0 + + step[level]+=1 + level+=1 + elseif step[level]==1 + @nexprs N d->(bdims_{d} = vecbdims_{d}[level]) + + @nexprs N d->(vecbdims_{d}[level+1] = (d==vecdmax[level] ? bdims_{d}-vecnewdim[level] : bdims_{d})) + vecoffsetP[level+1]=vecoffsetP[level]+vecdP[level]*vecnewdim[level] + vecoffsetB[level+1]=vecoffsetB[level]+vecdB[level]*vecnewdim[level] + step[level+1]=0 + + step[level]+=1 + level+=1 + else + level-=1 + end + end + end + return P +end + +@ngenerate N typeof(P) function permutedimsnew2!{T1,T2,N}(P::StridedArray{T1,N},B::StridedArray{T2,N},perm,basesize::Int=1024) + length(perm) == N || error("expected permutation of size $N, but length(perm)=$(length(perm))") + isperm(perm) || error("input is not a permutation") + dims = size(P) + for i = 1:N + dims[i] == size(B,perm[i]) || throw(DimensionMismatch("destination tensor of incorrect size")) + end + @nexprs N d->(stridesB_{d} = stride(B,perm[d])) + @nexprs N d->(stridesP_{d} = stride(P,d)) + @nexprs N d->(dims_{d} = dims[d]) + + if isa(B, SubArray) + startB = B.first_index + B = B.parent + else + startB = 1 + end + if isa(P, SubArray) + startP = P.first_index + P = P.parent + else + startP = 1 + end + + @nfunction(N,innerbase,offsetP::Int,offsetB::Int,bdims::Int,begin + @nexprs 1 d->(indB_{N} = startB+offsetP) + @nexprs 1 d->(indP_{N} = startP+offsetB) + @nloops(N, i, d->1:bdims_{d}, + d->(indB_{d-1} = indB_{d};indP_{d-1}=indP_{d}), # PRE + d->(indB_{d} += stridesB_{d};indP_{d} += stridesP_{d}), # POST + @inbounds P[indP_0]=B[indB_0]) + end) + + if prod(dims)<=4*basesize + @ncall N innerbase 0 0 dims + else + @nexprs N d->(minstrides_{d} = min(stridesB_{d},stridesP_{d})) + + @nfunction(N,innerrec,offsetP::Int,offsetB::Int,bdims::Int,begin + currentsize=1 + @nexprs N d->(currentsize *=bdims_{d}) + if currentsize<=basesize + @ncall N innerbase offsetP offsetB bdims + else + dmax=1 + maxval=minstrides_1*bdims_1 + @nexprs N d->(begin + newmax=minstrides_{d}*bdims_{d} + if bdims_{d}>1 && newmax>maxval + dmax=d + maxval=newmax + end + end) + @nexprs N d->(begin + if d==dmax + olddim=bdims_{d} + newdim=olddim>>1 + bdims_{d}=newdim + @ncall N innerrec offsetP offsetB bdims + bdims_{d}=olddim-newdim + offsetP+=stridesP_{d}*newdim + offsetB+=stridesB_{d}*newdim + @ncall N innerrec offsetP offsetB bdims + end + end) + end + end) + @ncall N innerrec 0 0 dims + end + return P +end + ## unique across dim # TODO: this doesn't fit into the new hashing scheme in any obvious way