-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdataset.py
116 lines (96 loc) · 4.64 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
from PIL import Image
import torch
from torch.utils import data
import numpy as np
def randomCrop(image, label, flow, depth):
border = 30
image_width = image.size[0]
image_height = image.size[1]
label = Image.fromarray(label)
crop_win_width = np.random.randint(image_width - border, image_width)
crop_win_height = np.random.randint(image_height - border, image_height)
random_region = (
(image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
(image_height + crop_win_height) >> 1)
return image.crop(random_region), np.array(label.crop(random_region)), flow.crop(random_region), depth.crop(
random_region)
class Dataset(data.Dataset):
def __init__(self, datasets, mode='train', transform=None, return_size=True):
self.return_size = return_size
if type(datasets) != list:
datasets = [datasets]
self.datas_id = []
self.mode = mode
for (i, dataset) in enumerate(datasets):
if mode == 'train':
data_dir = './dataset/train/{}'.format(dataset)
imgset_path = data_dir + '/train.txt'
else:
data_dir = './dataset/test/{}'.format(dataset)
imgset_path = data_dir + '/test.txt'
imgset_file = open(imgset_path)
for line in imgset_file:
data = {}
img_path = line.strip("\n").split(" ")[0]
gt_path = line.strip("\n").split(" ")[1]
data['img_path'] = data_dir + img_path
data['gt_path'] = data_dir + gt_path
if dataset == 'DUTS-TR':
data['split'] = dataset
else:
data['flow_path'] = data_dir + line.strip("\n").split(" ")[2]
data['depth_path'] = data_dir + line.strip("\n").split(" ")[3]
data['split'] = img_path.split('/')[-3]
data['dataset'] = dataset
self.datas_id.append(data)
self.transform = transform
def __getitem__(self, item):
assert os.path.exists(self.datas_id[item]['img_path']), (
'{} does not exist'.format(self.datas_id[item]['img_path']))
assert os.path.exists(self.datas_id[item]['gt_path']), (
'{} does not exist'.format(self.datas_id[item]['gt_path']))
if self.datas_id[item]['dataset'] == 'DUTS-TR':
pass
else:
assert os.path.exists(self.datas_id[item]['depth_path']), (
'{} does not exist'.format(self.datas_id[item]['depth_path']))
assert os.path.exists(self.datas_id[item]['flow_path']), (
'{} does not exist'.format(self.datas_id[item]['flow_path']))
image = Image.open(self.datas_id[item]['img_path']).convert('RGB')
label = Image.open(self.datas_id[item]['gt_path']).convert('L')
label = np.array(label)
if self.datas_id[item]['dataset'] == 'DUTS-TR':
flow = np.zeros((image.size[1], image.size[0], 3))
flow = Image.fromarray(np.uint8(flow))
depth = np.zeros((image.size[1], image.size[0], 3))
depth = Image.fromarray(np.uint8(depth))
else:
flow = Image.open(self.datas_id[item]['flow_path']).convert('RGB')
depth = Image.open(self.datas_id[item]['depth_path']).convert('RGB')
if label.max() > 0:
label = label / 255
w, h = image.size
size = (h, w)
sample = {'image': image, 'label': label, 'flow': flow, 'depth': depth}
if self.mode == 'train':
sample['image'], sample['label'], sample['flow'], sample['depth'] = randomCrop(sample['image'],
sample['label'],
sample['flow'],
sample['depth'])
else:
pass
if self.transform:
sample = self.transform(sample)
if self.return_size:
sample['size'] = torch.tensor(size)
if self.datas_id[item]['dataset'] == 'DUTS-TR':
sample['flow'] = torch.zeros((3, 448, 448))
sample['depth'] = torch.zeros((3, 448, 448))
name = self.datas_id[item]['gt_path'].split('/')[-1]
sample['dataset'] = self.datas_id[item]['dataset']
sample['split'] = self.datas_id[item]['split']
sample['name'] = name
return sample
def __len__(self):
return len(self.datas_id)