forked from StephenTaylor1998/high-resolution-capsule
-
Notifications
You must be signed in to change notification settings - Fork 1
/
flops.py
54 lines (44 loc) · 2.3 KB
/
flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from core import models
from thop import profile
import argparse
def get_parameter_number(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return total_num, trainable_num
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Model Parameters and FLOPs Testing')
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-s', '--shape', default=[1, 3, 224, 224], nargs='+', type=int,
help="Input image's shape.")
parser.add_argument('-c', '--classes', default=1001, type=int, metavar='N',
help='number of classes (default: 1001)')
# add on args
parser.add_argument('--in-shape', default=(3, 32, 32), nargs='+', type=int,
help='Input image.')
parser.add_argument('--pose_dim', default=4, type=int, help='Capsule pose.')
parser.add_argument('--routing-iter', default=3, type=int, help='Capsule routing iter.')
parser.add_argument('--capsule-arch', default=[64, 8, 16, 16, 5], nargs='+', type=int,
help='Capsule arch.')
parser.add_argument('--routing-name-list', default=["FPN"], nargs='+', type=str,
help='FPN routing.')
parser.add_argument('--backbone', default="resnet50_dwt_tiny_half", type=str, help='FPN routing.')
args = parser.parse_args()
model = models.__dict__[args.arch](num_classes=args.classes, args=args)
print(args.shape)
inputs = torch.randn(*args.shape)
macs, _ = profile(model, inputs=(inputs,))
total_num, trainable_num = get_parameter_number(model)
print('='*30)
print(f"Model Name: {args.arch}")
print(f"FLOPs: {macs/1000000000} GFLOPs")
print(f"total_num: {total_num / 1000000} M")
print(f"trainable: {trainable_num / 1000000} M")
print('=' * 30)