Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement cc_attention using pure pytorch #1201

Merged
merged 16 commits into from
Sep 9, 2021
Merged

Reimplement cc_attention using pure pytorch #1201

merged 16 commits into from
Sep 9, 2021

Conversation

Leojc
Copy link
Contributor

@Leojc Leojc commented Jul 15, 2021

Motivation

Reimplement a faster cc_attention using pure PyTorch.

related PR #1186

Modification

Rewrite cc_attention.py

@zhouzaida
Copy link
Collaborator

hi @Leojc , please fix the conflict

@zhouzaida
Copy link
Collaborator

Do you have some statistical numbers that verifies the current implementation? And can I close #1186?

@Leojc
Copy link
Contributor Author

Leojc commented Jul 15, 2021

Do you have some statistical numbers that verifies the current implementation? And can I close #1186?

I haven't verified this implementation yet. But the speed comparison in the docstring is valid. As for #1186, if you have decided that you are not using the einops library, then sure you can close it.

@zhouzaida
Copy link
Collaborator

Do you have some statistical numbers that verifies the current implementation? And can I close #1186?

I haven't verified this implementation yet. But the speed comparison in the docstring is valid. As for #1186, if you have decided that you are not using the einops library, then sure you can close it.

got it. if pure pytorch is not slower than einops library after verifying the performance, #1186 will be closed

@codecov
Copy link

codecov bot commented Jul 15, 2021

Codecov Report

Merging #1201 (48cdefc) into master (5617ad7) will increase coverage by 0.86%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1201      +/-   ##
==========================================
+ Coverage   68.27%   69.14%   +0.86%     
==========================================
  Files         161      162       +1     
  Lines       10742    10746       +4     
  Branches     1972     1978       +6     
==========================================
+ Hits         7334     7430      +96     
+ Misses       3023     2927      -96     
- Partials      385      389       +4     
Flag Coverage Δ
unittests 69.14% <100.00%> (+0.86%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcv/ops/cc_attention.py 90.90% <100.00%> (+52.19%) ⬆️
mmcv/cnn/utils/__init__.py 100.00% <0.00%> (ø)
mmcv/cnn/bricks/conv_module.py 100.00% <0.00%> (ø)
mmcv/cnn/utils/sync_bn.py 84.61% <0.00%> (ø)
mmcv/utils/config.py 90.12% <0.00%> (+0.02%) ⬆️
mmcv/cnn/utils/weight_init.py 85.92% <0.00%> (+0.36%) ⬆️
mmcv/runner/checkpoint.py 74.41% <0.00%> (+0.94%) ⬆️
mmcv/ops/saconv.py 87.87% <0.00%> (+4.54%) ⬆️
mmcv/ops/deform_conv.py 72.26% <0.00%> (+10.21%) ⬆️
mmcv/ops/modulated_deform_conv.py 83.05% <0.00%> (+34.74%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5617ad7...48cdefc. Read the comment docs.

@ZwwWayne
Copy link
Collaborator

Overall, the implementation looks good to me, see @xvjiarui have any suggestion.

@xvjiarui
Copy link
Collaborator

xvjiarui commented Aug 6, 2021

Hi @Leojc
Sorry for the late reply.
Just to make sure, the output of PyTorch and CUDA version is exactly the same, right? So if we use the same checkpoint to inference CCNet in MMSeg, the results will be the same.

@xvjiarui
Copy link
Collaborator

xvjiarui commented Aug 6, 2021

If the operator is equivalent, we may delete redundant .cpp files.

@Leojc
Copy link
Contributor Author

Leojc commented Aug 6, 2021

If the operator is equivalent, we may delete redundant .cpp files.

No, they are not equivalent. But the #1186 version got a similar IoU. I havn't verified the IoU of this pytorch version yet.

@xvjiarui
Copy link
Collaborator

xvjiarui commented Aug 6, 2021

Hi @Leojc
So there will be a BC breaking on CCNet models. Is there any way to make it comptible with old weights?

@Leojc
Copy link
Contributor Author

Leojc commented Aug 6, 2021

Is there any way to make it comptible with old weights?

No, as far as I know.

@xvjiarui
Copy link
Collaborator

Hi @Leojc
Sorry for the late update. MMSeg team will help benchmark this operator recently.

@Junjun2016
Copy link
Contributor

Hi @Leojc @xvjiarui
Benchmark results:
image

@Junjun2016
Copy link
Contributor

Is there any way to make it comptible with old weights?

No, as far as I know.

Hi @Leojc
We could check the difference together.

@Leojc
Copy link
Contributor Author

Leojc commented Aug 18, 2021

Hi @Leojc
We could check the difference together.

Hi @Junjun2016
I'd like to, but sorry I'm not familiar with CUDA code.

@Leojc
Copy link
Contributor Author

Leojc commented Aug 18, 2021

@Junjun2016 The benchmark results look pretty good, thanks!

@zhouzaida zhouzaida mentioned this pull request Aug 26, 2021
16 tasks
Copy link
Contributor

@Junjun2016 Junjun2016 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
import torch.nn.functional as F
from mmcv.ops.cc_attention import ca_map, ca_weight


def NEG_INF_DIAG(n, device):
    """Returns a diagonal matrix of size [n, n].

    The diagonal are all "-inf". This is for avoiding calculating the
    overlapped element in the Criss-Cross twice.
    """
    return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)


def cc_cuda(query):
    energy = ca_weight(query, query)
    attention = F.softmax(energy, 1)
    out = ca_map(attention, query)

    return out


def cc_pt(query):
    B, C, H, W = query.size()
    energy_H = torch.einsum('bchw,bciw->bwhi', query, query) + NEG_INF_DIAG(
        H, query.device)
    energy_H = energy_H.transpose(1, 2)
    energy_W = torch.einsum('bchw,bchj->bhwj', query, query)
    attn = F.softmax(
        torch.cat([energy_H, energy_W], dim=-1), dim=-1)  # [B,H,W,(H+W)]
    out = torch.einsum('bciw,bhwi->bchw', query, attn[..., :H])
    out += torch.einsum('bchj,bhwj->bchw', query, attn[..., H:])

    return out


if __name__ == "__main__":
    n = 100
    for i in range(n):
        query = torch.randn((2, 2, 2, 2), dtype=torch.float).cuda()
        # print(query)
        out_cuda = cc_cuda(query=query)
        out_pt = cc_pt(query=query)
        print(torch.allclose(out_cuda, out_pt))
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True

We can test the CCNet checkpoint with this implementation.
Please refer to the benchmark of CCNet in MMSeg.

@Leojc
Copy link
Contributor Author

Leojc commented Aug 27, 2021

Now the results of the CUDA and PyTorch versions are the same!
图片

@zhouzaida
Copy link
Collaborator

hi @Leojc , please fix the conflict. Thanks for your nice job!

@zhouzaida
Copy link
Collaborator

If the operator is equivalent, we may delete redundant .cpp files.

@Leojc we can delete redundant .cpp files.

@Leojc
Copy link
Contributor Author

Leojc commented Aug 27, 2021

If the operator is equivalent, we may delete redundant .cpp files.

@Leojc we can delete redundant .cpp files.

To be clear, I can delete these files, right?

mmcv/ops/csrc/pytorch/cc_attention.cpp
mmcv/ops/csrc/pytorch/cuda/cc_attention_cuda.cu

@zhouzaida
Copy link
Collaborator

hi @Leojc, we also need to remove those lines.

void ca_forward(const Tensor t, const Tensor f, Tensor weight);
void ca_backward(const Tensor dw, const Tensor t, const Tensor f, Tensor dt,
Tensor df);
void ca_map_forward(const Tensor weight, const Tensor g, Tensor out);
void ca_map_backward(const Tensor dout, const Tensor weight, const Tensor g,

m.def("ca_forward", &ca_forward, "ccattention forward", py::arg("t"),
py::arg("f"), py::arg("weight"));
m.def("ca_backward", &ca_backward, "ccattention backward", py::arg("dw"),
py::arg("t"), py::arg("f"), py::arg("dt"), py::arg("df"));
m.def("ca_map_forward", &ca_map_forward, "ccattention map forward",
py::arg("weight"), py::arg("g"), py::arg("out"));
m.def("ca_map_backward", &ca_map_backward, "ccattention map backward",
py::arg("dout"), py::arg("weight"), py::arg("g"), py::arg("dw"),
py::arg("dg"));

@Leojc
Copy link
Contributor Author

Leojc commented Sep 6, 2021

hi @Leojc, we also need to remove those lines.

OK. Done.

@zhouzaida
Copy link
Collaborator

class TestCrissCrossAttention(object):

can be refactored to

class TestCrissCrossAttention(object):

    def test_cc_attention(self):
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        from mmcv.ops import CrissCrossAttention
        loss_func = Loss()

        input = np.fromfile(
            'tests/data/for_ccattention/ccattention_input.bin',
            dtype=np.float32)
        output = np.fromfile(
            'tests/data/for_ccattention/ccattention_output.bin',
            dtype=np.float32)
        input = input.reshape((1, 32, 45, 45))
        output = output.reshape((1, 32, 45, 45))
        label = torch.ones((1, 32, 45, 45))

        input = torch.FloatTensor(input)
        output = torch.FloatTensor(output)

        input.requires_grad = True

        shape = input.shape
        channel = shape[1]

        cca = CrissCrossAttention(channel)
        cca.to(device)
        input = input.to(device)
        label = label.to(device)
        cca.train()
        test_output = cca(input)
        test_loss = loss_func(test_output, label)
        test_loss.backward()
        test_output = test_output.detach().cpu().numpy()
        output = output.numpy()

        assert np.allclose(test_output, output)
        assert test_output.shape == shape

Leojc and others added 2 commits September 7, 2021 17:44
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Sep 7, 2021

This looks good to me now. Can be merged upon approval of @xvjiarui and @zhouzaida

Copy link
Collaborator

@zhouzaida zhouzaida left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Sep 8, 2021

Sorry that I missed a part before. As we delete an op, we should add a comment in the docstring to tell users that we use CUDA op before v1.13.13 and refer to this PR #1201 and tell users we use a pure PyTorch and equivalent implementation since v1.3.13

@Leojc Leojc requested a review from ZwwWayne September 8, 2021 03:17
@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Sep 8, 2021

Thanks for the efforts! The last step is to correct the format of the docstring. As shown in the preview of documentation https://mmcv--1201.org.readthedocs.build/en/1201/api.html#mmcv.ops.CrissCrossAttention, the docstring is not correct.

Suggest update the format as:

.. note::
    Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch
    to a pure PyTorch and equivalent implementation. For more
    details, please refer to PR #1201.
    Comparison of one forward pass with input size [2,512,97,97] and 1 NVIDIA GeForce RTX 2080 Ti:
    ... some results...

Args:
    xxx

See the documentation of sphinx here https://sublime-and-sphinx-guide.readthedocs.io/en/latest/notes_warnings.html#notes

@Leojc Leojc requested a review from ZwwWayne September 8, 2021 09:29
Leojc and others added 2 commits September 8, 2021 19:08
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
@ZwwWayne ZwwWayne merged commit 2a3d2d4 into open-mmlab:master Sep 9, 2021
@OpenMMLab-Coodinator
Copy link

Hi @Leojc !First of all, we want to express our gratitude for your significant PR in this project. Your contribution is highly appreciated, and we are grateful for your efforts in helping improve this open-source project during your personal time. We believe that many developers will benefit from your PR.

We would also like to invite you to join our Special Interest Group (SIG) private channel on Discord, where you can share your experiences, ideas, and build connections with like-minded peers. To join the SIG channel, simply message moderator— OpenMMLab on Discord or briefly share your open-source contributions in the #introductions channel and we will assist you. Look forward to seeing you there! Join us :https://discord.gg/UjgXkPWNqA

If you have WeChat account,welcome to join our community on WeChat. You can add our assistant :openmmlabwx. Please add "mmsig + Github ID" as a remark when adding friends:)
Thank you again for your contribution❤

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants