-
Notifications
You must be signed in to change notification settings - Fork 3
/
profile.py
109 lines (83 loc) · 3.31 KB
/
profile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import torch
import torch.nn as nn
import utils
def count_conv2d(m, x, y):
x = x[0] # remove tuple
fin = m.in_channels
fout = m.out_channels
sh, sw = m.kernel_size
# ops per output element
kernel_mul = sh * sw * fin
kernel_add = sh * sw * fin - 1
bias_ops = 1 if m.bias is not None else 0
ops = kernel_mul + kernel_add + bias_ops
# total ops
num_out_elements = y.numel()
total_ops = num_out_elements * ops
# print("Conv2d: S_c={}, F_in={}, F_out={}, P={}, params={}, operations={}".format(sh,fin,fout,x.size()[2:].numel(),int(m.total_params.item()),int(total_ops)))
# incase same conv is used multiple times
m.total_ops += torch.Tensor([int(total_ops)])
def count_bn2d(m, x, y):
x = x[0] # remove tuple
nelements = x.numel()
total_sub = 2*nelements
total_div = nelements
total_ops = total_sub + total_div
m.total_ops += torch.Tensor([int(total_ops)])
# print("Batch norm: F_in={} P={}, params={}, operations={}".format(x.size(1),x.size()[2:].numel(),int(m.total_params.item()),int(total_ops)))
def count_linear(m, x, y):
# per output element
total_mul = m.in_features
total_add = m.in_features - 1
num_elements = y.numel()
total_ops = (total_mul + total_add) * num_elements
# print("Linear: F_in={}, F_out={}, params={}, operations={}".format(m.in_features,m.out_features,int(m.total_params.item()),int(total_ops)))
m.total_ops += torch.Tensor([int(total_ops)])
def profile(model, input_size, custom_ops = {}):
model.eval()
def add_hooks(m):
if len(list(m.children())) > 0: return
m.register_buffer('total_ops', torch.zeros(1))
m.register_buffer('total_params', torch.zeros(1))
for p in m.parameters():
m.total_params += torch.Tensor([p.numel()])
if isinstance(m, nn.Conv2d):
m.register_forward_hook(count_conv2d)
elif isinstance(m, nn.BatchNorm2d):
m.register_forward_hook(count_bn2d)
elif isinstance(m, nn.Linear):
m.register_forward_hook(count_linear)
else:
print("Not implemented for ", m)
model.apply(add_hooks)
x = torch.zeros(input_size)
model(x)
total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: continue
total_ops += m.total_ops
total_params += m.total_params
total_ops = total_ops
total_params = total_params
return total_ops, total_params
def main():
file = "checkpoint/GKD_28-10_teaches_28-1_0_0_p1_25_pool.pth"
# file = "checkpoint/WideResNet28-1.pth"
model = torch.load(file)["net"].module.cpu()
flops, params = profile(model, (1,3,32,32))
flops, params = flops.item(), params.item()
wideresnet_params = 36536884
wideresnet_flops = 10490553344
score_flops = flops/wideresnet_flops
score_params = params/wideresnet_params
score = (score_flops + score_params)/2
print("Flops: {}, Params: {}".format(flops,params))
print("Score flops: {} Score Params: {}".format(score_flops,score_params))
print("Final score: {}".format(score))
model = torch.load(file)["net"].module
trainloader, testloader = utils.load_data(128)
utils.test(model,testloader, "cuda", "no",show="error")
if __name__ == "__main__":
main()