-
Notifications
You must be signed in to change notification settings - Fork 120
/
prepare_data.py
460 lines (402 loc) · 20 KB
/
prepare_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
import argparse
import multiprocessing
from functools import partial
from io import BytesIO
import lmdb
from PIL import Image
from tqdm import tqdm
import torch
import numpy as np
import pandas as pd
import cv2
import sys
import json
import os
from glob import glob
from utils.CUB_data_utils import square_bbox, perturb_bbox, acsm_crop
from utils.download import download_spair, download_lsun, download_cub, download_cub_metadata
# When an image is mirrored, any key points with left/right distinction need to be swapped.
# These are the permutations of key point indices that accomplishes this:
CUB_PERMUTATION = [0, 1, 2, 3, 4, 5, 10, 11, 12, 9, 6, 7, 8, 13, 14]
SPAIR_PERMUTATIONS = {
'bicycle': [0, 1, 3, 2, 4, 5, 7, 6, 8, 10, 9, 11],
'cat': [1, 0, 3, 2, 5, 4, 7, 6, 8, 10, 9, 12, 11, 13, 14],
'dog': [1, 0, 3, 2, 5, 4, 6, 7, 8, 10, 9, 12, 11, 13, 14, 15],
'tvmonitor': [2, 1, 0, 7, 6, 5, 4, 3, 10, 9, 8, 15, 14, 13, 12, 11],
}
def black_bar_pad(img, target_res, resize=True, to_pil=True):
canvas = np.zeros([target_res, target_res, 3], dtype=np.uint8)
original_width, original_height = img.size
if original_height <= original_width:
if resize:
img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), Image.ANTIALIAS)
width, height = img.size
img = np.asarray(img)
canvas[(width - height) // 2: (width + height) // 2] = img
else:
if resize:
img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), Image.ANTIALIAS)
width, height = img.size
img = np.asarray(img)
canvas[:, (height - width) // 2: (height + width) // 2] = img
if to_pil:
canvas = Image.fromarray(canvas)
return canvas
def border_pad(img, target_res, resize=True, to_pil=True):
original_width, original_height = img.size
if original_height <= original_width:
if resize:
img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), Image.ANTIALIAS)
width, height = img.size
img = np.asarray(img)
half_height = (target_res - height) / 2
int_half_height = int(half_height)
lh = int_half_height
rh = int_half_height + (half_height > int_half_height)
img = np.pad(img, mode='edge', pad_width=[(lh, rh), (0, 0), (0, 0)])
else:
if resize:
img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), Image.ANTIALIAS)
width, height = img.size
img = np.asarray(img)
half_width = (target_res - width) / 2
int_half_width = int(half_width)
lw = int_half_width
rw = int_half_width + (half_width > int_half_width)
img = np.pad(img, mode='edge', pad_width=[(0, 0), (lw, rw), (0, 0)])
if to_pil:
img = Image.fromarray(img)
return img
def center_crop(img, target_res):
# From official StyleGAN2 create_lsun method:
img = np.asarray(img)
crop = np.min(img.shape[:2])
img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
(img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
img = Image.fromarray(img, 'RGB')
img = img.resize((target_res, target_res), Image.ANTIALIAS)
return img
def nchw_center_crop(img):
# Essentially same as the above function but for (N,C,H,W) PyTorch tensors and doesn't resize (only crops)
assert img.dim() == 4
H, W = img.size(2), img.size(3)
crop = min(H, W)
top_start = (H - crop) // 2
left_start = (W - crop) // 2
img = img[:, :, top_start: (H + crop) // 2, left_start: (W + crop) // 2]
return img, (top_start, left_start)
def resize_small_side(img, target_res):
# This function resizes img while maintaining its aspect ratio.
# The smaller side of img gets resized to target_res.
w, h = img.size
if w < h:
new_w = target_res
new_h = h * target_res // w
else:
new_h = target_res
new_w = w * target_res // h
img = img.resize((new_w, new_h), Image.ANTIALIAS)
return img
def cub_crop(img, target_res, bbox):
# This function mimics ACSM's pre-processing used for the CUB dataset (up to image resampling and padding color)
img = np.asarray(img)
img = acsm_crop(img, bbox, 0, border=True)
return Image.fromarray(img).resize((target_res, target_res), Image.ANTIALIAS)
def preprocess_kps_pad(kps, img_width, img_height, size):
# Once an image has been pre-processed via border (or zero) padding,
# the location of key points needs to be updated. This function applies
# that pre-processing to the key points so they are correctly located
# in the border-padded (or zero-padded) image.
kps = kps.clone()
scale = size / max(img_width, img_height)
kps[:, [0, 1]] *= scale
if img_height < img_width:
new_h = int(np.around(size * img_height / img_width))
offset_y = int((size - new_h) / 2)
offset_x = 0
kps[:, 1] += offset_y
elif img_width < img_height:
new_w = int(np.around(size * img_width / img_height))
offset_x = int((size - new_w) / 2)
offset_y = 0
kps[:, 0] += offset_x
else:
offset_x = 0
offset_y = 0
kps *= kps[:, 2:3] # zero-out any non-visible key points
return kps, offset_x, offset_y, scale
def preprocess_kps_box_crop(kps, bbox, size):
# Once an image has been pre-processed via a box crop,
# the location of key points needs to be updated. This function applies
# that pre-processing to the key points so they are correctly located
# in the cropped image.
kps = kps.clone()
kps[:, 0] -= bbox[0] + 1
kps[:, 1] -= bbox[1] + 1
w = 1 + bbox[2] - bbox[0]
h = 1 + bbox[3] - bbox[1]
assert w == h
kps[:, [0, 1]] *= size / float(w)
return kps
def load_CUB_keypoints(path):
names = ['img_index', 'kp_index', 'x', 'y', 'visible']
landmarks = pd.read_table(path, header=None, names=names, delim_whitespace=True, engine='python')
landmarks = landmarks.to_numpy().reshape((11788, 15, 5))[..., [2, 3, 4]] # (num_images, num_kps, 3)
landmarks = torch.from_numpy(landmarks).float()
return landmarks
def load_acsm_data(path, mat_path='data/val_cub_cleaned.mat', size=256, out_path=None):
from scipy.io import loadmat
mat = loadmat(mat_path)
files = [f'data/CUB_200_2011/images/{file[0]}' for file in mat['images']['rel_path'][0]]
# These are the indices retained by ACSM (others are filtered):
indices = [i[0, 0] - 1 for i in mat['images']['id'][0]]
kps = load_CUB_keypoints(f'{path}/parts/part_locs.txt')[indices]
b = mat['images']['bbox'][0]
bboxes = []
kps_out = []
for ix, row in enumerate(b):
x1, y1, x2, y2 = row[0, 0]
bbox = np.array([x1[0, 0], y1[0, 0], x2[0, 0], y2[0, 0]]) - 1
bbox = perturb_bbox(bbox, 0.05, 0)
bbox = square_bbox(bbox)
bboxes.append(bbox)
kps_out.append(preprocess_kps_box_crop(kps[ix], bbox, size))
bboxes = np.stack(bboxes)
kps_out = torch.stack(kps_out)
torch.save(kps_out, f'{out_path}/keypoints.pt')
# When an image is mirrored horizontally, the designation between key points with a left versus right distinction
# needs to be swapped. This is the permutation of CUB key points which accomplishes this swap:
torch.save(CUB_PERMUTATION, f'{out_path}/permutation.pt')
assert bboxes.shape[0] == len(files)
return files, bboxes
def load_spair_data(path, size, out_path, category='cat', split='test'):
pairs = sorted(glob(f'{path}/PairAnnotation/{split}/*:{category}.json'))
files = []
thresholds = []
inverse = []
category_anno = list(glob(f'{path}/ImageAnnotation/{category}/*.json'))[0]
with open(category_anno) as f:
num_kps = len(json.load(f)['kps'])
print(f'Number of SPair key points for {category} <= {num_kps}')
kps = []
blank_kps = torch.zeros(num_kps, 3)
for pair in pairs:
with open(pair) as f:
data = json.load(f)
assert category == data["category"]
assert data["mirror"] == 0
source_fn = f'{path}/JPEGImages/{category}/{data["src_imname"]}'
target_fn = f'{path}/JPEGImages/{category}/{data["trg_imname"]}'
source_bbox = np.asarray(data["src_bndbox"])
target_bbox = np.asarray(data["trg_bndbox"])
# The source thresholds aren't actually used to evaluate PCK on SPair-71K, but for completeness
# they are computed as well:
thresholds.append(max(source_bbox[3] - source_bbox[1], source_bbox[2] - source_bbox[0]))
thresholds.append(max(target_bbox[3] - target_bbox[1], target_bbox[2] - target_bbox[0]))
source_size = data["src_imsize"][:2] # (W, H)
target_size = data["trg_imsize"][:2] # (W, H)
kp_ixs = torch.tensor([int(id) for id in data["kps_ids"]]).view(-1, 1).repeat(1, 3)
source_raw_kps = torch.cat([torch.tensor(data["src_kps"], dtype=torch.float), torch.ones(kp_ixs.size(0), 1)], 1)
source_kps = blank_kps.scatter(dim=0, index=kp_ixs, src=source_raw_kps)
source_kps, src_x, src_y, src_scale = preprocess_kps_pad(source_kps, source_size[0], source_size[1], size)
target_raw_kps = torch.cat([torch.tensor(data["trg_kps"], dtype=torch.float), torch.ones(kp_ixs.size(0), 1)], 1)
target_kps = blank_kps.scatter(dim=0, index=kp_ixs, src=target_raw_kps)
target_kps, trg_x, trg_y, trg_scale = preprocess_kps_pad(target_kps, target_size[0], target_size[1], size)
kps.append(source_kps)
kps.append(target_kps)
files.append(source_fn)
files.append(target_fn)
inverse.append([src_x, src_y, src_scale])
inverse.append([trg_x, trg_y, trg_scale])
kps = torch.stack(kps)
used_kps, = torch.where(kps[:, :, 2].any(dim=0))
kps = kps[:, used_kps, :]
print(f'Final number of used key points: {kps.size(1)}')
num_imgs = len(thresholds) # Total number of images (= 2 * number of pairs)
torch.save(torch.arange(num_imgs).view(num_imgs // 2, 2), f'{out_path}/pairs.pt')
torch.save(torch.tensor(thresholds, dtype=torch.float), f'{out_path}/pck_thresholds.pt')
torch.save(torch.tensor(inverse), f'{out_path}/inverse_coordinates.pt')
torch.save(kps, f'{out_path}/keypoints.pt')
torch.save(SPAIR_PERMUTATIONS[category], f'{out_path}/permutation.pt')
return files, [None] * len(files) # No bounding boxes are used
def load_image_folder(path, pattern):
files = sorted(glob(f'{path}/{pattern}'))
bboxes = [None] * len(files) # This means no bounding boxes are used
return files, bboxes
def resize_and_convert(img, size, pad, quality=100, format='jpeg', bbox=None):
if pad == 'zero':
img = black_bar_pad(img, size)
elif pad == 'border':
img = border_pad(img, size)
elif pad == 'center':
img = center_crop(img, size)
elif pad == 'none':
pass
elif pad == 'resize_small_side':
img = resize_small_side(img, size)
elif pad == 'cub_crop':
img = cub_crop(img, size, bbox)
else:
raise NotImplementedError
# img = trans_fn.resize(img, size, resample)
# img = trans_fn.center_crop(img, size)
buffer = BytesIO()
img.save(buffer, format=format, quality=quality)
val = buffer.getvalue()
return val
def resize_multiple(img, sizes=(128, 256, 512, 1024), quality=100, pad='zero', format='jpeg', bbox=None):
imgs = []
for size in sizes:
imgs.append(resize_and_convert(img, size, pad, quality, format, bbox))
return imgs
def resize_worker(img_file, sizes, pad, format, lmdb_path):
i, file, bbox = img_file
if lmdb_path is not None: # Load image from LMDB (useful for LSUN datasets, etc.)
input_env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
with input_env.begin(write=False) as txn:
img_bytes = txn.get(file)
try: # https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/dataset_tool.py
try:
img = cv2.imdecode(np.frombuffer(img_bytes, dtype=np.uint8), 1)
if img is None:
raise IOError('cv2.imdecode failed')
img = img[:, :, ::-1] # BGR => RGB
img = Image.fromarray(img)
except IOError:
img = Image.open(BytesIO(img_bytes))
out = resize_multiple(img, sizes=sizes, pad=pad, format=format, bbox=bbox)
except: # Corrupted image that should be skipped
out = [None] * len(sizes)
print(sys.exc_info()[1])
else: # Directly open a saved image:
img = Image.open(file).convert('RGB')
out = resize_multiple(img, sizes=sizes, pad=pad, format=format, bbox=bbox)
return i, out
def prepare(
env, path, out, n_worker, sizes=(128, 256, 512, 1024), pad='zero', format='jpeg', input_is_lmdb=False,
pattern='*.png', max_images=None, spair_category=None, spair_split=None, cub_acsm=False
):
if input_is_lmdb:
lmdb_path = path
input_env = lmdb.open(lmdb_path, readonly=True, lock=False)
print('Loading LMDB keys (this might take a bit)...')
with input_env.begin(write=False) as inp_txn:
key_list = list(inp_txn.cursor().iternext(values=False)) # https://stackoverflow.com/a/65663873
if max_images is not None:
key_list = key_list[:max_images]
num_files = len(key_list)
print(f'LMDB keys loaded! Found {num_files} keys.')
files = [(i, key, None) for i, key in enumerate(key_list)]
else:
lmdb_path = None
if cub_acsm: # Load CUB using ACSM pre-processing (this is the only dataset that uses bboxes in pre-processing)
files, bboxes = load_acsm_data(path, size=int(sizes[0]), out_path=out)
elif spair_category is not None: # Load SPair-71K (bboxes = None)
files, bboxes = load_spair_data(path, size=int(sizes[0]), out_path=out,
category=spair_category, split=spair_split)
else: # Load images from a folder (or hierarchy of folders); bboxes = None
files, bboxes = load_image_folder(path, pattern)
if max_images is not None:
files = files[:max_images]
bboxes = bboxes[:max_images]
num_files = len(files)
print(f'Found {num_files} files')
print(f'Example file being loaded: {files[0]}')
files = [(i, file, bbox) for i, (file, bbox) in enumerate(zip(files, bboxes))]
resize_fn = partial(resize_worker, sizes=sizes, pad=pad, format=format, lmdb_path=lmdb_path)
total = 0
skipped = 0
with multiprocessing.Pool(n_worker) as pool:
for i, imgs in tqdm(pool.imap(resize_fn, files), total=num_files):
increment = 1
for size, img in zip(sizes, imgs):
if img is None:
print('skipping image')
increment = 0
skipped += 1
else:
key = f"{size}-{str(i - skipped).zfill(5)}".encode("utf-8")
with env.begin(write=True) as txn:
txn.put(key, img)
total += increment
with env.begin(write=True) as txn:
txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
print(f'Final dataset size: {total}')
def create_dataset(out, path, size, pad='zero', n_worker=8, format='jpeg', input_is_lmdb=False, pattern='*.png',
max_images=None, spair_category=None, spair_split=None, cub_acsm=False):
size = str(size)
sizes = [int(s.strip()) for s in size.split(",")]
print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
with lmdb.open(out, map_size=2048 ** 4, readahead=False) as env:
prepare(env, path, out, n_worker, sizes=sizes, pad=pad, format=format,
input_is_lmdb=input_is_lmdb, pattern=pattern, max_images=max_images,
spair_category=spair_category, spair_split=spair_split, cub_acsm=cub_acsm)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create image datasets")
parser.add_argument("--out", type=str, help="filename of the output lmdb dataset")
parser.add_argument(
"--size",
type=str,
default="256",
help="resolutions of images for the dataset",
)
parser.add_argument(
"--n_worker",
type=int,
default=8,
help="number of workers for preparing dataset",
)
parser.add_argument(
"--format",
type=str,
choices=['png', 'jpeg'],
default='png',
help="format to store images in the database",
)
parser.add_argument("--pad", choices=['zero', 'border', 'center', 'resize_small_side', 'none'], default='center',
help='Algorithm to pad or crop input images to square. zero = pad with zeros, '
'border = border padding, center = center crop, resize_small_size = resize while preserving '
'aspect ratio (only mixed_reality.py supports this option), none = no pre-processing')
parser.add_argument("--path", type=str, help="path to the image dataset", default=None)
parser.add_argument("--input_is_lmdb", action='store_true',
help='If true, path input points to an LMDB dataset. This is useful for, e.g., creating '
'LSUN datasets. If you use this you can ignore --pattern')
parser.add_argument("--pattern", type=str, default='*.png', help='Specify the pattern glob uses to find images')
parser.add_argument("--max_images", type=int, default=None, help='Maximum number of images to include in '
'final dataset (default: include all)')
# Special arguments for loading SPair-71K and CUB for PCK evaluation purposes (and also LSUN). If you use these
# options below, you can ignore --input_is_lmdb, --path and --pattern above.
parser.add_argument("--spair_category", default=None, type=str, choices=list(SPAIR_PERMUTATIONS.keys()),
help='If specified, constructs the SPair-71K dataset for the specified category')
parser.add_argument("--spair_split", default='test', choices=['trn', 'val', 'test'], type=str,
help='The split of SPair that will be constructed (only used if --spair_category is specified)')
parser.add_argument("--lsun_category", default=None, type=str,
help='If specified, constructs the LSUN dataset for the specified category '
'(may take a while to download!)')
parser.add_argument("--cub_acsm", action='store_true',
help='If true, constructs the CUB dataset. This will use the same pre-processing and filtering '
'as the CUB validation split from the ACSM paper.')
args = parser.parse_args()
os.makedirs(os.path.dirname(args.out), exist_ok=True)
# Automatically download some datasets:
if args.cub_acsm: # Download metadata for CUB pre-processing
os.makedirs('data', exist_ok=True)
download_cub_metadata('data')
args.pad = 'cub_crop'
elif args.spair_category is not None:
args.pad = 'border' # zero padding would also be acceptable
if args.cub_acsm and args.path is None: # Download CUB-200-2011 if needed
args.path = download_cub('data')
elif args.spair_category is not None and args.path is None: # Download SPair-71K data if needed
os.makedirs('data', exist_ok=True)
args.path = download_spair('data')
args.pad = 'border'
elif args.lsun_category is not None: # Download LSUN category automatically
os.makedirs('data', exist_ok=True)
args.path = download_lsun('data', args.lsun_category)
args.input_is_lmdb = True
else:
assert args.path is not None
create_dataset(args.out, args.path, args.size, args.pad, args.n_worker, args.format,
args.input_is_lmdb, args.pattern, args.max_images, args.spair_category, args.spair_split,
args.cub_acsm)