-
Notifications
You must be signed in to change notification settings - Fork 41
/
dataset.py
104 lines (89 loc) · 2.79 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python
import os
import numpy as np
import scipy.misc
import chainer
import utils
class DatasetMixin(chainer.dataset.DatasetMixin):
label_names = None
mean_bgr = None
def label_rgb_to_32sc1(self, label_rgb):
assert label_rgb.dtype == np.uint8
label = np.zeros(label_rgb.shape[:2], dtype=np.int32)
label.fill(-1)
cmap = utils.labelcolormap(len(self.label_names))
cmap = (cmap * 255).astype(np.uint8)
for l, rgb in enumerate(cmap):
mask = np.all(label_rgb == rgb, axis=-1)
label[mask] = l
return label
def img_to_datum(self, img):
img = img.copy()
datum = img.astype(np.float32)
datum = datum[:, :, ::-1] # RGB -> BGR
datum -= self.mean_bgr
datum = datum.transpose((2, 0, 1))
return datum
def datum_to_img(self, datum):
datum = datum.copy()
bgr = datum.transpose((1, 2, 0))
bgr += self.mean_bgr
rgb = bgr[:, :, ::-1] # BGR -> RGB
rgb = rgb.astype(np.uint8)
return rgb
class PascalVOC2012Dataset(DatasetMixin):
label_names = np.array([
'background',
'aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'diningtable',
'dog',
'horse',
'motorbike',
'person',
'potted plant',
'sheep',
'sofa',
'train',
'tv/monitor',
])
mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
def __init__(self, data_type):
# get ids for the data_type
dataset_dir = chainer.dataset.get_dataset_directory(
'pascal/VOCdevkit/VOC2012')
imgsets_file = os.path.join(
dataset_dir,
'ImageSets/Segmentation/{}.txt'.format(data_type))
self.files = []
for data_id in open(imgsets_file).readlines():
data_id = data_id.strip()
img_file = os.path.join(
dataset_dir, 'JPEGImages/{}.jpg'.format(data_id))
label_rgb_file = os.path.join(
dataset_dir, 'SegmentationClass/{}.png'.format(data_id))
self.files.append({
'img': img_file,
'label_rgb': label_rgb_file,
})
def __len__(self):
return len(self.files)
def get_example(self, i):
data_file = self.files[i]
# load image
img_file = data_file['img']
img = scipy.misc.imread(img_file, mode='RGB')
datum = self.img_to_datum(img)
# load label
label_rgb_file = data_file['label_rgb']
label_rgb = scipy.misc.imread(label_rgb_file, mode='RGB')
label = self.label_rgb_to_32sc1(label_rgb)
return datum, label