-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for custom classes #71
Changes from 11 commits
06d5562
38c617c
538819e
f0ed77f
993e290
cda80d2
5b1af68
f8ccb85
d88f8f9
ea29d50
7478a42
d0a2763
04c5547
36a309a
b13aec1
5d0fda4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,10 @@ class CustomDataset(Dataset): | |
ignore_index (int): The label index to be ignored. Default: 255 | ||
reduce_zero_label (bool): Whether to mark label zero as ignored. | ||
Default: False | ||
classes (str | Sequence[str], optional): Specify classes to load. | ||
If is None, ``cls.CLASSES`` will be used. Default: None. | ||
palette (Sequence[str], optional): Specify palette to load. | ||
If is None, ``cls.PALETTE`` will be used. Default: None. | ||
""" | ||
|
||
CLASSES = None | ||
|
@@ -74,7 +78,9 @@ def __init__(self, | |
data_root=None, | ||
test_mode=False, | ||
ignore_index=255, | ||
reduce_zero_label=False): | ||
reduce_zero_label=False, | ||
classes=None, | ||
palette=None): | ||
self.pipeline = Compose(pipeline) | ||
self.img_dir = img_dir | ||
self.img_suffix = img_suffix | ||
|
@@ -85,6 +91,8 @@ def __init__(self, | |
self.test_mode = test_mode | ||
self.ignore_index = ignore_index | ||
self.reduce_zero_label = reduce_zero_label | ||
self.CLASSES, self.PALETTE = self.get_classes_and_palette( | ||
classes, palette) | ||
|
||
# join paths if data_root is specified | ||
if self.data_root is not None: | ||
|
@@ -160,6 +168,8 @@ def get_ann_info(self, idx): | |
def pre_pipeline(self, results): | ||
"""Prepare results dict for pipeline.""" | ||
results['seg_fields'] = [] | ||
if self.custom_classes: | ||
results['label_map'] = self.label_map | ||
|
||
def __getitem__(self, idx): | ||
"""Get training/test data after pipeline. | ||
|
@@ -220,6 +230,10 @@ def get_gt_seg_maps(self): | |
for img_info in self.img_infos: | ||
gt_seg_map = mmcv.imread( | ||
img_info['ann']['seg_map'], flag='unchanged', backend='pillow') | ||
# modify if custom classes | ||
if hasattr(self, 'label_map'): | ||
for old_id, new_id in self.label_map.items(): | ||
gt_seg_map[gt_seg_map == old_id] = new_id | ||
if self.reduce_zero_label: | ||
# avoid using underflow conversion | ||
gt_seg_map[gt_seg_map == 0] = 255 | ||
|
@@ -230,6 +244,49 @@ def get_gt_seg_maps(self): | |
|
||
return gt_seg_maps | ||
|
||
def get_classes_and_palette(self, classes=None, palette=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"""Get class names of current dataset. | ||
|
||
Args: | ||
classes (Sequence[str] | str | None): If classes is None, use | ||
default CLASSES defined by builtin dataset. If classes is a | ||
string, take it as a file name. The file contains the name of | ||
classes where each line contains one class name. If classes is | ||
a tuple or list, override the CLASSES defined by the dataset. | ||
palette (Sequence[str] | None): If palette is None, use | ||
default PALETTE defined by builtin dataset. If palette | ||
is a tuple or list, override the PALETTE defined by the | ||
dataset. | ||
""" | ||
if classes is None: | ||
self.custom_classes = False | ||
return self.CLASSES, self.PALETTE | ||
|
||
self.custom_classes = True | ||
if isinstance(classes, str): | ||
# take it as a file path | ||
class_names = mmcv.list_from_file(classes) | ||
elif isinstance(classes, (tuple, list)): | ||
class_names = classes | ||
else: | ||
raise ValueError(f'Unsupported type {type(classes)} of classes.') | ||
|
||
if self.CLASSES: | ||
if not set(classes).issubset(self.CLASSES): | ||
raise ValueError('classes is not a subset of CLASSES.') | ||
|
||
# dictionary, its keys are the old label ids and its values | ||
# are the new label ids. | ||
# used for changing pixel labels in load_annotations. | ||
self.label_map = {} | ||
for i, c in enumerate(self.CLASSES): | ||
if c not in class_names: | ||
self.label_map[i] = -1 | ||
else: | ||
self.label_map[i] = classes.index(c) | ||
|
||
return class_names, palette | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to return the subset of |
||
|
||
def evaluate(self, results, metric='mIoU', logger=None, **kwargs): | ||
"""Evaluate the dataset. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,8 +5,9 @@ | |
import pytest | ||
|
||
from mmseg.core.evaluation import get_classes, get_palette | ||
from mmseg.datasets import (ADE20KDataset, CityscapesDataset, ConcatDataset, | ||
CustomDataset, PascalVOCDataset, RepeatDataset) | ||
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset, | ||
ConcatDataset, CustomDataset, PascalVOCDataset, | ||
RepeatDataset) | ||
|
||
|
||
def test_classes(): | ||
|
@@ -171,3 +172,62 @@ def test_custom_dataset(): | |
assert 'mIoU' in eval_results | ||
assert 'mAcc' in eval_results | ||
assert 'aAcc' in eval_results | ||
|
||
|
||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) | ||
@patch('mmseg.datasets.CustomDataset.__getitem__', | ||
MagicMock(side_effect=lambda idx: idx)) | ||
@pytest.mark.parametrize('dataset, classes', [ | ||
('ADE20KDataset', ('wall', 'building')), | ||
('CityscapesDataset', ('road', 'sidewalk')), | ||
('CustomDataset', ('bus', 'car')), | ||
('PascalVOCDataset', ('aeroplane', 'bicycle')), | ||
]) | ||
def test_custom_classes_override_default(dataset, classes): | ||
|
||
dataset_class = DATASETS.get(dataset) | ||
|
||
original_classes = dataset_class.CLASSES | ||
|
||
# Test setting classes as a tuple | ||
custom_dataset = dataset_class( | ||
pipeline=[], | ||
img_dir=MagicMock(), | ||
split=MagicMock(), | ||
classes=classes, | ||
test_mode=True) | ||
|
||
assert custom_dataset.CLASSES != original_classes | ||
assert custom_dataset.CLASSES == classes | ||
|
||
# Test setting classes as a list | ||
custom_dataset = dataset_class( | ||
pipeline=[], | ||
img_dir=MagicMock(), | ||
split=MagicMock(), | ||
classes=list(classes), | ||
test_mode=True) | ||
|
||
assert custom_dataset.CLASSES != original_classes | ||
assert custom_dataset.CLASSES == list(classes) | ||
|
||
# Test overriding not a subset | ||
custom_dataset = dataset_class( | ||
pipeline=[], | ||
img_dir=MagicMock(), | ||
split=MagicMock(), | ||
classes=[classes[0]], | ||
test_mode=True) | ||
|
||
assert custom_dataset.CLASSES != original_classes | ||
assert custom_dataset.CLASSES == [classes[0]] | ||
|
||
# Test default behavior | ||
custom_dataset = dataset_class( | ||
pipeline=[], | ||
img_dir=MagicMock(), | ||
split=MagicMock(), | ||
classes=None, | ||
test_mode=True) | ||
|
||
assert custom_dataset.CLASSES == original_classes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may add a test case for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I add it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may add it in |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
import copy | ||
import os | ||
import os.path as osp | ||
|
||
import numpy as np | ||
from mmcv.image.io import imwrite | ||
|
||
from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile | ||
|
||
|
@@ -98,3 +100,99 @@ def test_load_seg(self): | |
# this image is saved by PIL | ||
assert results['gt_semantic_seg'].shape == (288, 512) | ||
assert results['gt_semantic_seg'].dtype == np.uint8 | ||
|
||
def test_load_seg_custom_classes(self): | ||
|
||
test_img = np.random.rand(10, 10) | ||
test_gt = np.zeros_like(test_img) | ||
test_gt[2:4, 2:4] = 1 | ||
test_gt[2:4, 6:8] = 2 | ||
test_gt[6:8, 2:4] = 3 | ||
test_gt[6:8, 6:8] = 4 | ||
|
||
img_path = osp.join(osp.dirname(__file__), 'img.jpg') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've tried using |
||
gt_path = osp.join(osp.dirname(__file__), 'gt.png') | ||
|
||
imwrite(test_img, img_path) | ||
imwrite(test_gt, gt_path) | ||
|
||
# test only train with label with id 3 | ||
results = dict( | ||
img_info=dict(filename=img_path), | ||
ann_info=dict(seg_map=gt_path), | ||
label_map={ | ||
0: 0, | ||
1: 0, | ||
2: 0, | ||
3: 1, | ||
4: 0 | ||
}, | ||
seg_fields=[]) | ||
|
||
load_imgs = LoadImageFromFile() | ||
results = load_imgs(copy.deepcopy(results)) | ||
|
||
load_anns = LoadAnnotations() | ||
results = load_anns(copy.deepcopy(results)) | ||
|
||
gt_array = results['gt_semantic_seg'] | ||
|
||
true_mask = np.zeros_like(gt_array) | ||
true_mask[6:8, 2:4] = 1 | ||
|
||
assert results['seg_fields'] == ['gt_semantic_seg'] | ||
assert gt_array.shape == (10, 10) | ||
assert gt_array.dtype == np.uint8 | ||
np.testing.assert_array_equal(gt_array, true_mask) | ||
|
||
# test only train with label with id 4 and 3 | ||
results = dict( | ||
img_info=dict(filename=img_path), | ||
ann_info=dict(seg_map=gt_path), | ||
label_map={ | ||
0: 0, | ||
1: 0, | ||
2: 0, | ||
3: 2, | ||
4: 1 | ||
}, | ||
seg_fields=[]) | ||
|
||
load_imgs = LoadImageFromFile() | ||
results = load_imgs(copy.deepcopy(results)) | ||
|
||
load_anns = LoadAnnotations() | ||
results = load_anns(copy.deepcopy(results)) | ||
|
||
gt_array = results['gt_semantic_seg'] | ||
|
||
true_mask = np.zeros_like(gt_array) | ||
true_mask[6:8, 2:4] = 2 | ||
true_mask[6:8, 6:8] = 1 | ||
|
||
assert results['seg_fields'] == ['gt_semantic_seg'] | ||
assert gt_array.shape == (10, 10) | ||
assert gt_array.dtype == np.uint8 | ||
np.testing.assert_array_equal(gt_array, true_mask) | ||
|
||
# test no custom classes | ||
results = dict( | ||
img_info=dict(filename=img_path), | ||
ann_info=dict(seg_map=gt_path), | ||
seg_fields=[]) | ||
|
||
load_imgs = LoadImageFromFile() | ||
results = load_imgs(copy.deepcopy(results)) | ||
|
||
load_anns = LoadAnnotations() | ||
results = load_anns(copy.deepcopy(results)) | ||
|
||
gt_array = results['gt_semantic_seg'] | ||
|
||
assert results['seg_fields'] == ['gt_semantic_seg'] | ||
assert gt_array.shape == (10, 10) | ||
assert gt_array.dtype == np.uint8 | ||
np.testing.assert_array_equal(gt_array, test_gt) | ||
|
||
os.remove(img_path) | ||
os.remove(gt_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
palette
arg is not needed.