-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reimplement cc_attention using pure pytorch #1201
Conversation
hi @Leojc , please fix the conflict |
Do you have some statistical numbers that verifies the current implementation? And can I close #1186? |
I haven't verified this implementation yet. But the speed comparison in the docstring is valid. As for #1186, if you have decided that you are not using the einops library, then sure you can close it. |
got it. if pure pytorch is not slower than einops library after verifying the performance, #1186 will be closed |
Codecov Report
@@ Coverage Diff @@
## master #1201 +/- ##
==========================================
+ Coverage 68.27% 69.14% +0.86%
==========================================
Files 161 162 +1
Lines 10742 10746 +4
Branches 1972 1978 +6
==========================================
+ Hits 7334 7430 +96
+ Misses 3023 2927 -96
- Partials 385 389 +4
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
Overall, the implementation looks good to me, see @xvjiarui have any suggestion. |
Hi @Leojc |
If the operator is equivalent, we may delete redundant |
No, they are not equivalent. But the #1186 version got a similar IoU. I havn't verified the IoU of this pytorch version yet. |
Hi @Leojc |
No, as far as I know. |
Hi @Leojc |
Hi @Leojc |
Hi @Junjun2016 |
@Junjun2016 The benchmark results look pretty good, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import torch
import torch.nn.functional as F
from mmcv.ops.cc_attention import ca_map, ca_weight
def NEG_INF_DIAG(n, device):
"""Returns a diagonal matrix of size [n, n].
The diagonal are all "-inf". This is for avoiding calculating the
overlapped element in the Criss-Cross twice.
"""
return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)
def cc_cuda(query):
energy = ca_weight(query, query)
attention = F.softmax(energy, 1)
out = ca_map(attention, query)
return out
def cc_pt(query):
B, C, H, W = query.size()
energy_H = torch.einsum('bchw,bciw->bwhi', query, query) + NEG_INF_DIAG(
H, query.device)
energy_H = energy_H.transpose(1, 2)
energy_W = torch.einsum('bchw,bchj->bhwj', query, query)
attn = F.softmax(
torch.cat([energy_H, energy_W], dim=-1), dim=-1) # [B,H,W,(H+W)]
out = torch.einsum('bciw,bhwi->bchw', query, attn[..., :H])
out += torch.einsum('bchj,bhwj->bchw', query, attn[..., H:])
return out
if __name__ == "__main__":
n = 100
for i in range(n):
query = torch.randn((2, 2, 2, 2), dtype=torch.float).cuda()
# print(query)
out_cuda = cc_cuda(query=query)
out_pt = cc_pt(query=query)
print(torch.allclose(out_cuda, out_pt))
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
We can test the CCNet checkpoint with this implementation.
Please refer to the benchmark of CCNet in MMSeg.
hi @Leojc , please fix the conflict. Thanks for your nice job! |
@Leojc we can delete redundant |
To be clear, I can delete these files, right?
|
hi @Leojc, we also need to remove those lines. mmcv/mmcv/ops/csrc/pytorch/pybind.cpp Lines 161 to 168 in d3b0572
mmcv/mmcv/ops/csrc/pytorch/pybind.cpp Lines 388 to 396 in d3b0572
|
OK. Done. |
mmcv/tests/test_ops/test_cc_attention.py Line 17 in d3b0572
can be refactored to class TestCrissCrossAttention(object):
def test_cc_attention(self):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
from mmcv.ops import CrissCrossAttention
loss_func = Loss()
input = np.fromfile(
'tests/data/for_ccattention/ccattention_input.bin',
dtype=np.float32)
output = np.fromfile(
'tests/data/for_ccattention/ccattention_output.bin',
dtype=np.float32)
input = input.reshape((1, 32, 45, 45))
output = output.reshape((1, 32, 45, 45))
label = torch.ones((1, 32, 45, 45))
input = torch.FloatTensor(input)
output = torch.FloatTensor(output)
input.requires_grad = True
shape = input.shape
channel = shape[1]
cca = CrissCrossAttention(channel)
cca.to(device)
input = input.to(device)
label = label.to(device)
cca.train()
test_output = cca(input)
test_loss = loss_func(test_output, label)
test_loss.backward()
test_output = test_output.detach().cpu().numpy()
output = output.numpy()
assert np.allclose(test_output, output)
assert test_output.shape == shape |
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This looks good to me now. Can be merged upon approval of @xvjiarui and @zhouzaida |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Sorry that I missed a part before. As we delete an op, we should add a comment in the docstring to tell users that we use CUDA op before v1.13.13 and refer to this PR #1201 and tell users we use a pure PyTorch and equivalent implementation since v1.3.13 |
Thanks for the efforts! The last step is to correct the format of the docstring. As shown in the preview of documentation https://mmcv--1201.org.readthedocs.build/en/1201/api.html#mmcv.ops.CrissCrossAttention, the docstring is not correct. Suggest update the format as:
See the documentation of sphinx here https://sublime-and-sphinx-guide.readthedocs.io/en/latest/notes_warnings.html#notes |
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Hi @Leojc !First of all, we want to express our gratitude for your significant PR in this project. Your contribution is highly appreciated, and we are grateful for your efforts in helping improve this open-source project during your personal time. We believe that many developers will benefit from your PR. We would also like to invite you to join our Special Interest Group (SIG) private channel on Discord, where you can share your experiences, ideas, and build connections with like-minded peers. To join the SIG channel, simply message moderator— OpenMMLab on Discord or briefly share your open-source contributions in the #introductions channel and we will assist you. Look forward to seeing you there! Join us :https://discord.gg/UjgXkPWNqA If you have WeChat account,welcome to join our community on WeChat. You can add our assistant :openmmlabwx. Please add "mmsig + Github ID" as a remark when adding friends:) |
Motivation
Reimplement a faster cc_attention using pure PyTorch.
related PR #1186
Modification
Rewrite cc_attention.py