diff --git a/src/Flux.jl b/src/Flux.jl index 9969b32346..6a987fdfde 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,7 @@ export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, - SkipConnection, params, fmap, cpu, gpu, f32, f64 + WeightNorm, WeightNormParam, SkipConnection, params, fmap, cpu, gpu, f32, f64 include("optimise/Optimise.jl") using .Optimise diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b421d3e7f0..ba03ac2250 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -365,3 +365,91 @@ function Base.show(io::IO, l::GroupNorm) (l.λ == identity) || print(io, ", λ = $(l.λ)") print(io, ")") end + + +""" +Weight Normalization. +This layer reparametrizes weights (w) of a layer with its decomposition into magnitude (g) and direction (v). +WeightNorm has been implemented solely for `Dense` layers in Flux. + + WeightNorm(layer, weight, dim) + +``layer`` is the layer being normalized. + +``weight`` are the parameters to be normalized. + +``dim`` are the dimension of normalization. +Often, its the dimension encoding the output channels. + +Example: +``` +d = Dense(10, 9, tanh); +wndA = WeightNorm(d, :W, 2); #The param d.W is now normalized in the second dimension, i.e normalization per output channel +wndB = WeightNorm(d, :W, [1:2]); #Now we normalize all directions together, keeping a single magnitude +``` + +Link : https://arxiv.org/pdf/1602.07868.pdf +""" + +struct WeightNormParam{T,N,I} + g::AbstractArray{T,N} + v::AbstractArray{T,N} + dim::I +end + +Base.size(w::WeightNormParam, i...) = size(w.v, i...) +Base.size(w::WeightNormParam) = size(w.v) +Base.iterate(w::WeightNormParam, i...) = iterate(w.g .* w.v ./ WN_mag(w.v, w.dim), i...) +Base.getindex(w::WeightNormParam, i...) = getindex(w.g .* w.v ./ WN_mag(w.v, w.dim), i...) +Base.ndims(w::WeightNormParam) = ndims(w.v) +Base.length(w::WeightNormParam) = length(w.v) + +@functor WeightNormParam + +WN_mag(p, dim, eps) = sqrt.(sum(abs2.(p), dims = dim)) .+ eps +WN_mag(p, dim) = WN_mag(p, dim, eps(eltype(p))) +WN_dir(p, mag) = p ./ mag + +import Base.*, Base./, Base.+, Base.- +for f in (:+, :-, :*, :/) + @eval ($f)(z::AbstractArray, w::WeightNormParam) = ($f)(z, w.g .* w.v ./ WN_mag(w.v, w.dim)) + @eval ($f)(w::WeightNormParam, z::AbstractArray) = ($f)(w.g .* w.v ./ WN_mag(w.v, w.dim), z) +end + +struct WeightNorm{L} + layer::L + eps::Number + weight::Vector + dim::Vector +end + +@functor WeightNorm + +function Base.show(io::IO, wn::WeightNorm) + print(io, "WeightNorm(", wn.layer, ", ", wn.weight, ", ", wn.dim, ")") +end + +function WeightNorm(layer, weight::Vector, dim::Vector) + if !isa(layer, Dense) + error("WeightNorm is defined only for Dense layers!") + end + #Expose layer fields and constructor + func, re = Flux.functor(layer) + #Get the fields + par = [getfield(layer, fn) for fn in keys(func)] + w = map(weight) do W + getfield(layer, W) + end + g = map((W, D) -> WN_mag(W, D), w, dim) + v = map((W, G) -> WN_dir(W, G), w, g) + par[indexin(weight,collect(keys(func)))] = WeightNormParam.(g, v, dim) + return WeightNorm(re(par), eps(Float32), weight, dim) +end + +WeightNorm(layer, weight::Symbol, dim::Vector) = WeightNorm(layer, [weight], dim) +WeightNorm(layer, weight::Symbol, dim::Integer) = WeightNorm(layer, [weight], [dim]) +WeightNorm(layer, weight::Vector, dim::Integer) = WeightNorm(layer, weight, [dim for _ in axes(weight,1)]) + +function (wn::WeightNorm)(x) + wn.layer(x) +end \ No newline at end of file diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 4399a25608..fa060a04fc 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -191,6 +191,35 @@ end end +@testset "WeightNorm" begin + let fake_data = randn(Float32, 10,3) + d = Dense(10, 9, tanh) + gs = gradient(() -> sum(abs2, d(fake_data)), params(d)) + W = d.W + for WN_dim in [[1], 1, [2], 2, [1:2]] + wnd = WeightNorm(d, :W, WN_dim) + gswn = gradient(() -> sum(abs2, wnd(fake_data)), params(wnd)) + g = wnd.layer.W.g + v = wnd.layer.W.v + + ΔW = gs[W] + Δg = gswn[g] + Δv = gswn[v] + @test wnd(fake_data) ≈ d(fake_data) + if isa(WN_dim, Int) + normv = sum(abs2, v, dims = WN_dim) + @test sum(ΔW .* v ./ normv, dims = WN_dim) ≈ Δg + else + normv = sum(abs2, v, dims = WN_dim[1]) + @test sum(ΔW .* v ./ normv, dims = WN_dim[1]) ≈ Δg + end + @test g ./ normv .* ΔW - g .* Δg .* v ./ (normv.^2) ≈ Δv + @test size(Δv) == size(ΔW) + @test isa(wnd.layer.W, Flux.WeightNormParam) + end + end +end + if VERSION >= v"1.1" @testset "GroupNorm" begin # begin tests