Skip to content

Commit

Permalink
WeightNormWeight is now called WeightNormParam
Browse files Browse the repository at this point in the history
WeightNorm for several params, single dim

Test for Scalar and Vector dims

Test newly created WN equality

Simplified some bits

Missing last constructor
  • Loading branch information
bhvieira committed Feb 11, 2020
1 parent 41feb43 commit 266c9af
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 29 deletions.
57 changes: 32 additions & 25 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,13 @@ end
Weight Normalization.
This layer reparametrizes weights (w) of a layer with its decomposition into magnitude (g) and direction (v).
WeightNorm(layer, weight::Union{Symbol,Int}, dim)
WeightNorm(layer, weight, dim)
``layer`` is the layer being normalized.
``weight`` is the parameter to be normalized.
``weight`` are the parameters to be normalized.
``dim`` is the dimension of normalization.
``dim`` are the dimension of normalization.
Often, its the dimension encoding the output channels.
Example:
Expand All @@ -390,55 +390,62 @@ wndB = WeightNorm(d, :W, 1:2); #Now we normalize all directions together, keepin
Link : https://arxiv.org/pdf/1602.07868.pdf
"""

struct WeightNormWeight{T,N,I}
struct WeightNormParam{T,N,I}
g::AbstractArray{T,N}
v::AbstractArray{T,N}
dim::I
end

Base.size(w::WeightNormWeight, i...) = size(w.v, i...)
Base.size(w::WeightNormWeight) = size(w.v)
Base.iterate(w::WeightNormWeight, i...) = iterate(w.g .* w.v ./ WN_mag(w.v, w.dim), i...)
Base.getindex(w::WeightNormWeight, i...) = getindex(w.g .* w.v ./ WN_mag(w.v, w.dim), i...)
Base.ndims(w::WeightNormWeight) = ndims(w.v)
Base.size(w::WeightNormParam, i...) = size(w.v, i...)
Base.size(w::WeightNormParam) = size(w.v)
Base.iterate(w::WeightNormParam, i...) = iterate(w.g .* w.v ./ WN_mag(w.v, w.dim), i...)
Base.getindex(w::WeightNormParam, i...) = getindex(w.g .* w.v ./ WN_mag(w.v, w.dim), i...)
Base.ndims(w::WeightNormParam) = ndims(w.v)
Base.length(w::WeightNormParam) = length(w.v)

Flux.@functor WeightNormWeight
@functor WeightNormParam

WN_mag(p, dim) = sqrt.(sum(abs2.(p), dims = dim))
WN_dir(p, mag, eps) = p ./ (mag .+ eps)
WN_dir(p, mag) = WN_dir(p, mag, eps(eltype(p)))
WN_mag(p, dim, eps) = sqrt.(sum(abs2.(p), dims = dim)) .+ eps
WN_mag(p, dim) = WN_mag(p, dim, eps(eltype(p)))
WN_dir(p, mag) = p ./ mag

import Base.*, Base./, Base.+, Base.-
for f in (:+, :-, :*, :/)
@eval ($f)(z::AbstractArray, w::WeightNormWeight) = ($f)(z, w.g .* w.v ./ WN_mag(w.v, w.dim))
@eval ($f)(w::WeightNormWeight, z::AbstractArray) = ($f)(w.g .* w.v ./ WN_mag(w.v, w.dim), z)
@eval ($f)(z::AbstractArray, w::WeightNormParam) = ($f)(z, w.g .* w.v ./ WN_mag(w.v, w.dim))
@eval ($f)(w::WeightNormParam, z::AbstractArray) = ($f)(w.g .* w.v ./ WN_mag(w.v, w.dim), z)
end

struct WeightNorm{L,E,I,W}
struct WeightNorm{L}
layer::L
eps::E
weight::W
dim::I
eps::Number
weight::Vector
dim::Vector
end

Flux.@functor WeightNorm
@functor WeightNorm

function Base.show(io::IO, wn::WeightNorm)
print(io, "WeightNorm(", wn.layer, ", ", wn.weight, ", ", wn.dim, ")")
end

function WeightNorm(layer, weight::Union{Symbol,Int}, dim)
function WeightNorm(layer, weight::Vector, dim::Vector)
#Expose layer fields and constructor
func, re = Flux.functor(layer)
#Get the fields
par = [getfield(layer, fn) for fn in keys(func)]
w = getfield(layer, weight)
g = WN_mag(w, dim)
v = WN_dir(w, g)
par[findfirst(keys(func) .== weight)] = WeightNormWeight(g, v, dim)
w = map(weight) do W
getfield(layer, W)
end
g = map((W, D) -> WN_mag(W, D), w, dim)
v = map((W, G) -> WN_dir(W, G), w, g)
par[indexin(weight,collect(keys(func)))] = WeightNormParam.(g, v, dim)
return WeightNorm(re(par), eps(Float32), weight, dim)
end

WeightNorm(layer, weight::Symbol, dim::Vector) = WeightNorm(layer, [weight], dim)
WeightNorm(layer, weight::Symbol, dim::Integer) = WeightNorm(layer, [weight], [dim])
WeightNorm(layer, weight::Vector, dim::Integer) = WeightNorm(layer, weight, [dim for _ in axes(weight,1)])

function (wn::WeightNorm)(x)
wn.layer(x)
end
14 changes: 10 additions & 4 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,20 +196,26 @@ end
d = Dense(10, 9, tanh)
gs = gradient(() -> sum(abs2, d(fake_data)), params(d))
W = d.W
for WN_dim in [1, 2, 1:2]
for WN_dim in [[1], 1, [2], 2, [1:2]]
wnd = WeightNorm(d, :W, WN_dim)
gswn = gradient(() -> sum(abs2, wnd(fake_data)), params(wnd))
g = wnd.layer.W.g
v = wnd.layer.W.v
normv = sum(abs2, v, dims = WN_dim)

ΔW = gs[W]
Δg = gswn[g]
Δv = gswn[v]
@test sum(ΔW .* v ./ normv, dims = WN_dim) Δg
@test wnd(fake_data) d(fake_data)
if isa(WN_dim, Int)
normv = sum(abs2, v, dims = WN_dim)
@test sum(ΔW .* v ./ normv, dims = WN_dim) Δg
else
normv = sum(abs2, v, dims = WN_dim[1])
@test sum(ΔW .* v ./ normv, dims = WN_dim[1]) Δg
end
@test g ./ normv .* ΔW - g .* Δg .* v ./ (normv.^2) Δv
@test size(Δv) == size(ΔW)
@test isa(wnd.layer.W, WeightNormWeight)
@test isa(wnd.layer.W, Flux.WeightNormParam)
end
end
end
Expand Down

0 comments on commit 266c9af

Please sign in to comment.