-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathimagenet_data.py
executable file
·118 lines (88 loc) · 3.58 KB
/
imagenet_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fuel.streams import AbstractDataStream
#from fuel.iterator import DataIterator
import numpy as np
import theano
class IMAGENET(AbstractDataStream):
"""
A fuel DataStream for imagenet data
from fuel:
A data stream is an iterable stream of examples/minibatches. It shares
similarities with Python file handles return by the ``open`` method.
Data streams can be closed using the :meth:`close` method and reset
using :meth:`reset` (similar to ``f.seek(0)``).
"""
def __init__(self, partition_label='train', datadir='/home/jascha/data/imagenet/JPEG/', seed=12345, fraction=0.9, width=256, **kwargs):
# ignore axis labels if not given
kwargs.setdefault('axis_labels', '')
# call __init__ of the AbstractDataStream
super(self.__class__, self).__init__(**kwargs)
# get a list of the images
import glob
print "getting imagenet images"
image_files = glob.glob(datadir + "*.JPEG")
print "filenames loaded"
self.sources = ('features',)
self.width = width
# shuffle indices, subselect a fraction
np.random.seed(seed=seed)
np.random.shuffle(image_files)
num_train = int(np.round(fraction * np.float32(len(image_files))))
train_files = image_files[:num_train]
test_files = image_files[num_train:]
if 'train' in partition_label:
self.X = train_files
elif 'test' in partition_label:
self.X = test_files
self.num_examples = len(self.X)
self.current_index = 0
def get_data(self, data_state, request=None):
"""Get a new sample of data"""
if request is None:
request = [self.current_index]
self.current_index += 1
return self.load_images(request)
def apply_default_transformers(self, data_stream):
return data_stream
def open(self):
return None
def close(self):
"""Close the hdf5 file"""
pass
def reset(self):
"""Reset the current data index"""
self.current_index = 0
def get_epoch_iterator(self, **kwargs):
return super(self.__class__, self).get_epoch_iterator(**kwargs)
# return None
# TODO: implement iterator
def next_epoch(self, *args, **kwargs):
self.current_index = 0
return super(self.__class__, self).next_epoch(**kwargs)
# return None
def load_images(self, inds):
print ".",
output = np.zeros((len(inds), 3, self.width, self.width), dtype=theano.config.floatX)
for ii, idx in enumerate(inds):
output[ii] = self.load_image(idx)
return [output]
def load_image(self, idx):
#filename = self.X[idx]
import Image
import ImageOps
# print "loading ", self.X[idx]
image = Image.open(self.X[idx])
width, height = image.size
if width > height:
delta2 = int((width - height)/2)
image = ImageOps.expand(image, border=(0, delta2, 0, delta2))
else:
delta2 = int((height - width)/2)
image = ImageOps.expand(image, border=(delta2, 0, delta2, 0))
image = image.resize((self.width, self.width), resample=Image.BICUBIC)
try:
imagenp = np.array(image.getdata()).reshape((self.width,self.width,3))
imagenp = imagenp.transpose((2,0,1)) # move color channels to beginning
except:
# print "reshape failure (black and white?)"
imagenp = self.load_image(np.random.randint(len(self.X)))
return imagenp.astype(theano.config.floatX)