-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathcreate_lmdb.py
151 lines (132 loc) · 5.32 KB
/
create_lmdb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Create lmdb files for [General images (291 images/DIV2K) | Vimeo90K | REDS] training datasets"""
import sys
import os.path as osp
import glob
import pickle
from multiprocessing import Pool
import numpy as np
import lmdb
import cv2
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
import data.util as data_util # noqa: E402
import utils.util as util # noqa: E402
def main():
mode = 'GT' # used for vimeo90k and REDS datasets
vimeo90k(mode)
def read_image_worker(path, key):
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
return (key, img)
def vimeo90k(mode):
"""Create lmdb for the Vimeo90K dataset, each image with a fixed size
GT: [3, 256, 448]
Now only need the 4th frame, e.g., 00001_0001_4
key:
Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
"""
#### configurations
read_all_imgs = False # whether real all images to memory with multiprocessing
# Set False for use limited memory
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
# if mode == 'GT':
# input datafiles here
img_folder = '../Video/vimeo_septuplet/sequences/'
lmdb_save_path = '../Video/Vimeo/vimeo90k_train_GT.lmdb'
txt_file = '../Video/vimeo_septuplet/sep_trainlist.txt'
H_dst, W_dst = 256, 448
n_thread = 40
########################################################
if not lmdb_save_path.endswith('.lmdb'):
raise ValueError("lmdb_save_path must end with \'lmdb\'.")
if osp.exists(lmdb_save_path):
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
sys.exit(1)
#### read all the image paths to a list
print('Reading image path list ...')
with open(txt_file) as f:
train_l = f.readlines()
train_l = [v.strip() for v in train_l]
all_img_list = []
keys = []
for line in train_l:
folder = line.split('/')[0]
sub_folder = line.split('/')[1]
all_img_list.extend(glob.glob(osp.join(img_folder, folder, sub_folder, '*')))
for j in range(7):
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
all_img_list = sorted(all_img_list)
keys = sorted(keys)
if read_all_imgs:
#### read all images to memory (multiprocessing)
dataset = {} # store all image data. list cannot keep the order, use dict
print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
pbar = util.ProgressBar(len(all_img_list))
def mycallback(arg):
"""get the image data and update pbar"""
key = arg[0]
dataset[key] = arg[1]
pbar.update('Reading {}'.format(key))
pool = Pool(n_thread)
for path, key in zip(all_img_list, keys):
pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
pool.close()
pool.join()
print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))
#### write data to lmdb
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
print('data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(all_img_list)
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
txn = env.begin(write=True)
pbar = util.ProgressBar(len(all_img_list))
for idx, (path, key) in enumerate(zip(all_img_list, keys)):
pbar.update('Write {}'.format(key))
key_byte = key.encode('ascii')
data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
if 'flow' in mode:
H, W = data.shape
assert H == H_dst and W == W_dst, 'different shape.'
else:
H, W, C = data.shape
assert H == H_dst and W == W_dst and C == 3, 'different shape.'
txn.put(key_byte, data)
if not read_all_imgs and idx % BATCH == 0:
txn.commit()
txn = env.begin(write=True)
txn.commit()
env.close()
print('Finish writing lmdb.')
#### create meta information
meta_info = {}
meta_info['name'] = 'Vimeo90K_train_GT'
channel = 3
meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
key_set = set()
for key in keys:
if mode == 'flow':
a, b, _, _ = key.split('_')
else:
a, b, _ = key.split('_')
key_set.add('{}_{}'.format(a, b))
meta_info['keys'] = list(key_set)
pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
print('Finish creating lmdb meta info.')
def test_lmdb(dataroot, dataset='REDS'):
env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb"))
print('Name: ', meta_info['name'])
print('Resolution: ', meta_info['resolution'])
print('# keys: ', len(meta_info['keys']))
# read one image
if dataset == 'vimeo90k':
key = '00001_0001_4'
else:
key = '000_00000000'
print('Reading {} for test.'.format(key))
with env.begin(write=False) as txn:
buf = txn.get(key.encode('ascii'))
img_flat = np.frombuffer(buf, dtype=np.uint8)
C, H, W = [int(s) for s in meta_info['resolution'].split('_')]
img = img_flat.reshape(H, W, C)
cv2.imwrite('test.png', img)
if __name__ == "__main__":
main()