Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom classes #71

Merged
merged 16 commits into from
Sep 16, 2020
59 changes: 58 additions & 1 deletion mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class CustomDataset(Dataset):
ignore_index (int): The label index to be ignored. Default: 255
reduce_zero_label (bool): Whether to mark label zero as ignored.
Default: False
classes (str | Sequence[str], optional): Specify classes to load.
If is None, ``cls.CLASSES`` will be used. Default: None.
palette (Sequence[str], optional): Specify palette to load.
If is None, ``cls.PALETTE`` will be used. Default: None.
"""

CLASSES = None
Expand All @@ -74,7 +78,9 @@ def __init__(self,
data_root=None,
test_mode=False,
ignore_index=255,
reduce_zero_label=False):
reduce_zero_label=False,
classes=None,
palette=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

palette arg is not needed.

self.pipeline = Compose(pipeline)
self.img_dir = img_dir
self.img_suffix = img_suffix
Expand All @@ -85,6 +91,8 @@ def __init__(self,
self.test_mode = test_mode
self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)

# join paths if data_root is specified
if self.data_root is not None:
Expand Down Expand Up @@ -160,6 +168,8 @@ def get_ann_info(self, idx):
def pre_pipeline(self, results):
"""Prepare results dict for pipeline."""
results['seg_fields'] = []
if self.custom_classes:
results['label_map'] = self.label_map

def __getitem__(self, idx):
"""Get training/test data after pipeline.
Expand Down Expand Up @@ -220,6 +230,10 @@ def get_gt_seg_maps(self):
for img_info in self.img_infos:
gt_seg_map = mmcv.imread(
img_info['ann']['seg_map'], flag='unchanged', backend='pillow')
# modify if custom classes
if hasattr(self, 'label_map'):
for old_id, new_id in self.label_map.items():
gt_seg_map[gt_seg_map == old_id] = new_id
if self.reduce_zero_label:
# avoid using underflow conversion
gt_seg_map[gt_seg_map == 0] = 255
Expand All @@ -230,6 +244,49 @@ def get_gt_seg_maps(self):

return gt_seg_maps

def get_classes_and_palette(self, classes=None, palette=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

palette arg is not needed.

"""Get class names of current dataset.

Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
palette (Sequence[str] | None): If palette is None, use
default PALETTE defined by builtin dataset. If palette
is a tuple or list, override the PALETTE defined by the
dataset.
"""
if classes is None:
self.custom_classes = False
return self.CLASSES, self.PALETTE

self.custom_classes = True
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')

if self.CLASSES:
if not set(classes).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.')

# dictionary, its keys are the old label ids and its values
# are the new label ids.
# used for changing pixel labels in load_annotations.
self.label_map = {}
for i, c in enumerate(self.CLASSES):
if c not in class_names:
self.label_map[i] = -1
else:
self.label_map[i] = classes.index(c)

return class_names, palette
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to return the subset of PALETTE since we are using a subset of CLASSES.


def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
"""Evaluate the dataset.

Expand Down
4 changes: 4 additions & 0 deletions mmseg/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def __call__(self, results):
gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze().astype(np.uint8)
# modify if custom classes
if results.get('label_map', None) is not None:
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg == old_id] = new_id
# reduce zero_label
if self.reduce_zero_label:
# avoid using underflow conversion
Expand Down
64 changes: 62 additions & 2 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pytest

from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (ADE20KDataset, CityscapesDataset, ConcatDataset,
CustomDataset, PascalVOCDataset, RepeatDataset)
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
ConcatDataset, CustomDataset, PascalVOCDataset,
RepeatDataset)


def test_classes():
Expand Down Expand Up @@ -171,3 +172,62 @@ def test_custom_dataset():
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results


@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),
('CustomDataset', ('bus', 'car')),
('PascalVOCDataset', ('aeroplane', 'bicycle')),
])
def test_custom_classes_override_default(dataset, classes):

dataset_class = DATASETS.get(dataset)

original_classes = dataset_class.CLASSES

# Test setting classes as a tuple
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=classes,
test_mode=True)

assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == classes

# Test setting classes as a list
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=list(classes),
test_mode=True)

assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == list(classes)

# Test overriding not a subset
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=[classes[0]],
test_mode=True)

assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == [classes[0]]

# Test default behavior
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=None,
test_mode=True)

assert custom_dataset.CLASSES == original_classes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may add a test case for LoadAnnotation pipeline when there is custom classes.
To achieve this, we may generate a random ground truth segmentation map.

Copy link
Contributor Author

@igonro igonro Sep 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I add it in test_loading.py or in this same file? Also I'm not very sure what type of test I should add.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may add it in test_loading.py.
We should check if the loaded annotation is of custom classes as desired.

98 changes: 98 additions & 0 deletions tests/test_data/test_loading.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
import os
import os.path as osp

import numpy as np
from mmcv.image.io import imwrite

from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile

Expand Down Expand Up @@ -98,3 +100,99 @@ def test_load_seg(self):
# this image is saved by PIL
assert results['gt_semantic_seg'].shape == (288, 512)
assert results['gt_semantic_seg'].dtype == np.uint8

def test_load_seg_custom_classes(self):

test_img = np.random.rand(10, 10)
test_gt = np.zeros_like(test_img)
test_gt[2:4, 2:4] = 1
test_gt[2:4, 6:8] = 2
test_gt[6:8, 2:4] = 3
test_gt[6:8, 6:8] = 4

img_path = osp.join(osp.dirname(__file__), 'img.jpg')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried using tmpdir and tmp_path for doing this, but it raises an error when calling imwrite (not sure why), so I've had to make this workaround.

gt_path = osp.join(osp.dirname(__file__), 'gt.png')

imwrite(test_img, img_path)
imwrite(test_gt, gt_path)

# test only train with label with id 3
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
label_map={
0: 0,
1: 0,
2: 0,
3: 1,
4: 0
},
seg_fields=[])

load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))

load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))

gt_array = results['gt_semantic_seg']

true_mask = np.zeros_like(gt_array)
true_mask[6:8, 2:4] = 1

assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)

# test only train with label with id 4 and 3
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
label_map={
0: 0,
1: 0,
2: 0,
3: 2,
4: 1
},
seg_fields=[])

load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))

load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))

gt_array = results['gt_semantic_seg']

true_mask = np.zeros_like(gt_array)
true_mask[6:8, 2:4] = 2
true_mask[6:8, 6:8] = 1

assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)

# test no custom classes
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
seg_fields=[])

load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))

load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))

gt_array = results['gt_semantic_seg']

assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, test_gt)

os.remove(img_path)
os.remove(gt_path)