Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add norm functions #452

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Add norm functions #452

wants to merge 6 commits into from

Conversation

ToucheSir
Copy link
Member

These roughly correspond to Flux's *Norm layers.

Constitutes the bulk of the work in #19. Dropout is a different beast and deserving of a separate discussion/its own PR. In the meantime, this should give NNlib{CUDA,ROC} a common interface to implement and allow Flux to start divesting its own norm function helpers.

Design Notes

The affine parameters are included as part of the function signature here because backends like cuDNN can fuse them into the main layer computation. This differs from Flux's use of Scale in LayerNorm, so how do we best reconcile the two? Activation functions are not included of this API for a similar reason (backends don't handle them).

Another thing not addressed in this PR is alternative memory layouts/dimension orderings such as channels-first. Hopefully once we figure out a way to represent inputs with different layouts, we can add internal extension points that dispatch on those. For now, everything is assumed to be spatial dims... x channels [x batch].

Relatedly, juggling dims is an absolute pain and both the most painful and time-consuming part of developing this PR. I wish there was a way to use named dims while not requiring users or backend implementors to do the same. Something for future work.

PR Checklist

  • Tests are added
  • Documentation, if applicable

@github-actions
Copy link
Contributor

github-actions bot commented Jan 3, 2023

Once the documentation build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/NNlib.jl/previews/PR452/

@ToucheSir ToucheSir force-pushed the bc/norm-functions branch 3 times, most recently from 75b42c3 to 3a52deb Compare January 3, 2023 04:00
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some small comments below from a first pass.

But I also have a gist where I was fiddling with some approaches, and can compare them:

https://gist.github.com/mcabbott/6154bb78b735e8f0a9348767a7d59c86

The tl;dr is that trying to fuse the gradient got me 1.45x faster & 29% the memory usage of Flux's approach. This PR gets 88%. Is it worth trying to bolt something like that on?

The CUDA routine for batchnorm seems to do even better, I don't know how. But I may be measuring it wrong... I did not really think about the batchnorm case. That seems to be a lot of the complexity here.

src/normalization.jl Outdated Show resolved Hide resolved
Comment on lines 130 to 134
if ChainRulesCore.is_inplaceable_destination(running_mean)
stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ)
else
stats.mean = res_mtm .* running_mean .+ momentum .* vec(μ)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might it be simpler to just demand that these be mutable?

Copy link
Member Author

@ToucheSir ToucheSir Jan 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was this gets us SArray support for relatively cheap. Granted I didn't actually test with StaticArrays, so coverage might not be happy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get this in abstract, but I think it's quite a bit of complexity, and there would have to be a real gain over just storing an MArray (and ideally someone who wants this).

I'd have to think what happens if an array has wider type than expected... it will be down-converted?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asking for parameter types to support similar so that Flux can always construct a RunningStats doesn't seem unreasonable. I also haven't throught about eltype mismatches.

error("both scale and bias must be provided or left as nothing")
end
scale′, bias′ = _maybe_reshape(scale, affine_size), _maybe_reshape(bias, affine_size)
return _apply_scale_bias((x .- μ) ./ sqrt.(σ² .+ ϵ), scale′, bias′)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will pay to separate out the inv & sqrt on N things, instead of calling them N^2 times. This was the major forward speedup I found.

@ToucheSir
Copy link
Member Author

ToucheSir commented Jan 3, 2023

My thought was that this PR (and any future functions which advertise ::AbstractArray) should be general enough to work on all array types. We then add overloads for native Arrays (in-repo and in extensions), GPU arrays and others (e.g. StaticArrays?) as applicable. So it's correct to say there's plenty of perf left on the table (I don't do a tenth of the tricks in https://github.com/chengchingwen/NeuralAttentionlib.jl/blob/b418c0d2a9e99c960e88879a5fd879d47d8e4c22/src/functional/layernorm.jl, for example), but generality comes first.

@mcabbott
Copy link
Member

mcabbott commented Jan 3, 2023

Ah I had not seen that. It looks optimised but I don't immediately see what it's actually doing! On the CPU it seems to give correct answers but is not especially efficient. However, I see some multi-arg mapreduce which I know hits a fallback defn, so I assume it's aimed at GPU arrays.

The counter to trying to bolt on speed afterwards is that it may want different organisation. I haven't absorbed how this PR structures things, but for me, absorbing the mean & var pullbacks completely seemed to be important.

@ToucheSir
Copy link
Member Author

ToucheSir commented Jan 3, 2023

It's likely optimized routines will want to replace everything and just fulfill the high-level *norm interface. The helper functions that do exist are meant to be internal, but implementors do have the choice of overriding specific parts. I suspect most will opt for the full replacement route, which is why functions like norm_helper don't have rrules defined and no equivalents to ∇conv_* are present—there aren't good one-size-fits-all solutions for either. To illustrate the extremes of the spectrum, cuDNN/oneDNN may want everything fused while XLA/LibTorch want everything decomposed into primitive ops internally.

@chengchingwen
Copy link
Member

Ah I had not seen that. It looks optimised but I don't immediately see what it's actually doing! On the CPU it seems to give correct answers but is not especially efficient. However, I see some multi-arg mapreduce which I know hits a fallback defn, so I assume it's aimed at GPU arrays.

It's focusing on GPU arrays. The idea is to use the GPUArrays' broadcast/mapreduce kernel for simple kernel fusion and reducing gpu allocation. So like in the layer norm forward funciton:

function layer_norm(epsilon, alpha, beta, x)
    # [...]
    sum_sum2 = mapreduce(_x_x2, .+, x; dims=1, init = (zero(T), zero(T)))
    return fma.(α, _normalize.(convert(T, 1//N), ϵ, x, sum_sum2), β)
end

This only take two gpu kernel to compute the layer norm, and only 1 intermediate array is created (sum_sum2). Then we do the same in the pullback function. It's not the most optimized implementation since it's accessing the reduced array from global memory, but should be a reasonable trade-off.

@mcabbott
Copy link
Member

mcabbott commented Jan 3, 2023

Cool. I've updated my gist above to time things... this is indeed fast.

On the GPU, I discover that var(x; mean, dims) seems to cost an entire copy. Despite that, the gradients for layer_norm and normal_new seem to be tied in memory alloc, suggesting there's one copy which could be avoided.

On the CPU it does some work N^2 times, but the cost of that can be largely solved by @fastmath sqrt, max.

What all this means for this PR, I don't know. I suppose that I think the basic implementation should (1) try to be efficient without doing anything exotic, Array and CuArray, (2) must not error on SArray, FillArray, etc, but little point optimising that, and (3) try to line up with cuDNN etc. routines to make such overloads easy, and perhaps try to make easy to hook on other overloads like using LoopVectorization.

Whether this mapreduce trick to avoid calling mean, var is too exotic, I don't know. Zygote would hate it but shouldn't see. It doesn't mean that there shouldn't be a function norm_stats, as BatchNorm wants to store that, but the fast path isn't then a fast version of that function at all.

@chengchingwen
Copy link
Member

Thanks for timing it. Were the reported numbers for layer_norm on gpu obtained with or without @fastmath?

I don't know much about fastmath, but I vaguely remember it is considered evil?

@mcabbott
Copy link
Member

mcabbott commented Jan 3, 2023

Not 100% sure but I think it made no difference on GPU. Weird beasts but computation often seems free compared to memory anyway.

Not sure I have an opinion on how evil it is. Here it seem pretty safe, it may fail to propagate NaN via the variance, but you'll get it again from x itself. But I didn't check very carefully.

@ToucheSir
Copy link
Member Author

ToucheSir commented Jan 4, 2023

If one really wants to go wild with optimizations, it's possible to fuse the computation of sum_sum2 into the forward pass as well and write out to mean + var/inverse std as you go along. https://triton-lang.org/master/getting-started/tutorials/05-layer-norm.html does this and is the fastest LayerNorm implementation I know of.

It doesn't mean that there shouldn't be a function norm_stats, as BatchNorm wants to store that, but the fast path isn't then a fast version of that function at all.

norm_stats is a kludge and wouldn't exist in an ideal world. The code reuse is nice, but its main purpose is to support maybe_norm_stats, which only exists to make the pullback for batch/instancenorm somewhat type stable. And we only care about type stability because Zygote gets cranky if a pullback in a deeply nested model isn't type stable. Indeed, a lot of code in this PR was written mainly to keep AD happy.

@chengchingwen
Copy link
Member

If one really wants to go wild with optimizations, it's possible to fuse the computation of sum_sum2 into the forward pass as well and write out to mean + var/inverse std as you go along. https://triton-lang.org/master/getting-started/tutorials/05-layer-norm.html does this and is the fastest LayerNorm implementation I know of.

The issue of that approach is that we would need to go down to kernel programming with either CUDA.jl or KernelAbstraction.jl. And the input size would also affect the performance, so in Pytorch they implement multiple kernel and dispatch with the input type and size. I guess part of the success of triton also comes from their compiler?

@ToucheSir
Copy link
Member Author

It would be interesting to know, because the kernels are so much shorter too. Wonder what performance we'd get with Julia versions. A Triton-like library would be a nice force multiplier in this ecosystem.

@mcabbott
Copy link
Member

mcabbott commented Jan 4, 2023

I think this sum_sum2 trick has catastrophic cancellation problems. Not sure if this matters in practice, but:

julia> data32 = rand(Float32, 100) .+ 10^4 |> cu;

julia> normal_now(data32) |> mean  # using mean, var, like Flux
-0.0031580066f0

julia> layer_norm(nothing, nothing, data32) |> mean
-87.890625f0

julia> hcat(normal_now(data32), layer_norm(nothing, nothing, data32))
100×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 -1.28074    -35644.5
  1.02109     28418.0
 -1.64216    -45703.1
 -1.19302    -33203.1
  0.635108    17675.8
...

On CPU without @fastmath this gives an error from sqrt(negative).

@chengchingwen
Copy link
Member

That doesn't look good. Maybe it's worth switching to Welford algo for that.

@mcabbott
Copy link
Member

mcabbott commented Jan 4, 2023

It sounds like that's the done thing. Can it use mapreduce or do you need lower-level things?

@chengchingwen
Copy link
Member

It should be doable with mapreduce, just replace .+ with the correct update rule. I didn't do that because at the beginning I thought if the input is a large array with small values, the division-then-addition in the update rule would introduce more error. But looks like the catastrophic cancellation is more troublesome.

Comment on lines +14 to +33
function norm_stats(x, dims)
μ = mean(x; dims)
σ² = var(x; dims, mean = μ, corrected = false)
return μ, σ²
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this function need not be closely tied to norm, it's just mean-and-var for any purpose you like. Saying what it does might be clearer than saying what it's for.

So I'd propose, for name & signature, following Statistics closely. And returning a NamedTuple?

Suggested change
function norm_stats(x, dims)
μ = mean(x; dims)
σ² = var(x; dims, mean = μ, corrected = false)
return μ, σ²
end
function mean_var(x::AbstractArray; dims=:, corrected::Bool=true)
μ = mean(x; dims)
σ2 = var(x; dims, mean=μ, corrected)
(; mean=μ, var=σ2)
end

I almost wonder if this function should live upstream somewhere... like Statistics?

My current attempt at a one-pass GPU version, which can happily overload this signature, is here:

https://gist.github.com/mcabbott/6154bb78b735e8f0a9348767a7d59c86#file-layer_norm-jl-L59-L82

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason it lives here, is called norm_stats instead of mean_var and doesn't take keyword args are one and the same: we need a Zygote-friendly function that selectively looks up running stats or calculates them. https://github.com/FluxML/NNlib.jl/blob/bc/norm-functions/src/normalization.jl#L108 doesn't really work with kwargs, so ; dims=... is not an option. Thus it'd be a little weird to advertise this function as a general purpose fused mean + var, because it doesn't follow the same interface.

@CarloLucibello
Copy link
Member

What's the status here? Can we delay optimization/generalization to future PRs and focus on approximately porting existing functionality if that's what needed to move on?

@ToucheSir ToucheSir mentioned this pull request Feb 5, 2023
3 tasks
@pxl-th
Copy link
Member

pxl-th commented Feb 13, 2023

Kind ping on the status of this PR as it will unblock support for AMDGPU backend (and we can specialize batchnorm using MIOpen for which I plan to open a PR).

These roughly correspond to Flux's `*Norm` layers.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants