Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added WeightNorm #1005

Closed
wants to merge 8 commits into from
Closed

Added WeightNorm #1005

wants to merge 8 commits into from

Conversation

bhvieira
Copy link
Contributor

@bhvieira bhvieira commented Jan 25, 2020

In light of #993 I wanted to create a WeightNorm constructor that could operate over any layer, param and dim.

This proved harder than I thought, but with the help of @chengchingwen and @mcabbott I could finally make this work.

So the execution is simply this:

  1. User calls WeightNorm on a layer
  2. We catch the param they want to normalize, and keep its direction and magnitude in a new type, WeightNormWeight
  3. We redefine several operations over this type (so it operates like an Array) while Zygote can transparently see that it's actually a direction and magnitude
  4. We reconstruct the layer using the functor functionality with the substitution in place

The catch is that it is not ready though.

- [ ] Make it work with NNlib.conv (I really want to make this work, but now I noticed it's not only Conv that will be laborious to fix, simple stuff like RNNs will also require some iterations to get working simply because they hide their params inside cells)

  • Remove dead code (I'm pretty sure some things I did there should not be needed, specially regarding the Base array operations. But I couldn't subtype WeightNormWeight into AbstractArray either
  • Make it accept more than one weight at a time, potentially with differing dim as well.
  • Add tests

So, here's a first try at least. I could throw an error to Conv, but I want to make it work and perhaps it'll be better to wait out until everything is right.

Inputs welcome!

@bhvieira bhvieira changed the title Added WeightNorm and export Added WeightNorm Jan 25, 2020
@bhvieira
Copy link
Contributor Author

Conv is proving harder than the other layers. I'll have to investigate further, but if anyone has a clue, here's where it fails to produce grads.

import NNlib: conv, ∇conv_data, depthwiseconv, DenseConvDims
DenseConvDims(x::AbstractArray, w::WeightNormWeight; kwargs...) = DenseConvDims(x, w.g .* w.v; kwargs...)
conv(x::AbstractArray, w::WeightNormWeight, cdims::ConvDims; kwargs...) = conv(x, w.g .* w.v ./ Flux.WN_mag(w.v, w.dim), cdims, kwargs...)
∇conv_data(x::AbstractArray, w::WeightNormWeight, cdims::ConvDims; kwargs...) = ∇conv_data(x, w.g .* w.v ./ Flux.WN_mag(w.v, w.dim), cdims, kwargs...)

c = Conv((3,3), 1=>3, tanh);
c2 = WeightNorm(c, :weight, 4)
c_fake_data = randn(Float32, 5,5,1,5);
c2(c_fake_data) ≈ c(c_fake_data) #true

gs = gradient(() -> sum(c2(c_fake_data)), params(c2));
gs[c2.layer.weight.g] #empty

@CarloLucibello
Copy link
Member

CarloLucibello commented Feb 6, 2020

I don't know why convs are not working with your code, but this alternative solution should work.
Essentially the approach is very similar, except for the fact that I reconstruct the original layer at each forward pass

struct WeightNorm
  layer
  g
  v
  weight::Symbol
  dims
  eps 
end

WN_mag(p, dims) = sqrt.(sum(abs2.(p), dims = dims))
WN_dir(p, mag, eps=eps(eltype(p))) = p ./ (mag .+ eps)
WN_reconstr(wn::WeightNorm) = wn.g .* wn.v ./ WN_mag(wn.v, wn.dims)

function WeightNorm(layer, weight::Union{Symbol,Int}; dims)
    #Expose layer fields and constructor
    func, re = Flux.functor(layer)
    #Get the fields
    w = getfield(layer, weight)
    g = WN_mag(w, dims)
    v = WN_dir(w, g)
    
    # Reconstruct the layer changing w for v (let's not waste memeory)
    replace(name) = name == weight ?  v : getfield(layer, name)
    par = [replace(name) for name in keys(func)]
    WeightNorm(re(par), g, v, weight, dims, eps(Float32))
end

function (wn::WeightNorm)(x)
    func, re = Flux.functor(wn.layer)
    w = WN_reconstr(wn)
    replace(name) = name == wn.weight ?  w : getfield(wn.layer, name)
    par = [replace(name) for name in keys(func)]
    re(par)(x)
end 

Flux.@functor WeightNorm

This approach seems simpler, we don't have to define custom arrays.
The linear layer seems fine, although I didn't check that grads are correct

julia> m = Flux.Dense(2,3)
Dense(2, 3)

julia> wn = FluxDNN.WeightNorm(m, :W, dims=1)
FluxDNN.WeightNorm(Dense(2, 3), Float32[1.1312975 0.5932242], Float32[-0.2702223 -0.22867282; -0.23115203 -0.31670466; 0.93463814 0.9205468], :W, 1, 1.1920929f-7)

## FORWARDS ARE THE SAME

julia> m(ones(2))
3-element Array{Float32,1}:
 -0.44135612
 -0.44937864
  1.6034446 

julia> wn(ones(2))
3-element Array{Float32,1}:
 -0.44135612
 -0.44937867
  1.6034446 

 #### BACKWARDS 

julia> Flux.gradient(params(m)) do
          sum(m(ones(2)).^2)
       end.grads
IdDict{Any,Any} with 4 entries:
  RefValue{typeof(^)}(^)       => RefValue{Any}((x = nothing,))
  RefValue{Val{2}}(Val{2}())   => RefValue{Any}((x = nothing,))
  Float32[-0.305702 -0.135654 => Float32[-0.882712 -0.882712; -0.898757 -0.898757; 3.20
  Float32[0.0, 0.0, 0.0]       => Float32[-0.882712, -0.898757, 3.20689]

julia> Flux.gradient(params(wn)) do
          sum(wn(ones(2)).^2)
       end.grads
IdDict{Any,Any} with 5 entries:
  RefValue{typeof(^)}(^)       => RefValue{Any}((x = nothing,))
  Float32[1.1313 0.593224]     => Float32[3.44356 3.43859]
  Float32[-0.270222 -0.228673 => Float32[0.0540923 -0.0571875; -0.116265 0.112866; -0.0
  RefValue{Val{2}}(Val{2}())   => RefValue{Any}((x = nothing,))
  Float32[0.0, 0.0, 0.0]       => Float32[-1.76542, -1.79751, 6.41378]

@CarloLucibello
Copy link
Member

Also convolutions seem fine. Didn't test it on gpu yet.

julia> m = Flux.Conv((2,2), 2=>3)
Conv((2, 2), 2=>3)

julia> wn = FluxDNN.WeightNorm(m, :weight, dims=1);

julia> m(ones(3,3,2,2))
2×2×3×2 Array{Float64,4}:
[:, :, 1, 1] =
 0.0310524  0.0310524
 0.0310524  0.0310524

[:, :, 2, 1] =
 0.676353  0.676353
 0.676353  0.676353

[:, :, 3, 1] =
 -0.733513  -0.733513
 -0.733513  -0.733513

[:, :, 1, 2] =
 0.0310524  0.0310524
 0.0310524  0.0310524

[:, :, 2, 2] =
 0.676353  0.676353
 0.676353  0.676353

[:, :, 3, 2] =
 -0.733513  -0.733513
 -0.733513  -0.733513

julia> wn(ones(3,3,2,2))
2×2×3×2 Array{Float64,4}:
[:, :, 1, 1] =
 0.0310524  0.0310524
 0.0310524  0.0310524

[:, :, 2, 1] =
 0.676353  0.676353
 0.676353  0.676353

[:, :, 3, 1] =
 -0.733513  -0.733513
 -0.733513  -0.733513

[:, :, 1, 2] =
 0.0310524  0.0310524
 0.0310524  0.0310524

[:, :, 2, 2] =
 0.676353  0.676353
 0.676353  0.676353

[:, :, 3, 2] =
 -0.733513  -0.733513
 -0.733513  -0.733513

julia> Flux.gradient(params(m)) do
          sum(m(ones(3,3,2,2)).^2)
       end.grads
IdDict{Any,Any} with 4 entries:
  Float32[0.0, 0.0, 0.0]       => [0.496838, 10.8217, -11.7362]
  Float32[0.429876 -0.0354687 => [0.496838 0.496838; 0.496838 0.496838]
  RefValue{typeof(^)}(^)       => RefValue{Any}((x = nothing,))
  RefValue{Val{2}}(Val{2}())   => RefValue{Any}((x = nothing,))

julia> Flux.gradient(params(wn)) do
          sum(wn(ones(3,3,2,2)).^2)
       end.grads
IdDict{Any,Any} with 5 entries:
  Float32[0.653553 -0.241225; => [0.348865 0.0517; 0.301239 -0.0128509]
  RefValue{typeof(^)}(^)       => RefValue{Any}((x = nothing,))
  RefValue{Val{2}}(Val{2}())   => RefValue{Any}((x = nothing,))
  Float32[0.0, 0.0, 0.0]       => [0.993675, 21.6433, -23.4724]
  Float32[0.657753 0.147035]  => [-0.0513368 -0.602016]

@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 6, 2020

I thought of something similar, but reconstructing the layer at every call sounded like it could incur a hefty penalty to performance. We can benchmark the implementations, and also use the tests we already have for the gradients to compare them later.

@CarloLucibello
Copy link
Member

I don't think constructing an object out of precomputed fields is expensive, at least not compared to the weight reconstruction that we have to perform in any case. In any case, let's do some benchmarks.

@CarloLucibello
Copy link
Member

@bhvieira would you like to implement #1005 (comment) here or you want me to open a separate PR?

@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 9, 2020

@bhvieira would you like to implement #1005 (comment) here or you want me to open a separate PR?

We should really benchmark both before that, but feel free to open a PR to my branch at bhvieira:weightnorm, I really like that model of credit assignment to the work. That way we both get to keep authorship of the commits in this PR.

@bhvieira
Copy link
Contributor Author

bhvieira commented Feb 10, 2020

@CarloLucibello I ran the benchmarks, the backward passes don't look that much different (only 2 times slower), but the forward pass is around 5 times slower than mine (for comparison, mine is already 3 times slower than the normal layer without WeightNorm).


using Flux

struct WeightNorm2
    layer
    g
    v
    weight::Symbol
    dims
    eps 
end

WN_mag(p, dims) = sqrt.(sum(abs2.(p), dims = dims))
WN_dir(p, mag, eps=eps(eltype(p))) = p ./ (mag .+ eps)
WN_reconstr(wn::WeightNorm2) = wn.g .* wn.v ./ WN_mag(wn.v, wn.dims)

function WeightNorm2(layer, weight::Union{Symbol,Int}; dims)
    #Expose layer fields and constructor
    func, re = Flux.functor(layer)
    #Get the fields
    w = getfield(layer, weight)
    g = WN_mag(w, dims)
    v = WN_dir(w, g)
    
    # Reconstruct the layer changing w for v (let's not waste memeory)
    replace(name) = name == weight ?  v : getfield(layer, name)
    par = [replace(name) for name in keys(func)]
    WeightNorm2(re(par), g, v, weight, dims, eps(Float32))
end

function (wn::WeightNorm2)(x)
    func, re = Flux.functor(wn.layer)
    w = WN_reconstr(wn)
    replace(name) = name == wn.weight ?  w : getfield(wn.layer, name)
    par = [replace(name) for name in keys(func)]
    re(par)(x)
end 

Flux.@functor WeightNorm2

using BenchmarkTools

m = Flux.Dense(2,3);
wn = Flux.WeightNorm(m, :W, 1);
wn2 = WeightNorm2(m, :W, dims=1);

@benchmark m(x) setup=(x=randn(Float32, 2, 100))
@benchmark wn(x) setup=(x=randn(Float32, 2, 100))
@benchmark wn2(x) setup=(x=randn(Float32, 2, 100))

@benchmark gradient(() -> sum(abs2.(m(x))), Flux.params(m)) setup=(x=randn(Float32, 2, 100))
@benchmark gradient(() -> sum(abs2.(wn(x))), Flux.params(wn)) setup=(x=randn(Float32, 2, 100))
@benchmark gradient(() -> sum(abs2.(wn2(x))), Flux.params(wn2)) setup=(x=randn(Float32, 2, 100))

# julia> @benchmark m(x) setup=(x=randn(Float32, 2, 100))
# BenchmarkTools.Trial: 
#   memory estimate:  2.66 KiB
#   allocs estimate:  2
#   --------------
#   minimum time:     1.910 μs (0.00% GC)
#   median time:      2.230 μs (0.00% GC)
#   mean time:        3.056 μs (5.02% GC)
#   maximum time:     233.120 μs (96.59% GC)
#   --------------
#   samples:          10000
#   evals/sample:     10

# julia> @benchmark wn(x) setup=(x=randn(Float32, 2, 100))
# BenchmarkTools.Trial: 
#   memory estimate:  3.23 KiB
#   allocs estimate:  14
#   --------------
#   minimum time:     5.083 μs (0.00% GC)
#   median time:      6.817 μs (0.00% GC)
#   mean time:        9.644 μs (4.06% GC)
#   maximum time:     1.129 ms (0.00% GC)
#   --------------
#   samples:          10000
#   evals/sample:     6

# julia> @benchmark wn2(x) setup=(x=randn(Float32, 2, 100))
# BenchmarkTools.Trial: 
#   memory estimate:  4.48 KiB
#   allocs estimate:  35
#   --------------
#   minimum time:     26.399 μs (0.00% GC)
#   median time:      29.700 μs (0.00% GC)
#   mean time:        37.198 μs (2.46% GC)
#   maximum time:     4.677 ms (98.98% GC)
#   --------------
#   samples:          10000
#   evals/sample:     1

# julia> @benchmark gradient(() -> sum(abs2.(m(x))), Flux.params(m)) setup=(x=randn(Float32, 2, 100))
# BenchmarkTools.Trial: 
#   memory estimate:  29.67 KiB
#   allocs estimate:  1009
#   --------------
#   minimum time:     42.299 μs (0.00% GC)
#   median time:      47.200 μs (0.00% GC)
#   mean time:        59.509 μs (4.91% GC)
#   maximum time:     3.110 ms (96.45% GC)
#   --------------
#   samples:          10000
#   evals/sample:     1

# julia> @benchmark gradient(() -> sum(abs2.(wn(x))), Flux.params(wn)) setup=(x=randn(Float32, 2, 100))
# BenchmarkTools.Trial: 
#   memory estimate:  35.73 KiB
#   allocs estimate:  1193
#   --------------
#   minimum time:     86.400 μs (0.00% GC)
#   median time:      97.000 μs (0.00% GC)
#   mean time:        143.115 μs (4.00% GC)
#   maximum time:     7.478 ms (93.34% GC)
#   --------------
#   samples:          10000
#   evals/sample:     1

# julia> @benchmark gradient(() -> sum(abs2.(wn2(x))), Flux.params(wn2)) setup=(x=randn(Float32, 2, 100))
# BenchmarkTools.Trial: 
#   memory estimate:  41.69 KiB
#   allocs estimate:  1364
#   --------------
#   minimum time:     195.599 μs (0.00% GC)
#   median time:      230.299 μs (0.00% GC)
#   mean time:        331.383 μs (2.41% GC)
#   maximum time:     6.263 ms (85.39% GC)
#   --------------
#   samples:          10000
#   evals/sample:     1

@bhvieira
Copy link
Contributor Author

Gradients are correct though, perhaps it's only a matter of optimizing your solution as it appears to work with any layer out of the box:

data = randn(Float32, 2, 100);
gs_mine = gradient(() -> sum(abs2.(wn(data))), Flux.params(wn));
gs_prop = gradient(() -> sum(abs2.(wn2(data))), Flux.params(wn2));
gs = gradient(() -> sum(abs2.(m(data))), Flux.params(m));


isapprox(gs_mine[wn.layer.W.g], gs_prop[wn2.g])
isapprox(gs_mine[wn.layer.W.v], gs_prop[wn2.v])

ΔW = gs[m.W];
Δg = gs_prop[wn2.g];
v = wn2.v;
normv = WN_mag(wn2.v, wn2.dims);
WN_dim = wn2.dims;

sum(ΔW .* v ./ normv, dims = WN_dim) ≈ Δg

@CarloLucibello
Copy link
Member

mhm, I wonder what is causing this overhead. Could you try with a larger layers, e.g. m = Flux.Dense(10,10); and m = Flux.Dense(100,100); hopefully the overhead with respect to the current PR should decrease

@bhvieira
Copy link
Contributor Author

The new tests should pass, but I identified a big problem right now. where WeightNorm refuses to work with the bias parameters of Dense.

@bhvieira
Copy link
Contributor Author

I'll try that later when I find some time @CarloLucibello

WeightNorm for several params, single dim

Test for Scalar and Vector dims

Test newly created WN equality

Simplified some bits

Missing last constructor
@bhvieira
Copy link
Contributor Author

Can't make the following example work no matter what I try:

fake_data = randn(Float32, 10,3)
d = Dense(10, 9, tanh)
wnd = WeightNorm(d, [:b], [1])
d(fake_data) ≈ wnd(fake_data) #true
gs = gradient(() -> sum(abs2.(wnd(fake_data))), params(wnd)); #error

I can envision what could cause WeightNorm to fail, but I have no clue where does mutating occur.
Full stacktrace:

julia> gs = gradient(() -> sum(abs2.(wnd(fake_data))), params(wnd));
ERROR: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] (::getfield(Zygote, Symbol("##999#1000")))(::Nothing) at C:\Users\bhebl\.julia\packages\Zygote\oMScO\src\lib\array.jl:49
 [3] (::getfield(Zygote, Symbol("##2680#back#1001")){getfield(Zygote, Symbol("##999#1000"))})(::Nothing) at C:\Users\bhebl\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [4] copyto! at .\abstractarray.jl:649 [inlined]
 [5] (::typeof(∂(copyto!)))(::Array{Float32,1}) at C:\Users\bhebl\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
 [6] _collect at .\array.jl:563 [inlined]
 [7] (::typeof(∂(Base._collect)))(::Array{Float32,1}) at C:\Users\bhebl\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
 [8] collect at .\array.jl:557 [inlined]
 [9] (::typeof(∂(collect)))(::Array{Float32,1}) at C:\Users\bhebl\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
 [10] broadcastable at .\broadcast.jl:617 [inlined]
 [11] (::typeof(∂(Base.Broadcast.broadcastable)))(::Array{Float32,1}) at C:\Users\bhebl\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
 [12] broadcasted at .\broadcast.jl:1172 [inlined]
 [13] Dense at C:\Users\bhebl\Documents\GitHub\Flux.jl\src\layers\basic.jl:102 [inlined]
 [14] (::typeof(∂(invoke)))(::Array{Float32,2}) at C:\Users\bhebl\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
 [15] WeightNorm at C:\Users\bhebl\Documents\GitHub\Flux.jl\src\layers\basic.jl:113 [inlined]
 [16] (::typeof(∂(λ)))(::Array{Float32,2}) at C:\Users\bhebl\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
 [17] (::typeof(∂(getfield(Main, Symbol("##5#6"))())))(::Float32) at .\REPL[29]:1
 [18] (::getfield(Zygote, Symbol("##38#39")){Zygote.Params,Zygote.Context,typeof(∂(getfield(Main, Symbol("##5#6"))()))})(::Float32) at C:\Users\bhebl\.julia\packages\Zygote\oMScO\src\compiler\interface.jl:101
 [19] gradient(::Function, ::Zygote.Params) at C:\Users\bhebl\.julia\packages\Zygote\oMScO\src\compiler\interface.jl:47      
 [20] top-level scope at none:0

@bhvieira
Copy link
Contributor Author

I'm closing this PR, I haven't advanced on it and I'm out of time these days.

@mcabbott mcabbott mentioned this pull request Sep 11, 2022
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants