Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does NATTEN support double backward? #191

Open
Luciennnnnnn opened this issue Jan 4, 2025 · 12 comments
Open

Does NATTEN support double backward? #191

Luciennnnnnn opened this issue Jan 4, 2025 · 12 comments

Comments

@Luciennnnnnn
Copy link

Luciennnnnnn commented Jan 4, 2025

Hi, gradient loss is common in deep learning optimization, I would like to ensure NATTEN can support double backward such that gradient based loss do not fail silently.

However, currently I found that I can not utilize torch.autograd.gradcheck or torch.autograd.gradgradcheck to check whether NATTEN supporting double backward, since those apis require double type input while NATTEN only support FP32/FP16/BF16.

Below simple testing code just fail at torch.autograd.gradcheck:

import torch
import natten

q = torch.rand(2, 3, 3, 1, 4, device="cuda", requires_grad=True, dtype=torch.float32)
k = torch.rand(2, 3, 3, 1, 4, device="cuda", requires_grad=True, dtype=torch.float32)
v = torch.rand(2, 3, 3, 1, 4, device="cuda", requires_grad=True, dtype=torch.float32)

torch.autograd.gradcheck(lambda q, k, v: natten.functional.na2d(q, k, v, 3, scale=1.0), (q, k, v)) # fail
torch.autograd.gradgradcheck(lambda q, k, v: natten.functional.na2d(q, k, v, 3, scale=1.0), (q, k, v))
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[-0.0596, -0.0745, -0.0298,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0447,  0.0447,  0.0298,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0298, -0.0298, -0.0298,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0596,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0298, -0.0149, -0.0298],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0596,  0.0000,  0.0894]],
       device='cuda:0')
analytical:tensor([[-0.0416, -0.0353, -0.0043,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0080,  0.0035,  0.0059,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0239, -0.0478, -0.0411,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0379,  0.0065,  0.0214],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0069,  0.0138,  0.0097],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0212, -0.0316,  0.0403]],
       device='cuda:0')
@Luciennnnnnn
Copy link
Author

I notice that unfused ops supporting double precision, so I have done following check:

import torch
import natten

q = torch.rand(2, 1, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)
k = torch.rand(2, 1, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)

torch.autograd.gradcheck(lambda q, k: natten.functional.na2d_qk(q, k, 3), (q, k))
torch.autograd.gradgradcheck(lambda q, k: natten.functional.na2d_qk(q, k, 3), (q, k))

It fails at gradgradcheck with following error:

Traceback (most recent call last):
  File "/home/sist/luoxin/projects/PCM/scripts/test_second_order_grad.py", line 86, in <module>
    torch.autograd.gradgradcheck(lambda q, k: natten.functional.na2d_qk(q, k, 3), (q, k))
  File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 2257, in gradgradcheck
    return gradcheck(
  File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 2055, in gradcheck
    return _gradcheck_helper(**args)
  File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 2084, in _gradcheck_helper
    _gradcheck_real_imag(
  File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 1493, in _gradcheck_real_imag
    gradcheck_fn(
  File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 1594, in _slow_gradcheck
    return _check_no_differentiable_outputs(
  File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 982, in _check_no_differentiable_outputs
    raise GradcheckError(
torch.autograd.gradcheck.GradcheckError: Numerical gradient for function expected to be zero

@Luciennnnnnn
Copy link
Author

I confirm that NATTEN do not support double backward, the gradient of na2d_qk with respect to q or k just do not require gradient:

import torch
import natten

q = torch.rand(2, 1, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)
k = torch.rand(2, 1, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)

out = natten.functional.na2d_qk(q, k, 3)

grad_inputs = torch.autograd.grad(out, (q, k), grad_outputs=torch.ones_like(out), create_graph=True)
for grad_input in grad_inputs:
    print(f"{grad_input.requires_grad}")

But, can we support it?

@Luciennnnnnn
Copy link
Author

I'm not sure how difficult it is to support double backward for NATTEN. If it's very challenging, is there a non-CUDA implementation that could serve as an alternative during training? That is, a purely PyTorch implementation, which is inherently double-differentiable.

@Luciennnnnnn
Copy link
Author

I confirmed that pytorch's official scaled_dot_product_attention indeed support double backward, so I think implement it would not be impossible, see following test code:

import torch
import torch.nn.functional as F

q = torch.rand(2, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)
k = torch.rand(2, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)
v = torch.rand(2, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)

out = F.scaled_dot_product_attention(q, k, v)

grad_inputs = torch.autograd.grad(out, (q, k, v), grad_outputs=torch.ones_like(out), create_graph=True)
for grad_input in grad_inputs:
    print(f"{grad_input.requires_grad}")
print(f"{grad_inputs=}")

torch.autograd.gradcheck(F.scaled_dot_product_attention, (q, k, v))
torch.autograd.gradgradcheck(F.scaled_dot_product_attention, (q, k, v))

@alihassanijr
Copy link
Member

Thank you for your interest.

I'll look into this further, but here's a few things so far:

  1. Grad check does not necessarily need to be run to verify the correctness for backward pass, since many kernels may not support double precision. Most fused attention kernels don't even support FP32, let alone FP64 (double precision). Therefore, gradcheck and gradgradcheck failing for NATTEN ops is very normal, since there are no double precision FNA kernels.

  2. According to this PyTorch post, it should be possible to perform double backwards with custom operators, but it looks like it's not as simple as stacking existing operations together. I'll have to think about what the double gradient for attention looks like and see whether it requires new kernels, but I'd doubt that's the requirement if torch's SDPA API supports this.

I'll update this thread when I find out more.

@alihassanijr
Copy link
Member

Actually, while I'm looking into this, can you verify (without gradgradcheck) if SDPA supports double backward in say FP16?

@Luciennnnnnn
Copy link
Author

Hi, I understand that custom ops may not support FP64, so gradcheck and gradgradcheck cannot be used to verify if these ops correctly compute gradients or second-order gradients. However, I have a question: if we don't use finite-difference methods for the checks, do we have other alternatives? Do you have any opinions on how to check if SDPA supports double backward in FP16?

@Luciennnnnnn
Copy link
Author

Regarding whether an additional kernel is needed, I think the answer is yes. I noticed that xformers does not support double backward; see this, but PyTorch's SDPA does. Given that both xformers and PyTorch rely on FlashAttention, there might be some implementation differences here.

@Luciennnnnnn
Copy link
Author

Here is a reference for implementing double backward compatible ops

@alihassanijr
Copy link
Member

if we don't use finite-difference methods for the checks, do we have other alternatives?

Yes, operations that are functionally equivalent and have double ops. This can be a CPU implementation. In the case of attention, it can be the unfused implementation.

Do you have any opinions on how to check if SDPA supports double backward in FP16?

If it doesn't, I would expect trying double backward with FP16 operators to fail.

Also note that SDPA does not guarantee using fused attention, but there are flags that let you control even specific fused attention impls to which it can dispatch (i.e. FA, xFormers' FMHA, etc).

Regarding whether an additional kernel is needed, I think the answer is yes.

Thanks for the information, but I don't think we can definitively say that it does.
If it requires a completely new set of kernels, I would be certain that the same holds true for any other implementation, which means even FA or xFormers wouldn't support it. That would mean SDPA dispatches to unfused attention / an unfused implementation for double backwards.

I noticed that xformers does not support double backward; see this, but PyTorch's SDPA does. Given that both xformers and PyTorch rely on FlashAttention, there might be some implementation differences here.

xFormers and PyTorch do not rely on Flash Attention. Flash Attention and xFormers' FMHA are different implementations for the same concept: fused attention. SDPA in PyTorch can dispatch to either unfused (BMM-style) attention, xFormers' FMHA, Flash Attention, or any other available implementations based on your use case and whether it's supported.
This does not say anything about what happens during double backward.

Here is a pytorch/pytorch#34773 for implementing double backward compatible ops

Thanks, yes this is included in the PyTorch blog I mentioned. It's just the API for supporting double backward, and not what the second derivative for attention is. That's what I'm trying to find out.

@Luciennnnnnn
Copy link
Author

I found this issue of pytorch, which suggest fused kernel of pytorch's SDPA indeed do not support double backward, but math implementation support.

@alihassanijr
Copy link
Member

I see -- in that case you'd really need to consider efficiency with double backward. Even unfused self-attention very easily winds up being bound by memory bandwidth, and NA / localized attention would be even worse for a number reasons.

But regardless of that, in the case of NA, this could mean we'd need additional kernels to support double backward.

I can't promise we can support this any time soon, at least without figuring out what the double backward for attention even looks like.

I'll try and spend some time on this in case there's an easier workaround.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants