forked from meetps/pytorch-semseg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmit_sceneparsing_benchmark_loader.py
111 lines (82 loc) · 3.48 KB
/
mit_sceneparsing_benchmark_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
import numpy as np
import scipy.misc as m
from torch.utils import data
from ptsemseg.utils import recursive_glob
class MITSceneParsingBenchmarkLoader(data.Dataset):
"""MITSceneParsingBenchmarkLoader
http://sceneparsing.csail.mit.edu/
Data is derived from ADE20k, and can be downloaded from here:
http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip
NOTE: this loader is not designed to work with the original ADE20k dataset;
for that you will need the ADE20kLoader
This class can also be extended to load data for places challenge:
https://github.com/CSAILVision/placeschallenge/tree/master/sceneparsing
"""
def __init__(self, root, split="training", is_transform=False, img_size=512, augmentations=None, img_norm=True):
"""__init__
:param root:
:param split:
:param is_transform:
:param img_size:
"""
self.root = root
self.split = split
self.is_transform = is_transform
self.augmentations = augmentations
self.img_norm = img_norm
self.n_classes = 151 # 0 is reserved for "other"
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.mean = np.array([104.00699, 116.66877, 122.67892])
self.files = {}
self.images_base = os.path.join(self.root, 'images', self.split)
self.annotations_base = os.path.join(self.root, 'annotations', self.split)
self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.jpg')
if not self.files[split]:
raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
print("Found %d %s images" % (len(self.files[split]), split))
def __len__(self):
"""__len__"""
return len(self.files[self.split])
def __getitem__(self, index):
"""__getitem__
:param index:
"""
img_path = self.files[self.split][index].rstrip()
lbl_path = os.path.join(self.annotations_base, os.path.basename(img_path)[:-4] + '.png')
img = m.imread(img_path)
img = np.array(img, dtype=np.uint8)
lbl = m.imread(lbl_path)
lbl = np.array(lbl, dtype=np.uint8)
if self.augmentations is not None:
img, lbl = self.augmentations(img, lbl)
if self.is_transform:
img, lbl = self.transform(img, lbl)
return img, lbl
def transform(self, img, lbl):
"""transform
:param img:
:param lbl:
"""
img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float64)
img -= self.mean
if self.img_norm:
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCHW
img = img.transpose(2, 0, 1)
classes = np.unique(lbl)
lbl = lbl.astype(float)
lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), 'nearest', mode='F')
lbl = lbl.astype(int)
if not np.all(classes == np.unique(lbl)):
print("WARN: resizing labels yielded fewer classes")
if not np.all(np.unique(lbl) < self.n_classes):
raise ValueError("Segmentation map contained invalid class values")
img = torch.from_numpy(img).float()
lbl = torch.from_numpy(lbl).long()
return img, lbl