diff --git a/src/NNlib.jl b/src/NNlib.jl index db2bce887..0f8641309 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -10,6 +10,7 @@ include("softmax.jl") include("gemm.jl") include("conv.jl") include("pooling.jl") +include("conv_op.jl") ## Include implementations include("impl/padding_edges.jl") diff --git a/src/conv_op.jl b/src/conv_op.jl new file mode 100644 index 000000000..4ca06b006 --- /dev/null +++ b/src/conv_op.jl @@ -0,0 +1,49 @@ +# Defines operations on convolutions +export pixel_shuffle + +using Base.Iterators:partition + +""" + pixel_shuffle(x,r) + +Pixel shuffling operation. `r` is the scale factor for shuffling. + +The operation converts an input of size [W,H,r²C,N] to [rW,rH,C,N] +Used extensively in super-resolution networks to upsample towrads high reolution feature ma + +Reference : https://arxiv.org/pdf/1609.05158.pdf +""" +function split_channels(x::AbstractArray,val::Int) # Split chaannels into `val` partitions + indices = collect(1:size(x)[end-1]) + channels_par = partition(indices,div(size(x)[end-1],val)) + + out = [] + for c in channels_par + c = [c_ for c_ in c] + push!(out,x[:,:,c,:]) + end + return out +end + +""" +phaseShift cyclically reshapes and permutes the channels +""" +function phase_shift(x,r) + W,H,C,N = size(x) + x = reshape(x,W,H,r,r,N) + x = [x[i,:,:,:,:] for i in 1:W] + x = cat([t for t in x]...,dims=2) + x = [x[i,:,:,:] for i in 1:size(x)[1]] + x = cat([t for t in x]...,dims=2) + x +end + +function pixel_shuffle(x,r=3) + ndims(x) == 4 || error("PixelShuffle defined only for arrays of dimension 4") + (size(x)[end-1])%(r*r) == 0 || error("Number of channels($(size(x)[end-1])) must be divisible by $(r*r)") + + C_out = div(size(x)[end-1],r*r) + sch = split_channels(x,C_out) + out = cat([phase_shift(c,r) for c in split_channels(x,C_out)]...,dims=3) + reshape(out,size(out)[1],size(out)[2],C_out,div(size(out)[end],C_out)) +end diff --git a/test/conv.jl b/test/conv.jl index 6a3c593b3..989df7ee6 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -638,4 +638,10 @@ end end println() end + + @testset "pixel_shuffle" begin + x = ones(2,2,18,5) + + @test size(pixel_shuffle(x,3)) == (6,6,2,5) + end end \ No newline at end of file