This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 442
/
iterator.py
294 lines (260 loc) · 11.3 KB
/
iterator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
"""Utilities for real-time data augmentation on image data.
"""
import os
import threading
import numpy as np
from keras_preprocessing import get_keras_submodule
try:
IteratorType = get_keras_submodule('utils').Sequence
except ImportError:
IteratorType = object
from .utils import array_to_img, img_to_array, load_img
class Iterator(IteratorType):
"""Base class for image data iterators.
Every `Iterator` must implement the `_get_batches_of_transformed_samples`
method.
# Arguments
n: Integer, total number of samples in the dataset to loop over.
batch_size: Integer, size of a batch.
shuffle: Boolean, whether to shuffle the data between epochs.
seed: Random seeding for data shuffling.
"""
white_list_formats = ('png', 'jpg', 'jpeg', 'bmp', 'ppm', 'tif', 'tiff')
def __init__(self, n, batch_size, shuffle, seed):
self.n = n
self.batch_size = batch_size
self.seed = seed
self.shuffle = shuffle
self.batch_index = 0
self.total_batches_seen = 0
self.lock = threading.Lock()
self.index_array = None
self.index_generator = self._flow_index()
def _set_index_array(self):
self.index_array = np.arange(self.n)
if self.shuffle:
self.index_array = np.random.permutation(self.n)
def __getitem__(self, idx):
if idx >= len(self):
raise ValueError('Asked to retrieve element {idx}, '
'but the Sequence '
'has length {length}'.format(idx=idx,
length=len(self)))
if self.seed is not None:
np.random.seed(self.seed + self.total_batches_seen)
self.total_batches_seen += 1
if self.index_array is None:
self._set_index_array()
index_array = self.index_array[self.batch_size * idx:
self.batch_size * (idx + 1)]
return self._get_batches_of_transformed_samples(index_array)
def __len__(self):
return (self.n + self.batch_size - 1) // self.batch_size # round up
def on_epoch_end(self):
self._set_index_array()
def reset(self):
self.batch_index = 0
def _flow_index(self):
# Ensure self.batch_index is 0.
self.reset()
while 1:
if self.seed is not None:
np.random.seed(self.seed + self.total_batches_seen)
if self.batch_index == 0:
self._set_index_array()
if self.n == 0:
# Avoiding modulo by zero error
current_index = 0
else:
current_index = (self.batch_index * self.batch_size) % self.n
if self.n > current_index + self.batch_size:
self.batch_index += 1
else:
self.batch_index = 0
self.total_batches_seen += 1
yield self.index_array[current_index:
current_index + self.batch_size]
def __iter__(self):
# Needed if we want to do something like:
# for x, y in data_gen.flow(...):
return self
def __next__(self, *args, **kwargs):
return self.next(*args, **kwargs)
def next(self):
"""For python 2.x.
# Returns
The next batch.
"""
with self.lock:
index_array = next(self.index_generator)
# The transformation of images is not under thread lock
# so it can be done in parallel
return self._get_batches_of_transformed_samples(index_array)
def _get_batches_of_transformed_samples(self, index_array):
"""Gets a batch of transformed samples.
# Arguments
index_array: Array of sample indices to include in batch.
# Returns
A batch of transformed samples.
"""
raise NotImplementedError
class BatchFromFilesMixin():
"""Adds methods related to getting batches from filenames
It includes the logic to transform image files to batches.
"""
def set_processing_attrs(self,
image_data_generator,
target_size,
color_mode,
data_format,
save_to_dir,
save_prefix,
save_format,
subset,
interpolation,
keep_aspect_ratio):
"""Sets attributes to use later for processing files into a batch.
# Arguments
image_data_generator: Instance of `ImageDataGenerator`
to use for random transformations and normalization.
target_size: tuple of integers, dimensions to resize input images to.
color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`.
Color mode to read images.
data_format: String, one of `channels_first`, `channels_last`.
save_to_dir: Optional directory where to save the pictures
being yielded, in a viewable format. This is useful
for visualizing the random transformations being
applied, for debugging purposes.
save_prefix: String prefix to use for saving sample
images (if `save_to_dir` is set).
save_format: Format to use for saving sample images
(if `save_to_dir` is set).
subset: Subset of data (`"training"` or `"validation"`) if
validation_split is set in ImageDataGenerator.
interpolation: Interpolation method used to resample the image if the
target size is different from that of the loaded image.
Supported methods are "nearest", "bilinear", and "bicubic".
If PIL version 1.1.3 or newer is installed, "lanczos" is also
supported. If PIL version 3.4.0 or newer is installed, "box" and
"hamming" are also supported. By default, "nearest" is used.
"""
self.image_data_generator = image_data_generator
self.target_size = tuple(target_size)
self.keep_aspect_ratio = keep_aspect_ratio
if color_mode not in {'rgb', 'rgba', 'grayscale'}:
raise ValueError('Invalid color mode:', color_mode,
'; expected "rgb", "rgba", or "grayscale".')
self.color_mode = color_mode
self.data_format = data_format
if self.color_mode == 'rgba':
if self.data_format == 'channels_last':
self.image_shape = self.target_size + (4,)
else:
self.image_shape = (4,) + self.target_size
elif self.color_mode == 'rgb':
if self.data_format == 'channels_last':
self.image_shape = self.target_size + (3,)
else:
self.image_shape = (3,) + self.target_size
else:
if self.data_format == 'channels_last':
self.image_shape = self.target_size + (1,)
else:
self.image_shape = (1,) + self.target_size
self.save_to_dir = save_to_dir
self.save_prefix = save_prefix
self.save_format = save_format
self.interpolation = interpolation
if subset is not None:
validation_split = self.image_data_generator._validation_split
if subset == 'validation':
split = (0, validation_split)
elif subset == 'training':
split = (validation_split, 1)
else:
raise ValueError(
'Invalid subset name: %s;'
'expected "training" or "validation"' % (subset,))
else:
split = None
self.split = split
self.subset = subset
def _get_batches_of_transformed_samples(self, index_array):
"""Gets a batch of transformed samples.
# Arguments
index_array: Array of sample indices to include in batch.
# Returns
A batch of transformed samples.
"""
batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=self.dtype)
# build batch of image data
# self.filepaths is dynamic, is better to call it once outside the loop
filepaths = self.filepaths
for i, j in enumerate(index_array):
img = load_img(filepaths[j],
color_mode=self.color_mode,
target_size=self.target_size,
interpolation=self.interpolation,
keep_aspect_ratio=self.keep_aspect_ratio)
x = img_to_array(img, data_format=self.data_format)
# Pillow images should be closed after `load_img`,
# but not PIL images.
if hasattr(img, 'close'):
img.close()
if self.image_data_generator:
params = self.image_data_generator.get_random_transform(x.shape)
x = self.image_data_generator.apply_transform(x, params)
x = self.image_data_generator.standardize(x)
batch_x[i] = x
# optionally save augmented images to disk for debugging purposes
if self.save_to_dir:
for i, j in enumerate(index_array):
img = array_to_img(batch_x[i], self.data_format, scale=True)
fname = '{prefix}_{index}_{hash}.{format}'.format(
prefix=self.save_prefix,
index=j,
hash=np.random.randint(1e7),
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
# build batch of labels
if self.class_mode == 'input':
batch_y = batch_x.copy()
elif self.class_mode in {'binary', 'sparse'}:
batch_y = np.empty(len(batch_x), dtype=self.dtype)
for i, n_observation in enumerate(index_array):
batch_y[i] = self.classes[n_observation]
elif self.class_mode == 'categorical':
batch_y = np.zeros((len(batch_x), len(self.class_indices)),
dtype=self.dtype)
for i, n_observation in enumerate(index_array):
batch_y[i, self.classes[n_observation]] = 1.
elif self.class_mode == 'multi_output':
batch_y = [output[index_array] for output in self.labels]
elif self.class_mode == 'raw':
batch_y = self.labels[index_array]
else:
return batch_x
if self.sample_weight is None:
return batch_x, batch_y
else:
return batch_x, batch_y, self.sample_weight[index_array]
@property
def filepaths(self):
"""List of absolute paths to image files"""
raise NotImplementedError(
'`filepaths` property method has not been implemented in {}.'
.format(type(self).__name__)
)
@property
def labels(self):
"""Class labels of every observation"""
raise NotImplementedError(
'`labels` property method has not been implemented in {}.'
.format(type(self).__name__)
)
@property
def sample_weight(self):
raise NotImplementedError(
'`sample_weight` property method has not been implemented in {}.'
.format(type(self).__name__)
)