This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
forked from Yukariin/CSA_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
94 lines (72 loc) · 2.93 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils import data
class InfiniteSampler(data.sampler.Sampler):
def __init__(self, num_samples):
self.num_samples = num_samples
def __iter__(self):
return iter(self.loop())
def __len__(self):
return 2 ** 31
def loop(self):
i = 0
order = np.random.permutation(self.num_samples)
while True:
yield order[i]
i += 1
if i >= self.num_samples:
np.random.seed()
order = np.random.permutation(self.num_samples)
i = 0
class DS(data.Dataset):
def __init__(self, root, transform=None):
self.samples = []
for root, _, fnames in sorted(os.walk(root)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
self.samples.append(path)
if len(self.samples) == 0:
raise RuntimeError("Found 0 files in subfolders of: " + root)
self.transform = transform
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample_path = self.samples[index]
sample = Image.open(sample_path).convert('RGB')
if self.transform is not None:
sample = self.transform(sample)
mask = self.random_mask()
mask = torch.from_numpy(mask)
return sample, mask
@staticmethod
def random_mask(height=256, width=256, pad=50,
min_stroke=2, max_stroke=5,
min_vertex=2, max_vertex=12,
min_brush_width=7, max_brush_width=20,
min_lenght=10, max_length=50):
mask = np.zeros((height, width))
max_angle = 2*np.pi
num_stroke = np.random.randint(min_stroke, max_stroke+1)
for _ in range(num_stroke):
num_vertex = np.random.randint(min_vertex, max_vertex+1)
brush_width = np.random.randint(min_brush_width, max_brush_width+1)
start_x = np.random.randint(pad, height-pad)
start_y = np.random.randint(pad, width-pad)
for _ in range(num_vertex):
angle = np.random.uniform(max_angle)
length = np.random.randint(min_lenght, max_length+1)
#length = np.random.randint(min_lenght, height//num_vertex)
end_x = (start_x + length * np.sin(angle)).astype(np.int32)
end_y = (start_y + length * np.cos(angle)).astype(np.int32)
end_x = max(0, min(end_x, height))
end_y = max(0, min(end_y, width))
cv2.line(mask, (start_x, start_y), (end_x, end_y), 1., brush_width)
start_x, start_y = end_x, end_y
if np.random.random() < 0.5:
mask = np.fliplr(mask)
if np.random.random() < 0.5:
mask = np.flipud(mask)
return mask.reshape((1,)+mask.shape).astype(np.float32)