Skip to content

Commit

Permalink
Merge pull request #260 from FluxML/cl/pixel
Browse files Browse the repository at this point in the history
add pixel_shuffle
  • Loading branch information
CarloLucibello authored Dec 31, 2020
2 parents 8f1116e + dcb015f commit c08258b
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 5 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ jobs:
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
# `allow-failure` not available yet https://github.com/actions/toolkit/issues/399
## `allow-failure` not available yet https://github.com/actions/toolkit/issues/399
continue-on-error: ${{ matrix.julia-version == 'nightly' }}
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
continue-on-error: ${{ matrix.julia-version == 'nightly' }}
# continue-on-error: ${{ matrix.julia-version == 'nightly' }}
env:
cache-name: cache-artifacts
with:
Expand All @@ -65,11 +65,11 @@ jobs:
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
continue-on-error: ${{ matrix.julia-version == 'nightly' }}
# continue-on-error: ${{ matrix.julia-version == 'nightly' }}
- uses: julia-actions/julia-runtest@v1
continue-on-error: ${{ matrix.julia-version == 'nightly' }}
# continue-on-error: ${{ matrix.julia-version == 'nightly' }}
- uses: codecov/codecov-action@v1
continue-on-error: ${{ matrix.version == 'nightly' }}
# continue-on-error: ${{ matrix.version == 'nightly' }}
with:
file: lcov.info

Expand Down
1 change: 1 addition & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ end
include("activations.jl")

include("softmax.jl")
include("misc.jl")
include("batched/batchedmul.jl")
include("gemm.jl")
include("conv.jl")
Expand Down
25 changes: 25 additions & 0 deletions src/misc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
export pixel_shuffle

"""
pixel_shuffle(x, r)
Pixel shuffling operation. `r` is the upscale factor for shuffling.
The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N]
Used extensively in super-resolution networks to upsample
towards high resolution features.
Reference : https://arxiv.org/pdf/1609.05158.pdf
"""
function pixel_shuffle(x::AbstractArray, r::Integer)
@assert ndims(x) > 2
d = ndims(x) - 2
sizein = size(x)[1:d]
cin, n = size(x, d+1), size(x, d+2)
@assert cin % r^d == 0
cout = cin ÷ r^d
# x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866
x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n)
perm = [d+1:2d 1:d]' |> vec # = [d+1, 1, d+2, 2, ..., 2d, d]
x = permutedims(x, (perm..., 2d+1, 2d+2))
return reshape(x, ((r .* sizein)..., cout, n))
end
63 changes: 63 additions & 0 deletions test/misc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
@testset "pixel_shuffle" begin
x = reshape(1:16, (2, 2, 4, 1))
# [:, :, 1, 1] =
# 1 3
# 2 4
# [:, :, 2, 1] =
# 5 7
# 6 8
# [:, :, 3, 1] =
# 9 11
# 10 12
# [:, :, 4, 1] =
# 13 15
# 14 16

y_true = [1 9 3 11
5 13 7 15
2 10 4 12
6 14 8 16][:,:,:,:]

y = pixel_shuffle(x, 2)
@test size(y) == size(y_true)
@test y_true == y

x = reshape(1:32, (2, 2, 8, 1))
y_true = zeros(Int, 4, 4, 2, 1)
y_true[:,:,1,1] .= [ 1 9 3 11
5 13 7 15
2 10 4 12
6 14 8 16 ]

y_true[:,:,2,1] .= [ 17 25 19 27
21 29 23 31
18 26 20 28
22 30 24 32]

y = pixel_shuffle(x, 2)
@test size(y) == size(y_true)
@test y_true == y

x = reshape(1:4*3*27*2, (4,3,27,2))
y = pixel_shuffle(x, 3)
@test size(y) == (12, 9, 3, 2)
# batch dimension is preserved
x1 = x[:,:,:,[1]]
x2 = x[:,:,:,[2]]
y1 = pixel_shuffle(x1, 3)
y2 = pixel_shuffle(x2, 3)
@test cat(y1, y2, dims=4) == y

for d in [1, 2, 3]
r = rand(1:5)
n = rand(1:5)
c = rand(1:5)
insize = rand(1:5, d)
x = rand(insize..., r^d*c, n)

y = pixel_shuffle(x, r)
@test size(y) == ((r .* insize)..., c, n)

gradtest(x -> pixel_shuffle(x, r), x)
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ end
@testset "Softmax" begin
include("softmax.jl")
end

@testset "Misc Stuff" begin
include("misc.jl")
end

0 comments on commit c08258b

Please sign in to comment.