-
Notifications
You must be signed in to change notification settings - Fork 21
/
ds_tensor.py
57 lines (42 loc) · 1.55 KB
/
ds_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/python
# -*- coding: utf-8 -*-
import time
import tensorflow as tf
label_map = {'猫': 0, '狗': 1}
with open('train.csv') as f:
lines = [line.strip().split(',') for line in f.readlines()]
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3) # (1)
image = tf.cast(image_decoded, tf.float32)
image = tf.image.resize_images(image, [224, 224]) # (2)
return image, filename, label
def training_preprocess(image, filename, label):
flip_image = tf.image.random_flip_left_right(image) # (4)
return flip_image, filename, label
images = []
labels = []
for line in lines:
images.append(line[0])
labels.append(label_map[line[1]])
images = tf.constant(images)
labels = tf.constant(labels)
images = tf.random_shuffle(images, seed=0)
labels = tf.random_shuffle(labels, seed=0)
data = tf.data.Dataset.from_tensor_slices((images, labels))
data = data.map(_parse_function, num_parallel_calls=4)
data = data.prefetch(buffer_size=2 * 10)
batched_data = data.batch(2)
iterator = tf.data.Iterator.from_structure(batched_data.output_types,
batched_data.output_shapes)
init_op = iterator.make_initializer(batched_data)
tt = time.time()
with tf.Session() as sess:
sess.run(init_op)
for i in range(100):
try:
images, filenames, labels = iterator.get_next()
print('{} -> {}'.format(i, sess.run(labels)))
except tf.errors.OutOfRangeError:
sess.run(init_op)
print(time.time() - tt)