-
Notifications
You must be signed in to change notification settings - Fork 0
/
cal_params.py
30 lines (26 loc) · 1.17 KB
/
cal_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import argparse
from net import Net
import os
import time
from thop import profile
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
parser = argparse.ArgumentParser(description="PyTorch BasicIRSTD Parameter and FLOPs")
parser.add_argument("--model_names", default=['ACM', 'ALCNet', 'DNANet', 'ISNet', 'RISTDnet', 'UIUNet', 'U-Net', 'RDIAN', 'ISTDU-Net'], nargs='+',
help="model_name: 'ACM', 'ALCNet', 'DNANet', 'ISNet', 'RISTDnet', 'UIUNet', 'U-Net', 'RDIAN', 'ISTDU-Net'")
global opt
opt = parser.parse_args()
if __name__ == '__main__':
opt.f = open('./params_' + (time.ctime()).replace(' ', '_') + '.txt', 'w')
input_img = torch.rand(1,1,256,256).cuda()
for model_name in opt.model_names:
net = Net(model_name, mode='test').cuda()
flops, params = profile(net, inputs=(input_img, ))
print(model_name)
print('Params: %2fM' % (params/1e6))
print('FLOPs: %2fGFLOPs' % (flops/1e9))
opt.f.write(model_name + '\n')
opt.f.write('Params: %2fM\n' % (params/1e6))
opt.f.write('FLOPs: %2fGFLOPs\n' % (flops/1e9))
opt.f.write('\n')
opt.f.close()