-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does NATTEN support double backward? #191
Comments
I notice that unfused ops supporting double precision, so I have done following check: import torch
import natten
q = torch.rand(2, 1, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)
k = torch.rand(2, 1, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(lambda q, k: natten.functional.na2d_qk(q, k, 3), (q, k))
torch.autograd.gradgradcheck(lambda q, k: natten.functional.na2d_qk(q, k, 3), (q, k)) It fails at Traceback (most recent call last):
File "/home/sist/luoxin/projects/PCM/scripts/test_second_order_grad.py", line 86, in <module>
torch.autograd.gradgradcheck(lambda q, k: natten.functional.na2d_qk(q, k, 3), (q, k))
File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 2257, in gradgradcheck
return gradcheck(
File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 2055, in gradcheck
return _gradcheck_helper(**args)
File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 2084, in _gradcheck_helper
_gradcheck_real_imag(
File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 1493, in _gradcheck_real_imag
gradcheck_fn(
File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 1594, in _slow_gradcheck
return _check_no_differentiable_outputs(
File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.4+cu121/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 982, in _check_no_differentiable_outputs
raise GradcheckError(
torch.autograd.gradcheck.GradcheckError: Numerical gradient for function expected to be zero |
I confirm that NATTEN do not support double backward, the gradient of import torch
import natten
q = torch.rand(2, 1, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)
k = torch.rand(2, 1, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)
out = natten.functional.na2d_qk(q, k, 3)
grad_inputs = torch.autograd.grad(out, (q, k), grad_outputs=torch.ones_like(out), create_graph=True)
for grad_input in grad_inputs:
print(f"{grad_input.requires_grad}") But, can we support it? |
I'm not sure how difficult it is to support double backward for NATTEN. If it's very challenging, is there a non-CUDA implementation that could serve as an alternative during training? That is, a purely PyTorch implementation, which is inherently double-differentiable. |
I confirmed that pytorch's official import torch
import torch.nn.functional as F
q = torch.rand(2, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)
k = torch.rand(2, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)
v = torch.rand(2, 3, 3, 4, device="cuda", requires_grad=True, dtype=torch.double)
out = F.scaled_dot_product_attention(q, k, v)
grad_inputs = torch.autograd.grad(out, (q, k, v), grad_outputs=torch.ones_like(out), create_graph=True)
for grad_input in grad_inputs:
print(f"{grad_input.requires_grad}")
print(f"{grad_inputs=}")
torch.autograd.gradcheck(F.scaled_dot_product_attention, (q, k, v))
torch.autograd.gradgradcheck(F.scaled_dot_product_attention, (q, k, v)) |
Thank you for your interest. I'll look into this further, but here's a few things so far:
I'll update this thread when I find out more. |
Actually, while I'm looking into this, can you verify (without |
Hi, I understand that custom ops may not support FP64, so |
Regarding whether an additional kernel is needed, I think the answer is yes. I noticed that xformers does not support double backward; see this, but PyTorch's SDPA does. Given that both xformers and PyTorch rely on FlashAttention, there might be some implementation differences here. |
Here is a reference for implementing double backward compatible ops |
Yes, operations that are functionally equivalent and have double ops. This can be a CPU implementation. In the case of attention, it can be the unfused implementation.
If it doesn't, I would expect trying double backward with FP16 operators to fail. Also note that SDPA does not guarantee using fused attention, but there are flags that let you control even specific fused attention impls to which it can dispatch (i.e. FA, xFormers' FMHA, etc).
Thanks for the information, but I don't think we can definitively say that it does.
xFormers and PyTorch do not rely on Flash Attention. Flash Attention and xFormers' FMHA are different implementations for the same concept: fused attention. SDPA in PyTorch can dispatch to either unfused (BMM-style) attention, xFormers' FMHA, Flash Attention, or any other available implementations based on your use case and whether it's supported.
Thanks, yes this is included in the PyTorch blog I mentioned. It's just the API for supporting double backward, and not what the second derivative for attention is. That's what I'm trying to find out. |
I found this issue of pytorch, which suggest fused kernel of pytorch's SDPA indeed do not support double backward, but math implementation support. |
I see -- in that case you'd really need to consider efficiency with double backward. Even unfused self-attention very easily winds up being bound by memory bandwidth, and NA / localized attention would be even worse for a number reasons. But regardless of that, in the case of NA, this could mean we'd need additional kernels to support double backward. I can't promise we can support this any time soon, at least without figuring out what the double backward for attention even looks like. I'll try and spend some time on this in case there's an easier workaround. |
Hi, gradient loss is common in deep learning optimization, I would like to ensure NATTEN can support double backward such that gradient based loss do not fail silently.
However, currently I found that I can not utilize
torch.autograd.gradcheck
ortorch.autograd.gradgradcheck
to check whether NATTEN supporting double backward, since those apis require double type input while NATTEN only support FP32/FP16/BF16.Below simple testing code just fail at
torch.autograd.gradcheck
:The text was updated successfully, but these errors were encountered: