-
Notifications
You must be signed in to change notification settings - Fork 9
/
classify_image.py
111 lines (101 loc) · 4.14 KB
/
classify_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from nets.vgg_f import vggf
from nets.caffenet import caffenet
from nets.vgg_16 import vgg16
from nets.vgg_19 import vgg19
from nets.googlenet import googlenet
from nets.resnet_50 import resnet50
from nets.resnet_152 import resnet152
from nets.inception_v3 import inceptionv3
from misc.utils import *
import tensorflow as tf
import numpy as np
import argparse
import sys
def validate_arguments(args):
nets = ['vggf', 'caffenet', 'vgg16', 'vgg19', 'googlenet', 'resnet50', 'resnet152', 'inceptionv3']
if not(args.network in nets):
print ('invalid network')
exit (-1)
if args.evaluate:
if args.img_list is None or args.gt_labels is None:
print ('provide image list and labels')
exit (-1)
def choose_net(network):
MAP = {
'vggf' : vggf,
'caffenet' : caffenet,
'vgg16' : vgg16,
'vgg19' : vgg19,
'googlenet': googlenet,
'resnet50' : resnet50,
'resnet152': resnet152,
'inceptionv3': inceptionv3,
}
if network == 'caffenet':
size = 227
elif network == 'inceptionv3':
size = 299
else:
size = 224
#placeholder to pass image
input_image = tf.placeholder(shape=[None, size, size, 3],dtype='float32', name='input_image')
return MAP[network](input_image), input_image
def evaluate(net, im_list, in_im, labels, net_name):
top_1 = 0
top_5 = 0
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
imgs = open(im_list).readlines()
gt_labels = open(labels).readlines()
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
for i,name in enumerate(imgs):
if net_name=='caffenet':
im = img_preprocess(name.strip(), size=227)
elif net_name == 'inceptionv3':
im = v3_preprocess(name.strip())
else:
im = img_preprocess(name.strip())
softmax_scores = sess.run(net['prob'], feed_dict={in_im: im})
inds = np.argsort(softmax_scores[0])[::-1][:5]
if i!=0 and i%1000 == 0:
print 'iter: {:5d}\ttop-1: {:04.2f}\ttop-5: {:04.2f}'.format(i, (top_1/float(i))*100, (top_5)/float(i)*100)
if inds[0] == int(gt_labels[i].strip()):
top_1 += 1
top_5 += 1
elif int(gt_labels[i].strip()) in inds:
top_5 += 1
print 'Top-1 Accuracy = {:.2f}'.format(top_1/500.0)
print 'Top-5 Accuracy = {:.2f}'.format(top_5/500.0)
def predict(net, im_path, in_im, net_name):
synset = open('misc/ilsvrc_synsets.txt').readlines()
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
if net_name=='caffenet':
im = img_preprocess(im_path, size=227)
elif net_name == 'inceptionv3':
im = v3_preprocess(im_path)
else:
im = img_preprocess(im_path)
#fc_score = net['fc8']
softmax_scores = sess.run(net['prob'], feed_dict={in_im: im})
inds = np.argsort(softmax_scores[0])[::-1][:5]
print '{:}\t{:}'.format('Score','Class')
for i in inds:
print '{:.4f}\t{:}'.format(softmax_scores[0,i], synset[i].strip().split(',')[0])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--network', default='googlenet', help='The network eg. googlenet')
parser.add_argument('--img_path', default='misc/adv_test.jpg', help='Path to input image')
parser.add_argument('--evaluate', default=False, help='Flag to evaluate over full validation set')
parser.add_argument('--img_list', help='Path to the validation image list')
parser.add_argument('--gt_labels', help='Path to the ground truth validation labels')
args = parser.parse_args()
validate_arguments(args)
net, inp_im = choose_net(args.network)
if args.evaluate:
evaluate(net, args.img_list, inp_im, args.gt_labels, args.network)
else:
predict(net, args.img_path, inp_im, args.network)
if __name__ == '__main__':
main()