Skip to content

Commit

Permalink
Merge #1468
Browse files Browse the repository at this point in the history
1468: add Upsample and PixelShuffle layers r=DhairyaLGandhi a=CarloLucibello


### PR Checklist

- [x] Tests are added
- [x] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>
Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
  • Loading branch information
4 people authored Feb 11, 2021
2 parents cdb445c + 9acaae9 commit ddb5e9c
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 12 deletions.
8 changes: 4 additions & 4 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.3.4+0"

[[DataAPI]]
git-tree-sha1 = "6d64b28d291cb94a0d84e6e41081fb081e7f717f"
git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.5.0"
version = "1.5.1"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -236,9 +236,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[NNlib]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "573cc0d31f9697b9d2b060130a7a3c05a4f36b78"
git-tree-sha1 = "df42d0816edfc24f5b82a728f46381613c4dff79"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.12"
version = "0.7.14"

[[NaNMath]]
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* Moved GPU CI to use buildkite instead of GitLab
* New [`Parallel` layer](https://github.com/FluxML/Flux.jl/pull/1462) adds inception module-like building blocks.
* Feature additions and bug fixes for BatchNorm, LayerNorm, InstanceNorm, and GroupNorm [normalization layers](https://github.com/FluxML/Flux.jl/pull/1397)
* Added [Upsample and PixelShuffle layers](https://github.com/FluxML/Flux.jl/pull/1468)

## v0.11.2

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Colors = "0.12"
Functors = "0.1, 0.2"
Juno = "0.8"
MacroTools = "0.5"
NNlib = "0.7.10"
NNlib = "0.7.14"
Reexport = "0.2, 1.0"
StatsBase = "0.33"
ZipFile = "0.9"
Expand Down
7 changes: 7 additions & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ Flux.convfilter
Flux.depthwiseconvfilter
```

## Upsampling Layers

```@docs
Upsample
PixelShuffle
```

## Recurrent Layers

Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
Expand Down
8 changes: 8 additions & 0 deletions docs/src/models/nnlib.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ NNlib.conv
NNlib.depthwiseconv
```

## Upsampling

```@docs
NNlib.upsample_nearest
NNlib.upsample_bilinear
NNlib.pixel_shuffle
```

## Batched Operations

```@docs
Expand Down
2 changes: 2 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
Upsample, PixelShuffle,
params, fmap, cpu, gpu, f32, f64,
testmode!, trainmode!

Expand All @@ -42,6 +43,7 @@ include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")
include("layers/upsample.jl")

include("outputsize.jl")

Expand Down
79 changes: 79 additions & 0 deletions src/layers/upsample.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Upsample(mode = :nearest; [scale, size])
Upsample(scale, mode = :nearest)
An upsampling layer. One of two keywords must be given:
If `scale` is a number, this applies to all but the last two dimensions (channel and batch) of the input.
It may also be a tuple, to control dimensions individually. Alternatively, keyword
`size` accepts a tuple, to directly specify the leading dimensions of the output.
Currently supported upsampling `mode`s
and corresponding NNlib's methods are:
- `:nearest` -> [`NNlib.upsample_nearest`](@ref)
- `:bilinear` -> [`NNlib.upsample_bilinear`](@ref)
# Examples
```juliarepl
julia> m = Upsample(scale = (2, 3))
Upsample(:nearest, scale = (2, 3))
julia> m(ones(2, 2, 1, 1)) |> size
(4, 6, 1, 1)
julia> m = Upsample(:bilinear, size = (4, 5))
Upsample(:bilinear, size = (4, 5))
julia> m(ones(2, 2, 1, 1)) |> size
(4, 5, 1, 1)
"""
struct Upsample{mode, S, T}
scale::S
size::T
end

function Upsample(mode::Symbol = :nearest; scale = nothing, size = nothing)
mode in [:nearest, :bilinear] ||
throw(ArgumentError("mode=:$mode is not supported."))
if !(isnothing(scale) isnothing(size))
throw(ArgumentError("Either scale or size should be specified (but not both)."))
end
return Upsample{mode,typeof(scale),typeof(size)}(scale, size)
end

Upsample(scale, mode::Symbol = :nearest) = Upsample(mode; scale)

(m::Upsample{:nearest})(x::AbstractArray) =
NNlib.upsample_nearest(x, m.scale)
function (m::Upsample{:nearest, Int})(x::AbstractArray{T, N}) where {T, N}
NNlib.upsample_nearest(x, ntuple(i -> m.scale, N-2))
end
(m::Upsample{:nearest, Nothing})(x::AbstractArray) =
NNlib.upsample_nearest(x; size=m.size)

(m::Upsample{:bilinear})(x::AbstractArray) =
NNlib.upsample_bilinear(x, m.scale)
(m::Upsample{:bilinear, Nothing})(x::AbstractArray) =
NNlib.upsample_bilinear(x; size=m.size)

function Base.show(io::IO, u::Upsample{mode}) where {mode}
print(io, "Upsample(")
print(io, ":", mode)
u.scale !== nothing && print(io, ", scale = $(u.scale)")
u.size !== nothing && print(io, ", size = $(u.size)")
print(io, ")")
end

"""
PixelShuffle(r::Int)
Pixel shuffling layer with upscale factor `r`.
See [`NNlib.pixel_shuffle`](@ref).
"""
struct PixelShuffle
r::Int
end

(m::PixelShuffle)(x) = NNlib.pixel_shuffle(x, m.r)
2 changes: 1 addition & 1 deletion src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Base.isless(::Nil, ::Number) = true
Base.isless(::Number, ::Nil) = true

Base.isnan(::Nil) = false

Base.isfinite(::Nil) = true
Base.typemin(::Type{Nil}) = nil
Base.typemax(::Type{Nil}) = nil

Expand Down
11 changes: 10 additions & 1 deletion test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ gpu_gradtest("GroupNorm 3d", groupnorm, rand(Float32, 8, 8, 8, 12, 4), 12, 3, se
gpu_gradtest("GroupNorm 2d", groupnorm, rand(Float32, 8, 8, 12, 4), 12, 3, setmode=true)
gpu_gradtest("GroupNorm 1d", groupnorm, rand(Float32, 8, 3, 12, 4), 12, 3, setmode=true)

upsample = [x -> Upsample(scale=x)]
gpu_gradtest("Upsample 2d", upsample, rand(Float32, 3, 4, 2, 3), (2,2))
gpu_gradtest("Upsample 1d", upsample, rand(Float32, 3, 4, 2, 3), (2,))

pixelshuffle = [PixelShuffle]
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)


@testset "function layers" begin
x = rand(Float32, 3,3)
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)
Expand Down Expand Up @@ -168,4 +177,4 @@ end
@test sum(l(ip)) 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.b gs.params
end
end
67 changes: 67 additions & 0 deletions test/layers/upsample.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
@testset "upsample bilinear" begin
m = Upsample(:bilinear, scale=(2, 3))
x = rand(Float32, 3, 4, 2, 3)
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (6, 12, 2, 3)

m = Upsample(:bilinear, scale=3)
x = rand(Float32, 3, 4, 2, 3)
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (9, 12, 2, 3)

m = Upsample(:bilinear, size=(4, 6))
x = rand(Float32, 3, 4, 2, 3)
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (4, 6, 2, 3)
end

@testset "upsample nearest" begin
x = rand(Float32, 3, 2, 3)
m = Upsample(:nearest, scale=(2,))
y = m(x)
@test y isa Array{Float32, 3}
@test size(y) == (6, 2, 3)

x = rand(Float32, 3, 4, 2, 3)

m = Upsample(:nearest, scale=(2, 3))
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (6, 12, 2, 3)

m = Upsample(:nearest, scale=(2,))
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (6, 4, 2, 3)

m = Upsample(:nearest, scale=2)
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (6, 8, 2, 3)

m = Upsample(2)
y2 = m(x)
@test y2 y

m = Upsample(:nearest, size=(6,8))
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (6, 8, 2, 3)
end

@testset "PixelShuffle" begin
m = PixelShuffle(2)
x = rand(Float32, 3, 18, 3)
y = m(x)
@test y isa Array{Float32, 3}
@test size(y) == (6, 9, 3)

m = PixelShuffle(3)
x = rand(Float32, 3, 4, 18, 3)
y = m(x)
@test y isa Array{Float32, 4}
@test size(y) == (9, 12, 2, 3)
end
8 changes: 3 additions & 5 deletions test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ end
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16)
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1)

if VERSION >= v"1.1"
m = GroupNorm(16, 4)
@test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16)
@test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1)
end
m = GroupNorm(16, 4)
@test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16)
@test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ end
include("layers/stateless.jl")
include("layers/recurrent.jl")
include("layers/conv.jl")
include("layers/upsample.jl")
end

@testset "outputsize" begin
Expand Down

0 comments on commit ddb5e9c

Please sign in to comment.