diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e446f7c9a..a4ba07f0c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -48,13 +48,13 @@ jobs: steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 - # `allow-failure` not available yet https://github.com/actions/toolkit/issues/399 + ## `allow-failure` not available yet https://github.com/actions/toolkit/issues/399 continue-on-error: ${{ matrix.julia-version == 'nightly' }} with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - uses: actions/cache@v1 - continue-on-error: ${{ matrix.julia-version == 'nightly' }} + # continue-on-error: ${{ matrix.julia-version == 'nightly' }} env: cache-name: cache-artifacts with: @@ -65,11 +65,11 @@ jobs: ${{ runner.os }}-test- ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - continue-on-error: ${{ matrix.julia-version == 'nightly' }} + # continue-on-error: ${{ matrix.julia-version == 'nightly' }} - uses: julia-actions/julia-runtest@v1 - continue-on-error: ${{ matrix.julia-version == 'nightly' }} + # continue-on-error: ${{ matrix.julia-version == 'nightly' }} - uses: codecov/codecov-action@v1 - continue-on-error: ${{ matrix.version == 'nightly' }} + # continue-on-error: ${{ matrix.version == 'nightly' }} with: file: lcov.info diff --git a/src/NNlib.jl b/src/NNlib.jl index 64b13d309..5f4851fd7 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -26,6 +26,7 @@ end include("activations.jl") include("softmax.jl") +include("misc.jl") include("batched/batchedmul.jl") include("gemm.jl") include("conv.jl") diff --git a/src/misc.jl b/src/misc.jl new file mode 100644 index 000000000..02c9fa53d --- /dev/null +++ b/src/misc.jl @@ -0,0 +1,25 @@ +export pixel_shuffle + +""" + pixel_shuffle(x, r) + +Pixel shuffling operation. `r` is the upscale factor for shuffling. +The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N] +Used extensively in super-resolution networks to upsample +towards high resolution features. + +Reference : https://arxiv.org/pdf/1609.05158.pdf +""" +function pixel_shuffle(x::AbstractArray, r::Integer) + @assert ndims(x) > 2 + d = ndims(x) - 2 + sizein = size(x)[1:d] + cin, n = size(x, d+1), size(x, d+2) + @assert cin % r^d == 0 + cout = cin ÷ r^d + # x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866 + x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n) + perm = [d+1:2d 1:d]' |> vec # = [d+1, 1, d+2, 2, ..., 2d, d] + x = permutedims(x, (perm..., 2d+1, 2d+2)) + return reshape(x, ((r .* sizein)..., cout, n)) +end diff --git a/test/misc.jl b/test/misc.jl new file mode 100644 index 000000000..2f2d724da --- /dev/null +++ b/test/misc.jl @@ -0,0 +1,63 @@ +@testset "pixel_shuffle" begin + x = reshape(1:16, (2, 2, 4, 1)) + # [:, :, 1, 1] = + # 1 3 + # 2 4 + # [:, :, 2, 1] = + # 5 7 + # 6 8 + # [:, :, 3, 1] = + # 9 11 + # 10 12 + # [:, :, 4, 1] = + # 13 15 + # 14 16 + + y_true = [1 9 3 11 + 5 13 7 15 + 2 10 4 12 + 6 14 8 16][:,:,:,:] + + y = pixel_shuffle(x, 2) + @test size(y) == size(y_true) + @test y_true == y + + x = reshape(1:32, (2, 2, 8, 1)) + y_true = zeros(Int, 4, 4, 2, 1) + y_true[:,:,1,1] .= [ 1 9 3 11 + 5 13 7 15 + 2 10 4 12 + 6 14 8 16 ] + + y_true[:,:,2,1] .= [ 17 25 19 27 + 21 29 23 31 + 18 26 20 28 + 22 30 24 32] + + y = pixel_shuffle(x, 2) + @test size(y) == size(y_true) + @test y_true == y + + x = reshape(1:4*3*27*2, (4,3,27,2)) + y = pixel_shuffle(x, 3) + @test size(y) == (12, 9, 3, 2) + # batch dimension is preserved + x1 = x[:,:,:,[1]] + x2 = x[:,:,:,[2]] + y1 = pixel_shuffle(x1, 3) + y2 = pixel_shuffle(x2, 3) + @test cat(y1, y2, dims=4) == y + + for d in [1, 2, 3] + r = rand(1:5) + n = rand(1:5) + c = rand(1:5) + insize = rand(1:5, d) + x = rand(insize..., r^d*c, n) + + y = pixel_shuffle(x, r) + @test size(y) == ((r .* insize)..., c, n) + + gradtest(x -> pixel_shuffle(x, r), x) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 480a88d96..0ef0676fc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,3 +35,7 @@ end @testset "Softmax" begin include("softmax.jl") end + +@testset "Misc Stuff" begin + include("misc.jl") +end