Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use var to speed up normalisation #1973

Merged
merged 1 commit into from
May 27, 2022
Merged

Use var to speed up normalisation #1973

merged 1 commit into from
May 27, 2022

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented May 27, 2022

In this comparison, I think that one line accounts for Flux's extra memory use, compared to Lux's version with var. (Perhaps that wasn't supported earlier?). This PR fixes it:

julia> x = rand(rng, Float32, 128, 1000);

julia> @btime gradient(model -> sum(model(x)), model);
  4.906 ms (1305 allocations: 29.19 MiB)  # before
  3.136 ms (1219 allocations: 23.32 MiB)  # after

julia> v, re = destructure(model);

julia> @btime gradient(v -> sum(re(v)(x)), v);
  5.044 ms (1562 allocations: 29.47 MiB)  # before
  3.306 ms (1476 allocations: 23.60 MiB)  # after

compared to Lux, same machine same size:

julia> @btime gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps);
  4.069 ms (2679 allocations: 23.41 MiB)

julia> ca = ComponentArray(ps);  # to store a flat vector

julia> @btime gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ca);
  4.327 ms (3061 allocations: 24.62 MiB)

julia> v, re = Optimisers.destructure(ps);  # this works too

julia> @btime gradient(v -> sum(Lux.apply(model, x, re(v), st)[1]), v);
  4.093 ms (2847 allocations: 23.69 MiB)

@codecov-commenter
Copy link

Codecov Report

Merging #1973 (a9d5f44) into master (28ee7b4) will not change coverage.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##           master    #1973   +/-   ##
=======================================
  Coverage   87.94%   87.94%           
=======================================
  Files          19       19           
  Lines        1485     1485           
=======================================
  Hits         1306     1306           
  Misses        179      179           
Impacted Files Coverage Δ
src/layers/normalise.jl 88.81% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 28ee7b4...a9d5f44. Read the comment docs.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was even discussed at some point but fell through the cracks. LGTM.

@mcabbott mcabbott merged commit e4f8678 into FluxML:master May 27, 2022
@mcabbott mcabbott deleted the use_var branch May 27, 2022 04:35
@CarloLucibello
Copy link
Member

Does the gradient propagate to the keyword argument? I think that was the problem

@cossio
Copy link
Contributor

cossio commented May 27, 2022

There don't seem to be any tests checking the gradient of BatchNorm?

@mcabbott
Copy link
Member Author

Does the gradient propagate to the keyword argument? I think that was the problem

The gradient of the keyword is correctly zero, I think:

julia> gradient(x3, m3) do x, m
         var(x; mean=m)
       end
([-0.07914716919668174, 0.04399098883554592, 0.035156180361135936], nothing)

julia> ForwardDiff.gradient(x3) do x
         var(x; mean=m3)
       end
3-element Vector{Float64}:
 -0.07914716919668174
  0.04399098883554592
  0.035156180361135936

julia> ForwardDiff.derivative(m3) do m
         var(x3; mean=m)
       end
-1.1102230246251565e-16

Xref FluxML/Zygote.jl#478

@mcabbott
Copy link
Member Author

You could save a few more copies here by making some μ, σ2 = mean_var(x) which computes both gradients. Or a rule for _norm_layer_forward(x, μ, σ², ϵ), or better, fuse these two.

@cossio
Copy link
Contributor

cossio commented May 27, 2022

The centered second moment <(x - m)^2> has a minimum when m coincides with the mean, m = <x>, so the gradient is correctly zero in this case.

@CarloLucibello
Copy link
Member

I didn't check, but maybe we have wrong second derivatives?

@mcabbott
Copy link
Member Author

Here's a gist: https://gist.github.com/mcabbott/57befcf926b839e5e528ace38f018a66

tl;dr is that 2nd derivatives where you compute the mean one line above are fine. If you supply the mean from completely outside, then my head hurts, it's some overparameterised 2nd order tangent story.

@CarloLucibello
Copy link
Member

Ok, thanks, so I don't understand why but we are good

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants