-
Notifications
You must be signed in to change notification settings - Fork 16
/
03-infer-with-pb-tf-records.py
84 lines (75 loc) · 3.44 KB
/
03-infer-with-pb-tf-records.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import tensorflow as tf
import os
import sys
from tensorflow.python.platform import gfile
import numpy as np
from scipy.misc import imread
import glob
with open("./images/labels.txt") as f:
lines = list(f.readlines())
labels = [str(w).replace("\n", "") for w in lines]
print(labels)
NCLASS = len(labels)
NCHANNEL = 3
WIDTH = 224
HEIGHT = 224
def getImageBatch(filenames, batch_size, capacity, min_after_dequeue, enqueue_many=True):
filenameQ = tf.train.string_input_producer(filenames, num_epochs=None)
recordReader = tf.TFRecordReader()
key, fullExample = recordReader.read(filenameQ)
features = tf.parse_single_example(
fullExample,
features={
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width': tf.FixedLenFeature([], tf.int64),
'image/colorspace': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/channels': tf.FixedLenFeature([], tf.int64),
'image/class/label': tf.FixedLenFeature([], tf.int64),
'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/format': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/filename': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value='')
})
label = features['image/class/label']
image_buffer = features['image/encoded']
with tf.name_scope('decode_jpeg', [image_buffer], None):
image = tf.image.decode_jpeg(image_buffer, channels=NCHANNEL)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.reshape(1 - tf.image.rgb_to_grayscale(image), [WIDTH * HEIGHT * NCHANNEL])
label = tf.stack(tf.one_hot(label - 1, NCLASS))
imageBatch, labelBatch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue, enqueue_many=enqueue_many)
return imageBatch, labelBatch
with gfile.FastGFile("./output_graph_510.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
sess.graph.as_default()
tf.import_graph_def(graph_def)
tf.global_variables_initializer().run()
cwd = os.getcwd()
file_list = list(
map(lambda x: os.path.join(cwd, x), glob.glob("images/tf_records/validation*"))
)
print(file_list)
image_tensor, label_batch = getImageBatch(file_list, 2, 4, 2, enqueue_many=False)
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
image_data = sess.run(image_tensor)
print(image_data.shape)
# softmax_tensor = sess.graph.get_tensor_by_name('import/final_result:0')
# predictions = sess.run(softmax_tensor, {'import/input:0': image_data})
# predictions = np.squeeze(predictions)
# print(predictions)
coord.request_stop()
coord.join(threads)
# top_k = predictions.argsort()[:][::-1] # Getting top 3 predictions, reverse order
# for node_id in top_k:
# human_string = labels[node_id]
# score = predictions[node_id]
# print('%s (score = %.5f)' % (human_string, score))
# answer = labels[top_k[0]]
# return answer