Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More efficient permutedims #6517

Closed
wants to merge 14 commits into from
30 changes: 29 additions & 1 deletion base/cartesian.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Cartesian

export @ngenerate, @nsplat, @nloops, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, @nif, ngenerate
export @ngenerate, @nsplat, @nloops, @nfunction, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, @nif, ngenerate

const CARTESIAN_DIMS = 4

Expand Down Expand Up @@ -299,6 +299,34 @@ function _nloops(N::Int, itersym::Symbol, rangeexpr::Expr, args::Expr...)
ex
end

# Generate function f(pre,i_1::T,i_2::T,..) from @nfunction N f pre i::T body
macro nfunction(N, fname, args...)
_nfunction(N, fname, args...)
end

function _nfunction(N::Int, fname::Symbol, args...)
if length(args) < 2
error("argument missing")
end

prearg = args[1:end-2]
for k=1:length(prearg)
if !(isa(prearg[k],Symbol) || (isa(prearg[k],Expr) && prearg[k].head==:(::) && isa(prearg[k].args[1],Symbol) && isa(prearg[k].args[2],Symbol)))
error("invalid argument type for pre arguments")
end
end
iterarg = args[end-1]
if !(isa(iterarg,Symbol) || (isa(iterarg,Expr) && iterarg.head==:(::) && isa(iterarg.args[1],Symbol) && isa(iterarg.args[2],Symbol)))
error("invalid argument type for argument that will be iterated ")
end
iterarglist=(isa(iterarg,Symbol) ? [inlineanonymous(iterarg,i) for i=1:N] : [Expr(:(::),inlineanonymous(iterarg.args[1],i),iterarg.args[2]) for i=1:N])
fcall=Expr(:call,fname,prearg...,iterarglist...)

body = args[end]

ex=Expr(:escape,Expr(:function,fcall,body))
end

# Generate expression A[i1, i2, ...]
macro nref(N, A, sym)
_nref(N, A, sym)
Expand Down
1 change: 0 additions & 1 deletion base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,6 @@ export
permutations,
permute!,
permutedims,
permutedims!,
prod!,
prod,
promote_shape,
Expand Down
166 changes: 166 additions & 0 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,172 @@ for (V, PT, BT) in {((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)}
end
end

@ngenerate N typeof(P) function permutedimsnew!{T1,T2,N}(P::StridedArray{T1,N},B::StridedArray{T2,N},perm,basesize::Int=1024)
length(perm) == N || error("expected permutation of size $N, but length(perm)=$(length(perm))")
isperm(perm) || error("input is not a permutation")
dims = size(P)
for i = 1:N
dims[i] == size(B,perm[i]) || throw(DimensionMismatch("destination tensor of incorrect size"))
end
@nexprs N d->(stridesB_{d} = stride(B,perm[d]))
@nexprs N d->(stridesP_{d} = stride(P,d))
@nexprs N d->(dims_{d} = dims[d])

if isa(B, SubArray)
startB = B.first_index
B = B.parent
else
startB = 1
end
if isa(P, SubArray)
startP = P.first_index
P = P.parent
else
startP = 1
end

if prod(dims)<=4*basesize
# copy data
@nexprs 1 d->(indB_{N} = startB)
@nexprs 1 d->(indP_{N} = startP)
@nloops(N, i, d->1:dims_{d},
d->(indB_{d-1} = indB_{d};indP_{d-1}=indP_{d}), # PRE
d->(indB_{d} += stridesB_{d};indP_{d} += stridesP_{d}), # POST
@inbounds P[indP_0]=B[indB_0])
else
@nexprs N d->(minstrides_{d} = min(stridesB_{d},stridesP_{d}))

M=iceil(log2(prod(dims)/basesize))
step=zeros(Int,M)
level=1
@nexprs N d->(vecbdims_{d} = zeros(Int,M))
@nexprs N d->(vecbdims_{d}[level] = dims_{d})
vecoffsetB=zeros(Int,M)
vecoffsetP=zeros(Int,M)
vecdP=zeros(Int,M)
vecdB=zeros(Int,M)
vecdmax=zeros(Int,M)
vecnewdim=zeros(Int,M)
while level>0
if level==M
@nexprs N d->(bdims_{d} = vecbdims_{d}[M])
@nexprs 1 d->(indP_{N} = startP+vecoffsetP[M])
@nexprs 1 d->(indB_{N} = startB+vecoffsetB[M])
@nloops(N, i, d->1:bdims_{d},
d->(indB_{d-1} = indB_{d};indP_{d-1}=indP_{d}), # PRE
d->(indB_{d} += stridesB_{d};indP_{d} += stridesP_{d}), # POST
@inbounds P[indP_0]=B[indB_0])
level-=1
elseif step[level]==0
@nexprs N d->(bdims_{d} = vecbdims_{d}[level])
dmax=1
maxval=minstrides_1*bdims_1
newdim=bdims_1>>1
dP=stridesP_1
dB=stridesB_1
@nexprs N d->(newmax=minstrides_{d}*bdims_{d};if bdims_{d}>1 && newmax>maxval;dmax=d;newdim=bdims_{d}>>1;dP=stridesP_{d};dB=stridesB_{d};maxval=newmax;end)
vecnewdim[level]=newdim
vecdmax[level]=dmax
vecdP[level]=dP
vecdB[level]=dB

@nexprs N d->(vecbdims_{d}[level+1] = (d==dmax ? newdim : bdims_{d}))
vecoffsetP[level+1]=vecoffsetP[level]
vecoffsetB[level+1]=vecoffsetB[level]
step[level+1]=0

step[level]+=1
level+=1
elseif step[level]==1
@nexprs N d->(bdims_{d} = vecbdims_{d}[level])

@nexprs N d->(vecbdims_{d}[level+1] = (d==vecdmax[level] ? bdims_{d}-vecnewdim[level] : bdims_{d}))
vecoffsetP[level+1]=vecoffsetP[level]+vecdP[level]*vecnewdim[level]
vecoffsetB[level+1]=vecoffsetB[level]+vecdB[level]*vecnewdim[level]
step[level+1]=0

step[level]+=1
level+=1
else
level-=1
end
end
end
return P
end

@ngenerate N typeof(P) function permutedimsnew2!{T1,T2,N}(P::StridedArray{T1,N},B::StridedArray{T2,N},perm,basesize::Int=1024)
length(perm) == N || error("expected permutation of size $N, but length(perm)=$(length(perm))")
isperm(perm) || error("input is not a permutation")
dims = size(P)
for i = 1:N
dims[i] == size(B,perm[i]) || throw(DimensionMismatch("destination tensor of incorrect size"))
end
@nexprs N d->(stridesB_{d} = stride(B,perm[d]))
@nexprs N d->(stridesP_{d} = stride(P,d))
@nexprs N d->(dims_{d} = dims[d])

if isa(B, SubArray)
startB = B.first_index
B = B.parent
else
startB = 1
end
if isa(P, SubArray)
startP = P.first_index
P = P.parent
else
startP = 1
end

@nfunction(N,innerbase,offsetP::Int,offsetB::Int,bdims::Int,begin
@nexprs 1 d->(indB_{N} = startB+offsetP)
@nexprs 1 d->(indP_{N} = startP+offsetB)
@nloops(N, i, d->1:bdims_{d},
d->(indB_{d-1} = indB_{d};indP_{d-1}=indP_{d}), # PRE
d->(indB_{d} += stridesB_{d};indP_{d} += stridesP_{d}), # POST
@inbounds P[indP_0]=B[indB_0])
end)

if prod(dims)<=4*basesize
@ncall N innerbase 0 0 dims
else
@nexprs N d->(minstrides_{d} = min(stridesB_{d},stridesP_{d}))

@nfunction(N,innerrec,offsetP::Int,offsetB::Int,bdims::Int,begin
currentsize=1
@nexprs N d->(currentsize *=bdims_{d})
if currentsize<=basesize
@ncall N innerbase offsetP offsetB bdims
else
dmax=1
maxval=minstrides_1*bdims_1
@nexprs N d->(begin
newmax=minstrides_{d}*bdims_{d}
if bdims_{d}>1 && newmax>maxval
dmax=d
maxval=newmax
end
end)
@nexprs N d->(begin
if d==dmax
olddim=bdims_{d}
newdim=olddim>>1
bdims_{d}=newdim
@ncall N innerrec offsetP offsetB bdims
bdims_{d}=olddim-newdim
offsetP+=stridesP_{d}*newdim
offsetB+=stridesB_{d}*newdim
@ncall N innerrec offsetP offsetB bdims
end
end)
end
end)
@ncall N innerrec 0 0 dims
end
return P
end

## unique across dim

# TODO: this doesn't fit into the new hashing scheme in any obvious way
Expand Down