forked from hughplay/DFNet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata.py
69 lines (54 loc) · 2.38 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
class DS(Dataset):
def __init__(self, root, transform=None):
self.samples = []
for root, _, fnames in sorted(os.walk(root)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
self.samples.append(path)
if len(self.samples) == 0:
raise RuntimeError("Found 0 files in subfolders of: " + root)
self.transform = transform
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample_path = self.samples[index]
sample = Image.open(sample_path).convert('RGB')
if self.transform:
sample = self.transform(sample)
mask = DS.random_mask()
mask = torch.from_numpy(mask)
return sample, mask
@staticmethod
def random_mask(height=256, width=256,
min_stroke=1, max_stroke=4,
min_vertex=1, max_vertex=12,
min_brush_width_divisor=16, max_brush_width_divisor=10):
mask = np.ones((height, width))
min_brush_width = height // min_brush_width_divisor
max_brush_width = height // max_brush_width_divisor
max_angle = 2*np.pi
num_stroke = np.random.randint(min_stroke, max_stroke+1)
average_length = np.sqrt(height*height + width*width) / 8
for _ in range(num_stroke):
num_vertex = np.random.randint(min_vertex, max_vertex+1)
start_x = np.random.randint(width)
start_y = np.random.randint(height)
for _ in range(num_vertex):
angle = np.random.uniform(max_angle)
length = np.clip(np.random.normal(average_length, average_length//2), 0, 2*average_length)
brush_width = np.random.randint(min_brush_width, max_brush_width+1)
end_x = (start_x + length * np.sin(angle)).astype(np.int32)
end_y = (start_y + length * np.cos(angle)).astype(np.int32)
cv2.line(mask, (start_y, start_x), (end_y, end_x), 0., brush_width)
start_x, start_y = end_x, end_y
if np.random.random() < 0.5:
mask = np.fliplr(mask)
if np.random.random() < 0.5:
mask = np.flipud(mask)
return mask.reshape((1,)+mask.shape).astype(np.float32)