-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
363 lines (315 loc) · 13.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
# Copyright 2019 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SketchRNN data loading and image manipulation utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
from PIL import Image
def get_bounds(data, factor=10):
"""Return bounds of data."""
min_x = 0
max_x = 0
min_y = 0
max_y = 0
abs_x = 0
abs_y = 0
for i in range(len(data)):
x = float(data[i, 0]) / factor
y = float(data[i, 1]) / factor
abs_x += x
abs_y += y
min_x = min(min_x, abs_x)
min_y = min(min_y, abs_y)
max_x = max(max_x, abs_x)
max_y = max(max_y, abs_y)
return min_x, max_x, min_y, max_y
def slerp(p0, p1, t):
"""Spherical interpolation."""
omega = np.arccos(np.dot(p0 / np.linalg.norm(p0), p1 / np.linalg.norm(p1)))
so = np.sin(omega)
return np.sin((1.0 - t) * omega) / so * p0 + np.sin(t * omega) / so * p1
def lerp(p0, p1, t):
"""Linear interpolation."""
return (1.0 - t) * p0 + t * p1
# A note on formats:
# Sketches are encoded as a sequence of strokes. stroke-3 and stroke-5 are
# different stroke encodings.
# stroke-3 uses 3-tuples, consisting of x-offset, y-offset, and a binary
# variable which is 1 if the pen is lifted between this position and
# the next, and 0 otherwise.
# stroke-5 consists of x-offset, y-offset, and p_1, p_2, p_3, a binary
# one-hot vector of 3 possible pen states: pen down, pen up, end of sketch.
# See section 3.1 of https://arxiv.org/abs/1704.03477 for more detail.
# Sketch-RNN takes input in stroke-5 format, with sketches padded to a common
# maximum length and prefixed by the special start token [0, 0, 1, 0, 0]
# The QuickDraw dataset is stored using stroke-3.
def strokes_to_lines(strokes):
"""Convert stroke-3 format to polyline format."""
x = 0
y = 0
lines = []
line = []
for i in range(len(strokes)):
if strokes[i, 2] == 1:
x += float(strokes[i, 0])
y += float(strokes[i, 1])
line.append([x, y])
lines.append(line)
line = []
else:
x += float(strokes[i, 0])
y += float(strokes[i, 1])
line.append([x, y])
return lines
def lines_to_strokes(lines):
"""Convert polyline format to stroke-3 format."""
eos = 0
strokes = [[0, 0, 0]]
for line in lines:
linelen = len(line)
for i in range(linelen):
eos = 0 if i < linelen - 1 else 1
strokes.append([line[i][0], line[i][1], eos])
strokes = np.array(strokes)
strokes[1:, 0:2] -= strokes[:-1, 0:2]
return strokes[1:, :]
def augment_strokes(strokes, prob=0.0):
"""Perform data augmentation by randomly dropping out strokes."""
# drop each point within a line segments with a probability of prob
# note that the logic in the loop prevents points at the ends to be dropped.
result = []
prev_stroke = [0, 0, 1]
count = 0
stroke = [0, 0, 1] # Added to be safe.
for i in range(len(strokes)):
candidate = [strokes[i][0], strokes[i][1], strokes[i][2]]
if candidate[2] == 1 or prev_stroke[2] == 1:
count = 0
else:
count += 1
urnd = np.random.rand() # uniform random variable
if candidate[2] == 0 and prev_stroke[2] == 0 and count > 2 and urnd < prob:
stroke[0] += candidate[0]
stroke[1] += candidate[1]
else:
stroke = candidate
prev_stroke = stroke
result.append(stroke)
return np.array(result)
def scale_bound(stroke, average_dimension=10.0):
"""Scale an entire image to be less than a certain size."""
# stroke is a numpy array of [dx, dy, pstate], average_dimension is a float.
# modifies stroke directly.
bounds = get_bounds(stroke, 1)
max_dimension = max(bounds[1] - bounds[0], bounds[3] - bounds[2])
stroke[:, 0:2] /= (max_dimension / average_dimension)
def to_normal_strokes(big_stroke):
"""Convert from stroke-5 format (from sketch-rnn paper) back to stroke-3."""
l = 0
for i in range(len(big_stroke)):
if big_stroke[i, 4] > 0:
l = i
break
if l == 0:
l = len(big_stroke)
result = np.zeros((l, 3))
result[:, 0:2] = big_stroke[0:l, 0:2]
result[:, 2] = big_stroke[0:l, 3]
return result
def clean_strokes(sample_strokes, factor=100):
"""Cut irrelevant end points, scale to pixel space and store as integer."""
# Useful function for exporting data to .json format.
copy_stroke = []
added_final = False
for j in range(len(sample_strokes)):
finish_flag = int(sample_strokes[j][4])
if finish_flag == 0:
copy_stroke.append([
int(round(sample_strokes[j][0] * factor)),
int(round(sample_strokes[j][1] * factor)),
int(sample_strokes[j][2]),
int(sample_strokes[j][3]), finish_flag
])
else:
copy_stroke.append([0, 0, 0, 0, 1])
added_final = True
break
if not added_final:
copy_stroke.append([0, 0, 0, 0, 1])
return copy_stroke
def to_big_strokes(stroke, max_len=250):
"""Converts from stroke-3 to stroke-5 format and pads to given length."""
# (But does not insert special start token).
result = np.zeros((max_len, 5), dtype=float)
l = len(stroke)
assert l <= max_len
result[0:l, 0:2] = stroke[:, 0:2]
result[0:l, 3] = stroke[:, 2]
result[0:l, 2] = 1 - result[0:l, 3]
result[l:, 4] = 1
return result
def get_max_len(strokes):
"""Return the maximum length of an array of strokes."""
max_len = 0
for stroke in strokes: # stroke: [N_points, 3]
ml = len(stroke)
max_len = ml if ml > max_len else max_len
return max_len
class DataLoader(object):
"""Class for loading data."""
def __init__(self,
strokes,
png_paths,
img_h,
img_w,
batch_size=100,
max_seq_length=250,
scale_factor=1.0,
random_scale_factor=0.0,
augment_stroke_prob=0.0,
limit=1000):
self.batch_size = batch_size # minibatch size
self.max_seq_length = max_seq_length # N_max in sketch-rnn paper
self.scale_factor = scale_factor # divide offsets by this factor
self.random_scale_factor = random_scale_factor # data augmentation method
# Removes large gaps in the data. x and y offsets are clamped to have
# absolute value no greater than this limit.
self.limit = limit
self.augment_stroke_prob = augment_stroke_prob # data augmentation method
self.start_stroke_token = [0, 0, 1, 0, 0] # S_0 in sketch-rnn paper
# self.strokes (list of ndarrays, one per sketch, in stroke-3 format, sorted by size)
# self.png_paths (list)
self.preprocess(strokes, png_paths)
self.img_h = img_h
self.img_w = img_w
def preprocess(self, strokes, img_paths):
"""Remove entries from strokes having > max_seq_length points."""
raw_data = []
seq_len = []
raw_data_paths = []
count_data = 0
for i in range(len(strokes)):
data = strokes[i]
img_path = img_paths[i]
if len(data) <= self.max_seq_length:
count_data += 1
# removes large gaps from the data
data = np.minimum(data, self.limit)
data = np.maximum(data, -self.limit)
data = np.array(data, dtype=np.float32)
data[:, 0:2] /= self.scale_factor
raw_data.append(data)
seq_len.append(len(data))
raw_data_paths.append(img_path)
seq_len = np.array(seq_len) # nstrokes for each sketch
self.sorted_idx = np.argsort(seq_len)
self.strokes = []
self.png_paths = []
for i in range(len(seq_len)):
self.strokes.append(raw_data[self.sorted_idx[i]])
self.png_paths.append(raw_data_paths[self.sorted_idx[i]])
print("total images <= max_seq_len is %d" % count_data)
self.num_batches = int(count_data / self.batch_size)
def random_sample(self):
"""Return a random sample, in stroke-3 format as used by draw_strokes."""
rand_idx = random.randint(0, len(self.strokes) - 1)
print('## rand_idx', rand_idx)
sample = np.copy(self.strokes[rand_idx])
image = self.load_images([self.png_paths[rand_idx]])
return sample, rand_idx, image
def random_scale(self, data):
"""Augment data by stretching x and y axis randomly [1-e, 1+e]."""
x_scale_factor = (np.random.random() - 0.5) * 2 * self.random_scale_factor + 1.0
y_scale_factor = (np.random.random() - 0.5) * 2 * self.random_scale_factor + 1.0
result = np.copy(data)
result[:, 0] *= x_scale_factor
result[:, 1] *= y_scale_factor
return result
def calculate_normalizing_scale_factor(self):
"""Calculate the normalizing factor explained in appendix of sketch-rnn."""
data = []
for i in range(len(self.strokes)):
if len(self.strokes[i]) > self.max_seq_length:
continue
for j in range(len(self.strokes[i])):
data.append(self.strokes[i][j, 0])
data.append(self.strokes[i][j, 1])
data = np.array(data)
return np.std(data)
def normalize(self, scale_factor=None):
"""Normalize entire dataset (delta_x, delta_y) by the scaling factor."""
if scale_factor is None:
scale_factor = self.calculate_normalizing_scale_factor()
self.scale_factor = scale_factor
for i in range(len(self.strokes)):
self.strokes[i][:, 0:2] /= self.scale_factor
def _get_batch_from_indices(self, indices):
"""Given a list of indices, return the potentially augmented batch."""
x_batch = []
seq_len = []
img_paths = []
for idx in range(len(indices)):
i = indices[idx]
data = self.random_scale(self.strokes[i])
data_copy = np.copy(data)
if self.augment_stroke_prob > 0:
data_copy = augment_strokes(data_copy, self.augment_stroke_prob)
x_batch.append(data_copy)
length = len(data_copy)
seq_len.append(length)
img_paths.append(self.png_paths[i])
seq_len = np.array(seq_len, dtype=int)
# We return three things: stroke-3 format, stroke-5 format, list of seq_len.
return x_batch, self.pad_batch(x_batch), seq_len, self.load_images(img_paths)
def random_batch(self):
"""Return a randomised portion of the training data."""
idx = np.random.permutation(range(0, len(self.strokes)))[0:self.batch_size]
return self._get_batch_from_indices(idx)
def get_batch(self, idx):
"""Get the idx'th batch from the dataset."""
assert idx >= 0, "idx must be non negative"
assert idx < self.num_batches, "idx must be less than the number of batches"
start_idx = idx * self.batch_size
indices = range(start_idx, start_idx + self.batch_size)
return self._get_batch_from_indices(indices)
def pad_batch(self, batch):
"""Pad the batch to be stroke-5 bigger format as described in paper."""
result = np.zeros((self.batch_size, self.max_seq_length + 1, 5), dtype=float)
assert len(batch) == self.batch_size
for i in range(self.batch_size):
l = len(batch[i])
assert l <= self.max_seq_length
result[i, 0:l, 0:2] = batch[i][:, 0:2]
result[i, 0:l, 3] = batch[i][:, 2]
result[i, 0:l, 2] = 1 - result[i, 0:l, 3]
result[i, l:, 4] = 1
# put in the first token, as described in sketch-rnn methodology
result[i, 1:, :] = result[i, :-1, :]
result[i, 0, :] = 0
result[i, 0, 2] = self.start_stroke_token[2] # setting S_0 from paper.
result[i, 0, 3] = self.start_stroke_token[3]
result[i, 0, 4] = self.start_stroke_token[4]
return result
def load_images(self, image_paths):
assert len(image_paths) == self.batch_size
img_batch = np.zeros(shape=[self.batch_size, self.img_h, self.img_w, 1], dtype=np.float32)
for img_idx in range(len(image_paths)):
image_path = image_paths[img_idx]
image = Image.open(image_path).convert('L')
image = np.array(image, dtype=np.float32)
image = np.expand_dims(image, axis=2)
img_batch[img_idx] = image
return img_batch