-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathimagenet_split.py
81 lines (66 loc) · 2.58 KB
/
imagenet_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import random
import shutil
dataset_path = 'data/imagenet'
split_dataset_path = 'data/imagenet_split'
images_num = 1281167
train_portion = 0.1
val_portion = 0.025
CLASSES = 1000
random.seed(0)
def get_image_list(path):
imgs = []
for root, dirs, files in os.walk(path):
for file in files:
if '.JPEG' in file:
imgs.append(file)
return imgs
def split():
assert os.path.exists(dataset_path)
train_path = os.path.join(dataset_path, 'train')
# val_path = os.path.join(dataset_path, 'val')
split_train_path = os.path.join(split_dataset_path, 'train')
split_val_path = os.path.join(split_dataset_path, 'val')
if not os.path.exists(split_dataset_path):
os.makedirs(split_train_path)
os.makedirs(split_val_path)
for fn in os.listdir(train_path):
imgs = get_image_list(os.path.join(train_path, fn))
train_num_per_class = int(len(imgs) * train_portion)
val_num_per_class = int(len(imgs) * val_portion)
# print(train_num_per_class, val_num_per_class)
print('{}, imgs num: {}, train: {}, val: {}'.format(fn, len(imgs), train_num_per_class, val_num_per_class))
random.shuffle(imgs)
if not os.path.exists(os.path.join(split_train_path, fn)):
os.makedirs(os.path.join(split_train_path, fn))
for i in range(train_num_per_class):
source = os.path.join(train_path, fn, imgs[i])
target = os.path.join(split_train_path, fn, imgs[i])
shutil.copy(source, target)
if not os.path.exists(os.path.join(split_val_path, fn)):
os.makedirs(os.path.join(split_val_path, fn))
for i in range(train_num_per_class, train_num_per_class + val_num_per_class):
source = os.path.join(train_path, fn, imgs[i])
target = os.path.join(split_val_path, fn, imgs[i])
shutil.copy(source, target)
def check():
import warnings
from PIL import Image
train_path = os.path.join(dataset_path, 'train')
val_path = os.path.join(dataset_path, 'val')
paths = [train_path, val_path]
count = 0
for path in paths:
for fn in os.listdir(path):
imgs = get_image_list(os.path.join(path, fn))
for img in imgs:
img_path = os.path.join(path, fn, img)
try:
img = Image.open(img_path)
count += 1
if count % 1000 == 0:
print(count)
except:
print('corrupt img', img_path)
if __name__ == "__main__":
split()