-
Notifications
You must be signed in to change notification settings - Fork 1
/
models.py
64 lines (54 loc) · 2.24 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
load pretrained models
"""
import torchvision.models as models
import torch.nn as nn
import torch
from PyTorch_CIFAR10 import cifar10_models
class Normalize(nn.Module):
def __init__(self, mean, std) :
super(Normalize, self).__init__()
self.register_buffer('mean', torch.Tensor(mean))
self.register_buffer('std', torch.Tensor(std))
def forward(self, input):
# Broadcasting
mean = self.mean.reshape(1, 3, 1, 1)
std = self.std.reshape(1, 3, 1, 1)
return (input - mean) / std
def load_model(model_name, dataset_name ,use_cuda = True):
"""
load pretrained DNN models
if you want to use GPU, set use_cuda = True
else, set use_cuda = False
"""
assert model_name in ['vgg16','resnet50','inception_v3']
assert dataset_name in ['imagenet','cifar10']
if dataset_name == 'imagenet':
if model_name == 'vgg16':
model = models.vgg16_bn(pretrained=True) ## 3,224,224
elif model_name == 'resnet50':
model = models.resnet50(pretrained=True) ## 3,224,224
elif model_name == 'inception_v3':
model = models.inception_v3(pretrained=True) ## 3,299,299
elif dataset_name == 'cifar10':
if model_name == 'vgg16':
model = cifar10_models.vgg16_bn(pretrained=True) ##
elif model_name == 'resnet50':
model = cifar10_models.resnet50(pretrained=True) ##
elif model_name == 'inception_v3':
model = cifar10_models.inception_v3(pretrained=True) ##
elif dataset_name == 'mnist':
raise Exception('Not Implemented')
if dataset_name == 'imagenet':
model = nn.Sequential(Normalize(mean = [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),model)
elif dataset_name == 'cifar10':
model = nn.Sequential(Normalize(mean = [0.4913, 0.4821, 0.4465], std=[0.2470, 0.2434, 0.2615]),model)
if use_cuda:
model = model.cuda()
model.eval()
return model
if __name__ == '__main__':
model_names = ['vgg16','resnet50','inception_v3']
dataset_names = ['imagenet','cifar10']
for model_name,dataset_name in zip(model_names,dataset_names):
model = load_model(model_name,dataset_name)