Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix layer_normalize gradients #3001

Merged
merged 8 commits into from
Sep 1, 2024
Merged

Conversation

arrufat
Copy link
Contributor

@arrufat arrufat commented Aug 28, 2024

Closes #2902

@arrufat arrufat marked this pull request as draft August 28, 2024 13:52
@arrufat
Copy link
Contributor Author

arrufat commented Aug 28, 2024

It still doesn't fix the discrepancy between GPU and CPU, but fixes a bug in the implementation.

@davisking
Copy link
Owner

I'm looking at this and not sure what the code is supposed to be doing. Go through the contracts and make sure they are right. Like

    void layer_normalize (
        const double eps, 
        resizable_tensor& dest,
        resizable_tensor& means,
        resizable_tensor& invstds,
        const tensor& src, 
        const tensor& gamma,
        const tensor& beta 
    );   
    /*!  
        requires
            - eps > 0
            - src.num_samples() == gamma.size() == beta.size()
            - have_same_dimensions(gamma, beta) == true
            - beta.num_samples() ==beta.nr() ==gamma.nc() == 1

That's saying beta and gamma are the same shape and all the dimensions are 1 except k, which would have to be src.num_sample(). So by that what was in the code before this would make sense. But running some of these I see that gamma and beta don't have the shape that requires clause says they do. So there is inconsistency in how these are being interpreted. I.e. is gamma.size() == src.num_samples() or is it src.k() * src.nr() * src.nc()?

That's totally at the root of the problem here. Everything starts with having contracts that are right. Then put DLIB_ASSERT statements that check all the requires statements so you know for sure they are not being violated. That will chase down the problem. Although you've got to decide what the arguments are first. Not sure what you want them to be for this layer. I would think gamma.size() == src.num_samples() as that's probably the most typical variant of layer norm though.

@davisking
Copy link
Owner

davisking commented Aug 29, 2024

Inside layer_normalize_gradient() for the cuda code, there are also local resiable_tensor variables (dvars and dmeans), those can't live there. It's real expensive to be creating and destroying tensors. They all need to live in a layer object so they aren't created and destroyed when the network runs. But rather are allocated once. But more than that kernels run asynchronously in cuda, so that _cuda_layer_normalize_gradient kernel launches but then those two variables are immediately freed. That's probably why the cuda version isn't working. Since it's running on dangling pointers.

Sorry I never really looked at this. You write such good PRs I just kinda skimmed this one and was like "yeah another Adria PR, going to be great and looks great 👍 :D " without really reading it all.

@arrufat
Copy link
Contributor Author

arrufat commented Aug 29, 2024

Oh, right, what was I thinking. It looks like I got confused half-way through the code, where I should normalize each channel independently, but I ended up trying to normalize along k, nr, and nc. Definitely, k should not be included in the normalization.

Hopefully, I will fix that and the dangling pointer tonight if life allows.

@arrufat
Copy link
Contributor Author

arrufat commented Aug 29, 2024

I was just checking this again:
https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html

image

It seems like, it does normalize along C, H, W. and there's one beta and one gamma for each normalized item.

import torch
N, C, H, W = 20, 5, 10, 10
x = torch.randn(N, C, H, W)
layer_norm = torch.nn.LayerNorm([C, H, W])
y = layer_norm(x)
sum(p.numel() for p in layer_norm.parameters() if p.requires_grad)  # 1000 = 2 * (5 * 10 * 10) 

So, maybe the issue is just the dangling pointers? I will make sure the contracts are correct and respected, though.

EDIT: after necrobumping ConvNeXt, each LayerNorm only has 2 * C learnable parameters (beta and gamma), so the implementation here is wrong. You're right about the dimensions of beta and gamma: they should only have k parameters each.

@davisking
Copy link
Owner

davisking commented Aug 29, 2024 via email

@arrufat
Copy link
Contributor Author

arrufat commented Aug 30, 2024

I am now confident about the CPU implementation, however, the CUDA version still fails.
There's a mismatch between CPU and CUDA and in the beta and gamma gradients, which are trivial, but I can't spot the mistake.
It's getting late, I'll try again later. Feel free to check, it must be something really stupid.

@arrufat
Copy link
Contributor Author

arrufat commented Aug 31, 2024

I honestly don't know what else to do.
The CPU version works correctly now, but not the CUDA version.

However, if you run the test_layer_normalize with CUDA enabled, you'll see that all the functionality is on par with the CPU version:

  • normalized output tensor
  • src_grad
  • gamma_grad
  • beta_grad
  • dmeans
  • dvars

All of them are within the tolerance error of 1e-5 or 1e-4. However, test_layer with layer_norm still fails.

@arrufat
Copy link
Contributor Author

arrufat commented Aug 31, 2024

Ok, fixed a race condition, now test_layer complains like this:

Average parameter gradient error is somewhat large at: 0.00713434                                                                                                                                                                                                                                                               

EDIT: after running a clean build, it's working!

@arrufat arrufat marked this pull request as ready for review August 31, 2024 14:22
@davisking
Copy link
Owner

Nice. I'm away from my computer. I'll look in a bit. Seems like you got it 🥳

@arrufat
Copy link
Contributor Author

arrufat commented Aug 31, 2024

It took an awful amount of time...

dlib/cuda/cuda_dlib.cu Outdated Show resolved Hide resolved
Comment on lines 2208 to 2225
for (auto nk : grid_stride_range_y(0, ns * ks))
{
const auto n = nk / ks;
const auto k = nk % ks;
const auto ps = s + (n * ks + k) * num;
const auto pgi = gi + (n * ks + k) * num;
float temp_bg = 0;
float temp_gg = 0;
for (auto i : grid_stride_range(0, num))
{
const float x_hat = (ps[i] - m[n]) * v[n];
temp_bg += pgi[i];
temp_gg += pgi[i] * x_hat;
}
warp_reduce_atomic_add(bg[k], temp_bg);
warp_reduce_atomic_add(gg[k], temp_gg);
}
__syncthreads();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that kind of warp reduction loop is the best way I know to do it too.

@arrufat
Copy link
Contributor Author

arrufat commented Sep 1, 2024

After the previous two commits, the network went to train from 320 img/s to 2450 img/s (close to the official CUDA/CUDNN batch norm at 2560 img/s)

@davisking
Copy link
Owner

After the previous two commits, the network went to train from 320 img/s to 2450 img/s (close to the official CUDA/CUDNN batch norm at 2560 img/s)

Yeah that's awesome. All the tests are passing for me too on a GPU machine. Passing for you too now? Anything else you want to change before I merge it? :D

@arrufat
Copy link
Contributor Author

arrufat commented Sep 1, 2024

Nothing else to add, I think it's done now. FINALLY.

And yes, tests are passing now :D

@davisking
Copy link
Owner

Yeah nice, thanks for all the good work. Looks perfect :D

@davisking davisking merged commit 253098e into davisking:master Sep 1, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Wrong gradients in the CUDA implementation of Layer Norm
2 participants