Skip to content

Commit 7c6fa48

Browse files
igonroxvjiarui
andauthoredSep 16, 2020
Add support for custom classes (open-mmlab#71)
* Support for custom classes * Fix test * Fix pre-commit * Add pipeline logic for custom classes * Fix minor issues, fix test * Fix issues from PR review * Fix tests * Remove palette as str * Rename old_to_new_ids to label_map * Test for load_anns * Remove get_palette function * fixed temp * Add subset of palette, remove palette as arg * minor update Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
1 parent e2371a1 commit 7c6fa48

File tree

4 files changed

+233
-3
lines changed

4 files changed

+233
-3
lines changed
 

‎mmseg/datasets/custom.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class CustomDataset(Dataset):
5858
ignore_index (int): The label index to be ignored. Default: 255
5959
reduce_zero_label (bool): Whether to mark label zero as ignored.
6060
Default: False
61+
classes (str | Sequence[str], optional): Specify classes to load.
62+
If is None, ``cls.CLASSES`` will be used. Default: None.
6163
"""
6264

6365
CLASSES = None
@@ -74,7 +76,8 @@ def __init__(self,
7476
data_root=None,
7577
test_mode=False,
7678
ignore_index=255,
77-
reduce_zero_label=False):
79+
reduce_zero_label=False,
80+
classes=None):
7881
self.pipeline = Compose(pipeline)
7982
self.img_dir = img_dir
8083
self.img_suffix = img_suffix
@@ -85,6 +88,8 @@ def __init__(self,
8588
self.test_mode = test_mode
8689
self.ignore_index = ignore_index
8790
self.reduce_zero_label = reduce_zero_label
91+
self.label_map = None
92+
self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes)
8893

8994
# join paths if data_root is specified
9095
if self.data_root is not None:
@@ -160,6 +165,8 @@ def get_ann_info(self, idx):
160165
def pre_pipeline(self, results):
161166
"""Prepare results dict for pipeline."""
162167
results['seg_fields'] = []
168+
if self.custom_classes:
169+
results['label_map'] = self.label_map
163170

164171
def __getitem__(self, idx):
165172
"""Get training/test data after pipeline.
@@ -220,6 +227,10 @@ def get_gt_seg_maps(self):
220227
for img_info in self.img_infos:
221228
gt_seg_map = mmcv.imread(
222229
img_info['ann']['seg_map'], flag='unchanged', backend='pillow')
230+
# modify if custom classes
231+
if self.label_map is not None:
232+
for old_id, new_id in self.label_map.items():
233+
gt_seg_map[gt_seg_map == old_id] = new_id
223234
if self.reduce_zero_label:
224235
# avoid using underflow conversion
225236
gt_seg_map[gt_seg_map == 0] = 255
@@ -230,6 +241,63 @@ def get_gt_seg_maps(self):
230241

231242
return gt_seg_maps
232243

244+
def get_classes_and_palette(self, classes=None):
245+
"""Get class names of current dataset.
246+
247+
Args:
248+
classes (Sequence[str] | str | None): If classes is None, use
249+
default CLASSES defined by builtin dataset. If classes is a
250+
string, take it as a file name. The file contains the name of
251+
classes where each line contains one class name. If classes is
252+
a tuple or list, override the CLASSES defined by the dataset.
253+
"""
254+
if classes is None:
255+
self.custom_classes = False
256+
return self.CLASSES, self.PALETTE
257+
258+
self.custom_classes = True
259+
if isinstance(classes, str):
260+
# take it as a file path
261+
class_names = mmcv.list_from_file(classes)
262+
elif isinstance(classes, (tuple, list)):
263+
class_names = classes
264+
else:
265+
raise ValueError(f'Unsupported type {type(classes)} of classes.')
266+
267+
if self.CLASSES:
268+
if not set(classes).issubset(self.CLASSES):
269+
raise ValueError('classes is not a subset of CLASSES.')
270+
271+
# dictionary, its keys are the old label ids and its values
272+
# are the new label ids.
273+
# used for changing pixel labels in load_annotations.
274+
self.label_map = {}
275+
for i, c in enumerate(self.CLASSES):
276+
if c not in class_names:
277+
self.label_map[i] = -1
278+
else:
279+
self.label_map[i] = classes.index(c)
280+
281+
palette = self.get_palette_for_custom_classes()
282+
283+
return class_names, palette
284+
285+
def get_palette_for_custom_classes(self):
286+
287+
if self.label_map is not None:
288+
# return subset of palette
289+
palette = []
290+
for old_id, new_id in sorted(
291+
self.label_map.items(), key=lambda x: x[1]):
292+
if new_id != -1:
293+
palette.append(self.PALETTE[old_id])
294+
palette = type(self.PALETTE)(palette)
295+
296+
else:
297+
palette = self.PALETTE
298+
299+
return palette
300+
233301
def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
234302
"""Evaluate the dataset.
235303

‎mmseg/datasets/pipelines/loading.py

+4
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ def __call__(self, results):
132132
gt_semantic_seg = mmcv.imfrombytes(
133133
img_bytes, flag='unchanged',
134134
backend=self.imdecode_backend).squeeze().astype(np.uint8)
135+
# modify if custom classes
136+
if results.get('label_map', None) is not None:
137+
for old_id, new_id in results['label_map'].items():
138+
gt_semantic_seg[gt_semantic_seg == old_id] = new_id
135139
# reduce zero_label
136140
if self.reduce_zero_label:
137141
# avoid using underflow conversion

‎tests/test_data/test_dataset.py

+62-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import pytest
66

77
from mmseg.core.evaluation import get_classes, get_palette
8-
from mmseg.datasets import (ADE20KDataset, CityscapesDataset, ConcatDataset,
9-
CustomDataset, PascalVOCDataset, RepeatDataset)
8+
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
9+
ConcatDataset, CustomDataset, PascalVOCDataset,
10+
RepeatDataset)
1011

1112

1213
def test_classes():
@@ -171,3 +172,62 @@ def test_custom_dataset():
171172
assert 'mIoU' in eval_results
172173
assert 'mAcc' in eval_results
173174
assert 'aAcc' in eval_results
175+
176+
177+
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
178+
@patch('mmseg.datasets.CustomDataset.__getitem__',
179+
MagicMock(side_effect=lambda idx: idx))
180+
@pytest.mark.parametrize('dataset, classes', [
181+
('ADE20KDataset', ('wall', 'building')),
182+
('CityscapesDataset', ('road', 'sidewalk')),
183+
('CustomDataset', ('bus', 'car')),
184+
('PascalVOCDataset', ('aeroplane', 'bicycle')),
185+
])
186+
def test_custom_classes_override_default(dataset, classes):
187+
188+
dataset_class = DATASETS.get(dataset)
189+
190+
original_classes = dataset_class.CLASSES
191+
192+
# Test setting classes as a tuple
193+
custom_dataset = dataset_class(
194+
pipeline=[],
195+
img_dir=MagicMock(),
196+
split=MagicMock(),
197+
classes=classes,
198+
test_mode=True)
199+
200+
assert custom_dataset.CLASSES != original_classes
201+
assert custom_dataset.CLASSES == classes
202+
203+
# Test setting classes as a list
204+
custom_dataset = dataset_class(
205+
pipeline=[],
206+
img_dir=MagicMock(),
207+
split=MagicMock(),
208+
classes=list(classes),
209+
test_mode=True)
210+
211+
assert custom_dataset.CLASSES != original_classes
212+
assert custom_dataset.CLASSES == list(classes)
213+
214+
# Test overriding not a subset
215+
custom_dataset = dataset_class(
216+
pipeline=[],
217+
img_dir=MagicMock(),
218+
split=MagicMock(),
219+
classes=[classes[0]],
220+
test_mode=True)
221+
222+
assert custom_dataset.CLASSES != original_classes
223+
assert custom_dataset.CLASSES == [classes[0]]
224+
225+
# Test default behavior
226+
custom_dataset = dataset_class(
227+
pipeline=[],
228+
img_dir=MagicMock(),
229+
split=MagicMock(),
230+
classes=None,
231+
test_mode=True)
232+
233+
assert custom_dataset.CLASSES == original_classes

‎tests/test_data/test_loading.py

+98
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import copy
22
import os.path as osp
3+
import tempfile
34

5+
import mmcv
46
import numpy as np
57

68
from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile
@@ -98,3 +100,99 @@ def test_load_seg(self):
98100
# this image is saved by PIL
99101
assert results['gt_semantic_seg'].shape == (288, 512)
100102
assert results['gt_semantic_seg'].dtype == np.uint8
103+
104+
def test_load_seg_custom_classes(self):
105+
106+
test_img = np.random.rand(10, 10)
107+
test_gt = np.zeros_like(test_img)
108+
test_gt[2:4, 2:4] = 1
109+
test_gt[2:4, 6:8] = 2
110+
test_gt[6:8, 2:4] = 3
111+
test_gt[6:8, 6:8] = 4
112+
113+
tmp_dir = tempfile.TemporaryDirectory()
114+
img_path = osp.join(tmp_dir.name, 'img.jpg')
115+
gt_path = osp.join(tmp_dir.name, 'gt.png')
116+
117+
mmcv.imwrite(test_img, img_path)
118+
mmcv.imwrite(test_gt, gt_path)
119+
120+
# test only train with label with id 3
121+
results = dict(
122+
img_info=dict(filename=img_path),
123+
ann_info=dict(seg_map=gt_path),
124+
label_map={
125+
0: 0,
126+
1: 0,
127+
2: 0,
128+
3: 1,
129+
4: 0
130+
},
131+
seg_fields=[])
132+
133+
load_imgs = LoadImageFromFile()
134+
results = load_imgs(copy.deepcopy(results))
135+
136+
load_anns = LoadAnnotations()
137+
results = load_anns(copy.deepcopy(results))
138+
139+
gt_array = results['gt_semantic_seg']
140+
141+
true_mask = np.zeros_like(gt_array)
142+
true_mask[6:8, 2:4] = 1
143+
144+
assert results['seg_fields'] == ['gt_semantic_seg']
145+
assert gt_array.shape == (10, 10)
146+
assert gt_array.dtype == np.uint8
147+
np.testing.assert_array_equal(gt_array, true_mask)
148+
149+
# test only train with label with id 4 and 3
150+
results = dict(
151+
img_info=dict(filename=img_path),
152+
ann_info=dict(seg_map=gt_path),
153+
label_map={
154+
0: 0,
155+
1: 0,
156+
2: 0,
157+
3: 2,
158+
4: 1
159+
},
160+
seg_fields=[])
161+
162+
load_imgs = LoadImageFromFile()
163+
results = load_imgs(copy.deepcopy(results))
164+
165+
load_anns = LoadAnnotations()
166+
results = load_anns(copy.deepcopy(results))
167+
168+
gt_array = results['gt_semantic_seg']
169+
170+
true_mask = np.zeros_like(gt_array)
171+
true_mask[6:8, 2:4] = 2
172+
true_mask[6:8, 6:8] = 1
173+
174+
assert results['seg_fields'] == ['gt_semantic_seg']
175+
assert gt_array.shape == (10, 10)
176+
assert gt_array.dtype == np.uint8
177+
np.testing.assert_array_equal(gt_array, true_mask)
178+
179+
# test no custom classes
180+
results = dict(
181+
img_info=dict(filename=img_path),
182+
ann_info=dict(seg_map=gt_path),
183+
seg_fields=[])
184+
185+
load_imgs = LoadImageFromFile()
186+
results = load_imgs(copy.deepcopy(results))
187+
188+
load_anns = LoadAnnotations()
189+
results = load_anns(copy.deepcopy(results))
190+
191+
gt_array = results['gt_semantic_seg']
192+
193+
assert results['seg_fields'] == ['gt_semantic_seg']
194+
assert gt_array.shape == (10, 10)
195+
assert gt_array.dtype == np.uint8
196+
np.testing.assert_array_equal(gt_array, test_gt)
197+
198+
tmp_dir.cleanup()

0 commit comments

Comments
 (0)
Please sign in to comment.