forked from gngdb/ShuffleNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
72 lines (59 loc) · 2.08 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import os
import json
from model import ShuffleNet
from torchvision import transforms
from torch.autograd import Variable
import torch
from PIL import Image
import numpy as np
def get_transformer():
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transformer = transforms.Compose([
transforms.Resize(128),
transforms.ToTensor(),
normalize
])
return transformer
def preprocess(image, transformer):
x = transformer(image)
return Variable(x.unsqueeze(0))
def infer(args):
# make ShuffleNet model
print('Creating ShuffleNet model')
net = ShuffleNet(num_classes=args.num_classes, in_channels=3)
# load trained checkpoint
print('Loading checkpoint')
checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage)
net.load_state_dict(checkpoint['state_dict'])
print('Loading index-class map')
with open(args.idx_to_class, 'r') as f:
mapping = json.load(f)
# image transformer
transformer = get_transformer()
# make input tensor
print('Loading image')
image = Image.open(args.image)
print('Preprocessing')
x = preprocess(image, transformer)
# predict output
print('Inferring on image {}'.format(args.image))
net.eval()
y = net(x)
top_idxs = np.argsort(y.data.cpu().numpy().ravel()).tolist()[-10:][::-1]
print('==========================================')
for i, idx in enumerate(top_idxs):
key = str(idx)
class_name = mapping[key][1]
print('{}.\t{}'.format(i+1, class_name))
print('==========================================')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('image', type=str, help='Path to image that we want to classify')
parser.add_argument('checkpoint', type=str, help='Path to ShuffleNet checkpoint with trained weights')
parser.add_argument('idx_to_class', type=str, help='Path to JSON file mapping indexes to class names')
parser.add_argument('--num_classes', type=int, help='Number of classes to predict', default=1000)
args = parser.parse_args()
infer(args)