-
Notifications
You must be signed in to change notification settings - Fork 16
/
pargen.py
109 lines (83 loc) · 3.28 KB
/
pargen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Data generator, multi-threaded
# 多线程数据生成器,极大加快数据生成速度,并优化了内存占用。
# NOTE: 需要注意随机数生成的线程安全问题,目前没有遇到问题。
"""
NOTE: Some of the random functions (e.g. random.gauss()) are not thread-safe,
and would generate same values across different thread.
The functions used in this project (randint, randn) should be thread-safe.
"""
from typing import Tuple
from torch.functional import Tensor
from mona.config import config
import os
import pathlib
from multiprocessing import Pool
import torchvision.transforms as transforms
import torch
import datetime
from itertools import chain
from mona.datagen.datagen import generate_image
def progressBar(current, total, barLength=40):
percent = float(current) * 100 / total
arrow = '-' * int(percent/100 * barLength - 1) + '>'
spaces = ' ' * (barLength - len(arrow))
print('Progress: [%s%s] %d %%' % (arrow, spaces, percent), end='\r')
def fill_data(target_tensor_slice: Tensor) -> list:
"""Fill a given tensor slice with generated image
Args:
target_tensor_slice (Tensor): Tensor slice to store the generated image
Returns:
list: List of labels
"""
length = target_tensor_slice.shape[0]
y = []
for i in range(length):
im, text = generate_image()
tensor = transforms.ToTensor()(im)
tensor = torch.unsqueeze(tensor, dim=0)
# NOTE: here tensor.shape == [1, 1, 32, 384]
target_tensor_slice[i] = tensor
y.append(text)
# Worker print progress
if i % 100 == 0:
progressBar(i, length)
return y
def gen_dataset_with_label(size, threads=2) -> Tuple[Tensor, list]:
# Allocate the output Tensor, and split into sub-tensors (views, copy-free)
x = torch.zeros((size, 1, 32, 384))
x_split = torch.tensor_split(x, threads, dim=0)
with Pool(threads) as p:
print(f"Starting threadpool with {threads} threads.")
labels = p.map(fill_data, x_split)
print("\nStopping threadpool.")
return x, list(chain.from_iterable(labels))
if __name__ == '__main__':
train_size = config["train_size"]
validate_size = config["validate_size"]
folder = pathlib.Path("data")
if not folder.is_dir():
os.mkdir(folder)
# Use physical cores only
threads = max(1, os.cpu_count() // 2)
print(
f"Train size {train_size}, Val size {validate_size}, Thread count {threads}")
# Generate and save training set
print(f"{datetime.datetime.now()} Generating training data")
x, y = gen_dataset_with_label(size=train_size, threads=threads)
print(f"{datetime.datetime.now()} Saving training data")
torch.save(x, "data/train_x.pt")
torch.save(y, "data/train_label.pt")
del x, y
print(f"{datetime.datetime.now()} Generating validation data")
x, y = gen_dataset_with_label(size=validate_size, threads=threads)
print(f"{datetime.datetime.now()} Saving validation data")
torch.save(x, "data/validate_x.pt")
torch.save(y, "data/validate_label.pt")
# Verify the result
# for tensor, y in zip(x,y):
# arr = tensor.squeeze()
# im = Image.fromarray(np.uint8(arr * 255))
# im.show()
# print(y)
# import time
# time.sleep(1)