Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: support to calculate FLOPs of GN, IN, LN #894

Closed
wants to merge 3 commits into from

Conversation

zhouzaida
Copy link
Collaborator

Related Issue:#886

Support:

  1. Support to calculate FLOPs of GroupNorm, InstanceNorm1d, InstanceNorm2d, InstanceNorm3d, LayerNorm

Discussion:
1. how to support torch.bmm
Now we only support to calculate FLOPs of those modules inherited from nn.Module. Therefore, operations which are not inherited from nn.Module are not supported, such as torch.bmm, torch.nn.functional.conv2d.

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmcv.cnn.utils import get_model_complexity_info


class Dummy(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 3)
        self.conv2 = nn.Conv2d(8, 256, 3)
        self.conv3 = nn.Conv2d(256, 8, 3)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(8, 1)
    def forward(self, x):
        # nn.Module
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)
        # torch.nn.functional.conv1d is not supported
        filters = torch.randn(33, 16, 3)
        inputs = torch.randn(20, 16, 50)
        outputs = F.conv1d(inputs, filters)
        # torch.bmm is not supported
        inputs_1 = torch.randn(10, 3, 4)
        inputs_2 = torch.randn(10, 4, 5)
        outputs = torch.bmm(inputs_1, inputs_2)
        return outputs


get_model_complexity_info(Dummy(), (3, 16, 16))
"""
Dummy(
  0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs, 
  (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
  (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
  (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
  (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
)
"""

Maybe we can use decorator to support torch.bmm or torch.nn.functional.conv2d and so on.

from collections import defaultdict
import functools

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.utils import get_model_complexity_info


def bmm_flops_count(input, output):
    input1, input2, *remain = input[0]
    return np.prod(input1.shape[1:]) * input2.shape[-1]


method_mapping = {
    'bmm': bmm_flops_count,
}
flops_cnt = defaultdict(int)


def flops_count_wrapper(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        output = func(*args, **kwargs)
        name = func.__name__
        flops_cnt[name] += method_mapping[name](input=(args, kwargs),
                                                output=output)
        return output
    return wrapper


# decorate wrapper
torch.bmm = flops_count_wrapper(torch.bmm)


class Dummy(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 3)
        self.conv2 = nn.Conv2d(8, 256, 3)
        self.conv3 = nn.Conv2d(256, 8, 3)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(8, 1)
    def forward(self, x):
        # nn.Module
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)
        # torch.nn.functional.conv1d
        filters = torch.randn(33, 16, 3)
        inputs = torch.randn(20, 16, 50)
        outputs = F.conv1d(inputs, filters)
        # torch.bmm
        inputs_1 = torch.randn(10, 3, 4)
        inputs_2 = torch.randn(10, 4, 5)
        outputs = torch.bmm(inputs_1, inputs_2)
        return outputs


get_model_complexity_info(Dummy(), (3, 16, 16))
print(flops_cnt)
"""
Dummy(
  0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs, 
  (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
  (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
  (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
  (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
)
defaultdict(<class 'int'>, {'bmm': 60})
"""

2. deconv_flops_counter_hook and conv_flops_counter_hook
Should deconv_flops_counter_hook and
conv_flops_counter_hook be the same?

@CLAassistant
Copy link

CLAassistant commented Mar 17, 2021

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
1 out of 2 committers have signed the CLA.

✅ zhouzaida
❌ 周再达


周再达 seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@codecov
Copy link

codecov bot commented Mar 17, 2021

Codecov Report

Merging #894 (869f0eb) into master (73bff4e) will increase coverage by 0.01%.
The diff coverage is 75.00%.

❗ Current head 869f0eb differs from pull request most recent head 928da7f. Consider uploading reports for the commit 928da7f to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##           master     #894      +/-   ##
==========================================
+ Coverage   66.58%   66.59%   +0.01%     
==========================================
  Files         145      145              
  Lines        8828     8841      +13     
  Branches     1605     1606       +1     
==========================================
+ Hits         5878     5888      +10     
- Misses       2633     2637       +4     
+ Partials      317      316       -1     
Flag Coverage Δ
unittests 66.59% <75.00%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcv/runner/base_module.py 79.41% <69.23%> (-6.31%) ⬇️
mmcv/cnn/utils/flops_counter.py 93.63% <100.00%> (+0.45%) ⬆️
mmcv/runner/__init__.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 73bff4e...928da7f. Read the comment docs.

@hellock hellock requested a review from MeowZheng March 17, 2021 07:46
Copy link
Collaborator

@MeowZheng MeowZheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhouzaida zhouzaida closed this Mar 18, 2021
@zhouzaida zhouzaida deleted the flops_cnt_support branch March 18, 2021 08:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants