forked from haitongli/knowledge-distillation-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
count_model_size.py
58 lines (42 loc) · 1.78 KB
/
count_model_size.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
'''Count # of parameters in a trained model'''
import argparse
import os
import numpy as np
import torch
import utils
import model.net as net
import model.resnet as resnet
import model.wrn as wrn
import model.resnext as resnext
import utils
parser = argparse.ArgumentParser()
# parser.add_argument('--data_dir', default='data/64x64_SIGNS', help="Directory for the dataset")
parser.add_argument('--model', default='resnet18',
help="name of the model")
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == '__main__':
model_size = 0
args = parser.parse_args()
cnn_dir = 'experiments/cnn_distill'
json_path = os.path.join(cnn_dir, 'params.json')
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
params = utils.Params(json_path)
if args.model == "resnet18":
model = resnet.ResNet18()
model_checkpoint = 'experiments/base_resnet18/best.pth.tar'
elif args.model == "wrn":
model = wrn.wrn(depth=28, num_classes=10, widen_factor=10, dropRate=0.3)
model_checkpoint = 'experiments/base_wrn/best.pth.tar'
elif args.model == "distill_resnext":
model = resnet.ResNet18()
model_checkpoint = 'experiments/resnet18_distill/resnext_teacher/best.pth.tar'
elif args.model == "distill_densenet":
model = resnet.ResNet18()
model_checkpoint = 'experiments/resnet18_distill/densenet_teacher/best.pth.tar'
elif args.model == "cnn":
model = net.Net(params)
model_checkpoint = 'experiments/cnn_distill/best.pth.tar'
utils.load_checkpoint(model_checkpoint, model)
model_size = count_parameters(model)
print("Number of parameters in {} is: {}".format(args.model, model_size))