From 872e54497e4038f165a1fd4a37d3d4a18c61978e Mon Sep 17 00:00:00 2001 From: FreyWang Date: Thu, 9 Sep 2021 13:00:23 +0800 Subject: [PATCH] [Feature] Support eval concate dataset and add tool to show dataset (#833) * [Feature] Add tool to show origin or augmented train data * [Feature] Support eval concate dataset * Add docstring and modify evaluate of concate dataset Signed-off-by: FreyWang * format concat dataset in subfolder of imgfile_prefix Signed-off-by: FreyWang * add unittest of concate dataset Signed-off-by: FreyWang * update unittest for eval dataset with CLASSES is None Signed-off-by: FreyWang * [FIX] bug of generator, which lead metric to nan when pre_eval=False Signed-off-by: FreyWang * format code Signed-off-by: FreyWang * add more unittest * add more unittest * optim concat dataset builder --- mmseg/core/evaluation/metrics.py | 2 - mmseg/datasets/builder.py | 4 +- mmseg/datasets/custom.py | 24 +- mmseg/datasets/dataset_wrappers.py | 143 +++++++++- tests/test_data/test_dataset.py | 343 ++++++++++++++++++++++-- tests/test_data/test_dataset_builder.py | 6 +- tools/browse_dataset.py | 167 ++++++++++++ tools/test.py | 3 +- 8 files changed, 646 insertions(+), 46 deletions(-) create mode 100644 tools/browse_dataset.py diff --git a/mmseg/core/evaluation/metrics.py b/mmseg/core/evaluation/metrics.py index f64967c6c2..b83a798ea9 100644 --- a/mmseg/core/evaluation/metrics.py +++ b/mmseg/core/evaluation/metrics.py @@ -112,8 +112,6 @@ def total_intersect_and_union(results, ndarray: The prediction histogram on all classes. ndarray: The ground truth histogram on all classes. """ - num_imgs = len(results) - assert len(list(gt_seg_maps)) == num_imgs total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py index bfb54ef002..7ab645958d 100644 --- a/mmseg/datasets/builder.py +++ b/mmseg/datasets/builder.py @@ -30,6 +30,8 @@ def _concat_dataset(cfg, default_args=None): img_dir = cfg['img_dir'] ann_dir = cfg.get('ann_dir', None) split = cfg.get('split', None) + # pop 'separate_eval' since it is not a valid key for common datasets. + separate_eval = cfg.pop('separate_eval', True) num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 if ann_dir is not None: num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 @@ -57,7 +59,7 @@ def _concat_dataset(cfg, default_args=None): data_cfg['split'] = split[i] datasets.append(build_dataset(data_cfg, default_args)) - return ConcatDataset(datasets) + return ConcatDataset(datasets, separate_eval) def build_dataset(cfg, default_args=None): diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index e366c0da2d..9b0efc6f05 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -2,7 +2,6 @@ import os.path as osp import warnings from collections import OrderedDict -from functools import reduce import mmcv import numpy as np @@ -99,6 +98,9 @@ def __init__(self, self.label_map = None self.CLASSES, self.PALETTE = self.get_classes_and_palette( classes, palette) + if test_mode: + assert self.CLASSES is not None, \ + '`cls.CLASSES` or `classes` should be specified when testing' # join paths if data_root is specified if self.data_root is not None: @@ -339,7 +341,12 @@ def get_palette_for_custom_classes(self, class_names, palette=None): return palette - def evaluate(self, results, metric='mIoU', logger=None, **kwargs): + def evaluate(self, + results, + metric='mIoU', + logger=None, + gt_seg_maps=None, + **kwargs): """Evaluate the dataset. Args: @@ -350,6 +357,8 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): 'mDice' and 'mFscore' are supported. logger (logging.Logger | None | str): Logger used for printing related information during evaluation. Default: None. + gt_seg_maps (generator[ndarray]): Custom gt seg maps as input, + used in ConcatDataset Returns: dict[str, float]: Default metrics. @@ -364,14 +373,9 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): # test a list of files if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( results, str): - gt_seg_maps = self.get_gt_seg_maps() - if self.CLASSES is None: - num_classes = len( - reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) - else: - num_classes = len(self.CLASSES) - # reset generator - gt_seg_maps = self.get_gt_seg_maps() + if gt_seg_maps is None: + gt_seg_maps = self.get_gt_seg_maps() + num_classes = len(self.CLASSES) ret_metrics = eval_metrics( results, gt_seg_maps, diff --git a/mmseg/datasets/dataset_wrappers.py b/mmseg/datasets/dataset_wrappers.py index f161f71469..0349332eeb 100644 --- a/mmseg/datasets/dataset_wrappers.py +++ b/mmseg/datasets/dataset_wrappers.py @@ -1,7 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. +import bisect +from itertools import chain + +import mmcv +import numpy as np +from mmcv.utils import print_log from torch.utils.data.dataset import ConcatDataset as _ConcatDataset from .builder import DATASETS +from .cityscapes import CityscapesDataset @DATASETS.register_module() @@ -9,16 +16,148 @@ class ConcatDataset(_ConcatDataset): """A wrapper of concatenated dataset. Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but - concat the group flag for image aspect ratio. + support evaluation and formatting results Args: datasets (list[:obj:`Dataset`]): A list of datasets. + separate_eval (bool): Whether to evaluate the concatenated + dataset results separately, Defaults to True. """ - def __init__(self, datasets): + def __init__(self, datasets, separate_eval=True): super(ConcatDataset, self).__init__(datasets) self.CLASSES = datasets[0].CLASSES self.PALETTE = datasets[0].PALETTE + self.separate_eval = separate_eval + assert separate_eval in [True, False], \ + f'separate_eval can only be True or False,' \ + f'but get {separate_eval}' + if any([isinstance(ds, CityscapesDataset) for ds in datasets]): + raise NotImplementedError( + 'Evaluating ConcatDataset containing CityscapesDataset' + 'is not supported!') + + def evaluate(self, results, logger=None, **kwargs): + """Evaluate the results. + + Args: + results (list[tuple[torch.Tensor]] | list[str]]): per image + pre_eval results or predict segmentation map for + computing evaluation metric. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + + Returns: + dict[str: float]: evaluate results of the total dataset + or each separate + dataset if `self.separate_eval=True`. + """ + assert len(results) == self.cumulative_sizes[-1], \ + ('Dataset and results have different sizes: ' + f'{self.cumulative_sizes[-1]} v.s. {len(results)}') + + # Check whether all the datasets support evaluation + for dataset in self.datasets: + assert hasattr(dataset, 'evaluate'), \ + f'{type(dataset)} does not implement evaluate function' + + if self.separate_eval: + dataset_idx = -1 + total_eval_results = dict() + for size, dataset in zip(self.cumulative_sizes, self.datasets): + start_idx = 0 if dataset_idx == -1 else \ + self.cumulative_sizes[dataset_idx] + end_idx = self.cumulative_sizes[dataset_idx + 1] + + results_per_dataset = results[start_idx:end_idx] + print_log( + f'\nEvaluateing {dataset.img_dir} with ' + f'{len(results_per_dataset)} images now', + logger=logger) + + eval_results_per_dataset = dataset.evaluate( + results_per_dataset, logger=logger, **kwargs) + dataset_idx += 1 + for k, v in eval_results_per_dataset.items(): + total_eval_results.update({f'{dataset_idx}_{k}': v}) + + return total_eval_results + + if len(set([type(ds) for ds in self.datasets])) != 1: + raise NotImplementedError( + 'All the datasets should have same types when ' + 'self.separate_eval=False') + else: + if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( + results, str): + # merge the generators of gt_seg_maps + gt_seg_maps = chain( + *[dataset.get_gt_seg_maps() for dataset in self.datasets]) + else: + # if the results are `pre_eval` results, + # we do not need gt_seg_maps to evaluate + gt_seg_maps = None + eval_results = self.datasets[0].evaluate( + results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs) + return eval_results + + def get_dataset_idx_and_sample_idx(self, indice): + """Return dataset and sample index when given an indice of + ConcatDataset. + + Args: + indice (int): indice of sample in ConcatDataset + + Returns: + int: the index of sub dataset the sample belong to + int: the index of sample in its corresponding subset + """ + if indice < 0: + if -indice > len(self): + raise ValueError( + 'absolute value of index should not exceed dataset length') + indice = len(self) + indice + dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice) + if dataset_idx == 0: + sample_idx = indice + else: + sample_idx = indice - self.cumulative_sizes[dataset_idx - 1] + return dataset_idx, sample_idx + + def format_results(self, results, imgfile_prefix, indices=None, **kwargs): + """format result for every sample of ConcatDataset.""" + if indices is None: + indices = list(range(len(self))) + + assert isinstance(results, list), 'results must be a list.' + assert isinstance(indices, list), 'indices must be a list.' + + ret_res = [] + for i, indice in enumerate(indices): + dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx( + indice) + res = self.datasets[dataset_idx].format_results( + [results[i]], + imgfile_prefix + f'/{dataset_idx}', + indices=[sample_idx], + **kwargs) + ret_res.append(res) + return sum(ret_res, []) + + def pre_eval(self, preds, indices): + """do pre eval for every sample of ConcatDataset.""" + # In order to compat with batch inference + if not isinstance(indices, list): + indices = [indices] + if not isinstance(preds, list): + preds = [preds] + ret_res = [] + for i, indice in enumerate(indices): + dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx( + indice) + res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx) + ret_res.append(res) + return sum(ret_res, []) @DATASETS.register_module() diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index ebc173669d..f1ce7bb880 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -6,12 +6,13 @@ import numpy as np import pytest +import torch from PIL import Image from mmseg.core.evaluation import get_classes, get_palette from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset, ConcatDataset, CustomDataset, PascalVOCDataset, - RepeatDataset) + RepeatDataset, build_dataset) def test_classes(): @@ -143,7 +144,8 @@ def test_custom_dataset(): test_pipeline, img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'), img_suffix='img.jpg', - test_mode=True) + test_mode=True, + classes=('pseudo_class', )) assert len(test_dataset) == 5 # training data get @@ -164,30 +166,21 @@ def test_custom_dataset(): with pytest.raises(NotImplementedError): test_dataset.format_results([], '') - # test past evaluation pseudo_results = [] for gt_seg_map in gt_seg_maps: h, w = gt_seg_map.shape pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) - eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU']) - assert isinstance(eval_results, dict) - assert 'mIoU' in eval_results - assert 'mAcc' in eval_results - assert 'aAcc' in eval_results - eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') - assert isinstance(eval_results, dict) - assert 'mDice' in eval_results - assert 'mAcc' in eval_results - assert 'aAcc' in eval_results + # test past evaluation without CLASSES + with pytest.raises(TypeError): + eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU']) - eval_results = train_dataset.evaluate( - pseudo_results, metric=['mDice', 'mIoU']) - assert isinstance(eval_results, dict) - assert 'mIoU' in eval_results - assert 'mDice' in eval_results - assert 'mAcc' in eval_results - assert 'aAcc' in eval_results + with pytest.raises(TypeError): + eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') + + with pytest.raises(TypeError): + eval_results = train_dataset.evaluate( + pseudo_results, metric=['mDice', 'mIoU']) # test past evaluation with CLASSES train_dataset.CLASSES = tuple(['a'] * 7) @@ -221,6 +214,14 @@ def test_custom_dataset(): assert 'mPrecision' in eval_results assert 'mRecall' in eval_results + assert not np.isnan(eval_results['mIoU']) + assert not np.isnan(eval_results['mDice']) + assert not np.isnan(eval_results['mAcc']) + assert not np.isnan(eval_results['aAcc']) + assert not np.isnan(eval_results['mFscore']) + assert not np.isnan(eval_results['mPrecision']) + assert not np.isnan(eval_results['mRecall']) + # test evaluation with pre-eval and the dataset.CLASSES is necessary train_dataset.CLASSES = tuple(['a'] * 7) pseudo_results = [] @@ -258,6 +259,223 @@ def test_custom_dataset(): assert 'mPrecision' in eval_results assert 'mRecall' in eval_results + assert not np.isnan(eval_results['mIoU']) + assert not np.isnan(eval_results['mDice']) + assert not np.isnan(eval_results['mAcc']) + assert not np.isnan(eval_results['aAcc']) + assert not np.isnan(eval_results['mFscore']) + assert not np.isnan(eval_results['mPrecision']) + assert not np.isnan(eval_results['mRecall']) + + +@pytest.mark.parametrize('separate_eval', [True, False]) +def test_eval_concat_custom_dataset(separate_eval): + img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True) + test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(128, 256), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) + ] + data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') + img_dir = 'imgs/' + ann_dir = 'gts/' + + cfg1 = dict( + type='CustomDataset', + pipeline=test_pipeline, + data_root=data_root, + img_dir=img_dir, + ann_dir=ann_dir, + img_suffix='img.jpg', + seg_map_suffix='gt.png', + classes=tuple(['a'] * 7)) + dataset1 = build_dataset(cfg1) + assert len(dataset1) == 5 + # get gt seg map + gt_seg_maps = dataset1.get_gt_seg_maps(efficient_test=True) + assert isinstance(gt_seg_maps, Generator) + gt_seg_maps = list(gt_seg_maps) + assert len(gt_seg_maps) == 5 + + # test past evaluation + pseudo_results = [] + for gt_seg_map in gt_seg_maps: + h, w = gt_seg_map.shape + pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) + eval_results1 = dataset1.evaluate( + pseudo_results, metric=['mIoU', 'mDice', 'mFscore']) + + # We use same dir twice for simplicity + # with ann_dir + cfg2 = dict( + type='CustomDataset', + pipeline=test_pipeline, + data_root=data_root, + img_dir=[img_dir, img_dir], + ann_dir=[ann_dir, ann_dir], + img_suffix='img.jpg', + seg_map_suffix='gt.png', + classes=tuple(['a'] * 7), + separate_eval=separate_eval) + dataset2 = build_dataset(cfg2) + assert isinstance(dataset2, ConcatDataset) + assert len(dataset2) == 10 + + eval_results2 = dataset2.evaluate( + pseudo_results * 2, metric=['mIoU', 'mDice', 'mFscore']) + + if separate_eval: + assert eval_results1['mIoU'] == eval_results2[ + '0_mIoU'] == eval_results2['1_mIoU'] + assert eval_results1['mDice'] == eval_results2[ + '0_mDice'] == eval_results2['1_mDice'] + assert eval_results1['mAcc'] == eval_results2[ + '0_mAcc'] == eval_results2['1_mAcc'] + assert eval_results1['aAcc'] == eval_results2[ + '0_aAcc'] == eval_results2['1_aAcc'] + assert eval_results1['mFscore'] == eval_results2[ + '0_mFscore'] == eval_results2['1_mFscore'] + assert eval_results1['mPrecision'] == eval_results2[ + '0_mPrecision'] == eval_results2['1_mPrecision'] + assert eval_results1['mRecall'] == eval_results2[ + '0_mRecall'] == eval_results2['1_mRecall'] + else: + assert eval_results1['mIoU'] == eval_results2['mIoU'] + assert eval_results1['mDice'] == eval_results2['mDice'] + assert eval_results1['mAcc'] == eval_results2['mAcc'] + assert eval_results1['aAcc'] == eval_results2['aAcc'] + assert eval_results1['mFscore'] == eval_results2['mFscore'] + assert eval_results1['mPrecision'] == eval_results2['mPrecision'] + assert eval_results1['mRecall'] == eval_results2['mRecall'] + + # test get dataset_idx and sample_idx from ConcateDataset + dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(3) + assert dataset_idx == 0 + assert sample_idx == 3 + + dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(7) + assert dataset_idx == 1 + assert sample_idx == 2 + + dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-7) + assert dataset_idx == 0 + assert sample_idx == 3 + + # test negative indice exceed length of dataset + with pytest.raises(ValueError): + dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-11) + + # test negative indice value + indice = -6 + dataset_idx1, sample_idx1 = dataset2.get_dataset_idx_and_sample_idx(indice) + dataset_idx2, sample_idx2 = dataset2.get_dataset_idx_and_sample_idx( + len(dataset2) + indice) + assert dataset_idx1 == dataset_idx2 + assert sample_idx1 == sample_idx2 + + # test evaluation with pre-eval and the dataset.CLASSES is necessary + pseudo_results = [] + eval_results1 = [] + for idx in range(len(dataset1)): + h, w = gt_seg_maps[idx].shape + pseudo_result = np.random.randint(low=0, high=7, size=(h, w)) + pseudo_results.append(pseudo_result) + eval_results1.extend(dataset1.pre_eval(pseudo_result, idx)) + + assert len(eval_results1) == len(dataset1) + assert isinstance(eval_results1[0], tuple) + assert len(eval_results1[0]) == 4 + assert isinstance(eval_results1[0][0], torch.Tensor) + + eval_results1 = dataset1.evaluate( + eval_results1, metric=['mIoU', 'mDice', 'mFscore']) + + pseudo_results = pseudo_results * 2 + eval_results2 = [] + for idx in range(len(dataset2)): + eval_results2.extend(dataset2.pre_eval(pseudo_results[idx], idx)) + + assert len(eval_results2) == len(dataset2) + assert isinstance(eval_results2[0], tuple) + assert len(eval_results2[0]) == 4 + assert isinstance(eval_results2[0][0], torch.Tensor) + + eval_results2 = dataset2.evaluate( + eval_results2, metric=['mIoU', 'mDice', 'mFscore']) + + if separate_eval: + assert eval_results1['mIoU'] == eval_results2[ + '0_mIoU'] == eval_results2['1_mIoU'] + assert eval_results1['mDice'] == eval_results2[ + '0_mDice'] == eval_results2['1_mDice'] + assert eval_results1['mAcc'] == eval_results2[ + '0_mAcc'] == eval_results2['1_mAcc'] + assert eval_results1['aAcc'] == eval_results2[ + '0_aAcc'] == eval_results2['1_aAcc'] + assert eval_results1['mFscore'] == eval_results2[ + '0_mFscore'] == eval_results2['1_mFscore'] + assert eval_results1['mPrecision'] == eval_results2[ + '0_mPrecision'] == eval_results2['1_mPrecision'] + assert eval_results1['mRecall'] == eval_results2[ + '0_mRecall'] == eval_results2['1_mRecall'] + else: + assert eval_results1['mIoU'] == eval_results2['mIoU'] + assert eval_results1['mDice'] == eval_results2['mDice'] + assert eval_results1['mAcc'] == eval_results2['mAcc'] + assert eval_results1['aAcc'] == eval_results2['aAcc'] + assert eval_results1['mFscore'] == eval_results2['mFscore'] + assert eval_results1['mPrecision'] == eval_results2['mPrecision'] + assert eval_results1['mRecall'] == eval_results2['mRecall'] + + # test batch_indices for pre eval + eval_results2 = dataset2.pre_eval(pseudo_results, + list(range(len(pseudo_results)))) + + assert len(eval_results2) == len(dataset2) + assert isinstance(eval_results2[0], tuple) + assert len(eval_results2[0]) == 4 + assert isinstance(eval_results2[0][0], torch.Tensor) + + eval_results2 = dataset2.evaluate( + eval_results2, metric=['mIoU', 'mDice', 'mFscore']) + + if separate_eval: + assert eval_results1['mIoU'] == eval_results2[ + '0_mIoU'] == eval_results2['1_mIoU'] + assert eval_results1['mDice'] == eval_results2[ + '0_mDice'] == eval_results2['1_mDice'] + assert eval_results1['mAcc'] == eval_results2[ + '0_mAcc'] == eval_results2['1_mAcc'] + assert eval_results1['aAcc'] == eval_results2[ + '0_aAcc'] == eval_results2['1_aAcc'] + assert eval_results1['mFscore'] == eval_results2[ + '0_mFscore'] == eval_results2['1_mFscore'] + assert eval_results1['mPrecision'] == eval_results2[ + '0_mPrecision'] == eval_results2['1_mPrecision'] + assert eval_results1['mRecall'] == eval_results2[ + '0_mRecall'] == eval_results2['1_mRecall'] + else: + assert eval_results1['mIoU'] == eval_results2['mIoU'] + assert eval_results1['mDice'] == eval_results2['mDice'] + assert eval_results1['mAcc'] == eval_results2['mAcc'] + assert eval_results1['aAcc'] == eval_results2['aAcc'] + assert eval_results1['mFscore'] == eval_results2['mFscore'] + assert eval_results1['mPrecision'] == eval_results2['mPrecision'] + assert eval_results1['mRecall'] == eval_results2['mRecall'] + def test_ade(): test_dataset = ADE20KDataset( @@ -279,6 +497,44 @@ def test_ade(): shutil.rmtree('.format_ade') +@pytest.mark.parametrize('separate_eval', [True, False]) +def test_concat_ade(separate_eval): + test_dataset = ADE20KDataset( + pipeline=[], + img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')) + assert len(test_dataset) == 5 + + concat_dataset = ConcatDataset([test_dataset, test_dataset], + separate_eval=separate_eval) + assert len(concat_dataset) == 10 + # Test format_results + pseudo_results = [] + for _ in range(len(concat_dataset)): + h, w = (2, 2) + pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) + + # test format per image + file_paths = [] + for i in range(len(pseudo_results)): + file_paths.extend( + concat_dataset.format_results([pseudo_results[i]], + '.format_ade', + indices=[i])) + assert len(file_paths) == len(concat_dataset) + temp = np.array(Image.open(file_paths[0])) + assert np.allclose(temp, pseudo_results[0] + 1) + + shutil.rmtree('.format_ade') + + # test default argument + file_paths = concat_dataset.format_results(pseudo_results, '.format_ade') + assert len(file_paths) == len(concat_dataset) + temp = np.array(Image.open(file_paths[0])) + assert np.allclose(temp, pseudo_results[0] + 1) + + shutil.rmtree('.format_ade') + + def test_cityscapes(): test_dataset = CityscapesDataset( pipeline=[], @@ -311,6 +567,28 @@ def test_cityscapes(): shutil.rmtree('.format_city') +@pytest.mark.parametrize('separate_eval', [True, False]) +def test_concat_cityscapes(separate_eval): + cityscape_dataset = CityscapesDataset( + pipeline=[], + img_dir=osp.join( + osp.dirname(__file__), + '../data/pseudo_cityscapes_dataset/leftImg8bit'), + ann_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine')) + assert len(cityscape_dataset) == 1 + with pytest.raises(NotImplementedError): + _ = ConcatDataset([cityscape_dataset, cityscape_dataset], + separate_eval=separate_eval) + ade_dataset = ADE20KDataset( + pipeline=[], + img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')) + assert len(ade_dataset) == 5 + with pytest.raises(NotImplementedError): + _ = ConcatDataset([cityscape_dataset, ade_dataset], + separate_eval=separate_eval) + + @patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) @patch('mmseg.datasets.CustomDataset.__getitem__', MagicMock(side_effect=lambda idx: idx)) @@ -360,14 +638,23 @@ def test_custom_classes_override_default(dataset, classes): assert custom_dataset.CLASSES == [classes[0]] # Test default behavior - custom_dataset = dataset_class( - pipeline=[], - img_dir=MagicMock(), - split=MagicMock(), - classes=None, - test_mode=True) - - assert custom_dataset.CLASSES == original_classes + if dataset_class is CustomDataset: + with pytest.raises(AssertionError): + custom_dataset = dataset_class( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=None, + test_mode=True) + else: + custom_dataset = dataset_class( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=None, + test_mode=True) + + assert custom_dataset.CLASSES == original_classes @patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) diff --git a/tests/test_data/test_dataset_builder.py b/tests/test_data/test_dataset_builder.py index c945fe5527..edb82efb93 100644 --- a/tests/test_data/test_dataset_builder.py +++ b/tests/test_data/test_dataset_builder.py @@ -78,7 +78,8 @@ def test_build_dataset(): pipeline=[], data_root=data_root, img_dir=[img_dir, img_dir], - test_mode=True) + test_mode=True, + classes=('pseudo_class', )) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 10 @@ -90,7 +91,8 @@ def test_build_dataset(): data_root=data_root, img_dir=[img_dir, img_dir], split=['splits/val.txt', 'splits/val.txt'], - test_mode=True) + test_mode=True, + classes=('pseudo_class', )) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 2 diff --git a/tools/browse_dataset.py b/tools/browse_dataset.py new file mode 100644 index 0000000000..2ec414280a --- /dev/null +++ b/tools/browse_dataset.py @@ -0,0 +1,167 @@ +import argparse +import os +import warnings +from pathlib import Path + +import mmcv +import numpy as np +from mmcv import Config + +from mmseg.datasets.builder import build_dataset + + +def parse_args(): + parser = argparse.ArgumentParser(description='Browse a dataset') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--show-origin', + default=False, + action='store_true', + help='if True, omit all augmentation in pipeline,' + ' show origin image and seg map') + parser.add_argument( + '--skip-type', + type=str, + nargs='+', + default=['DefaultFormatBundle', 'Normalize', 'Collect'], + help='skip some useless pipeline,if `show-origin` is true, ' + 'all pipeline except `Load` will be skipped') + parser.add_argument( + '--output-dir', + default='./output', + type=str, + help='If there is no display interface, you can save it') + parser.add_argument('--show', default=False, action='store_true') + parser.add_argument( + '--show-interval', + type=int, + default=999, + help='the interval of show (ms)') + parser.add_argument( + '--opacity', + type=float, + default=0.5, + help='the opacity of semantic map') + args = parser.parse_args() + return args + + +def imshow_semantic(img, + seg, + class_names, + palette=None, + win_name='', + show=False, + wait_time=0, + out_file=None, + opacity=0.5): + """Draw `result` over `img`. + + Args: + img (str or Tensor): The image to be displayed. + seg (Tensor): The semantic segmentation results to draw over + `img`. + class_names (list[str]): Names of each classes. + palette (list[list[int]]] | np.ndarray | None): The palette of + segmentation map. If None is given, random palette will be + generated. Default: None + win_name (str): The window name. + wait_time (int): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The filename to write the image. + Default: None. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + Returns: + img (Tensor): Only if not `show` or `out_file` + """ + img = mmcv.imread(img) + img = img.copy() + if palette is None: + palette = np.random.randint(0, 255, size=(len(class_names), 3)) + palette = np.array(palette) + assert palette.shape[0] == len(class_names) + assert palette.shape[1] == 3 + assert len(palette.shape) == 2 + assert 0 < opacity <= 1.0 + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + # convert to BGR + color_seg = color_seg[..., ::-1] + + img = img * (1 - opacity) + color_seg * opacity + img = img.astype(np.uint8) + # if out_file specified, do not show image in window + if out_file is not None: + show = False + + if show: + mmcv.imshow(img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + if not (show or out_file): + warnings.warn('show==False and out_file is not specified, only ' + 'result image will be returned') + return img + + +def _retrieve_data_cfg(_data_cfg, skip_type, show_origin): + if show_origin is True: + # only keep pipeline of Loading data and ann + _data_cfg['pipeline'] = [ + x for x in _data_cfg.pipeline if 'Load' in x['type'] + ] + else: + _data_cfg['pipeline'] = [ + x for x in _data_cfg.pipeline if x['type'] not in skip_type + ] + + +def retrieve_data_cfg(config_path, skip_type, show_origin=False): + cfg = Config.fromfile(config_path) + train_data_cfg = cfg.data.train + if isinstance(train_data_cfg, list): + for _data_cfg in train_data_cfg: + if 'pipeline' in _data_cfg: + _retrieve_data_cfg(_data_cfg, skip_type, show_origin) + elif 'dataset' in _data_cfg: + _retrieve_data_cfg(_data_cfg['dataset'], skip_type, + show_origin) + else: + raise ValueError + elif 'dataset' in train_data_cfg: + _retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin) + else: + _retrieve_data_cfg(train_data_cfg, skip_type, show_origin) + return cfg + + +def main(): + args = parse_args() + cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin) + dataset = build_dataset(cfg.data.train) + progress_bar = mmcv.ProgressBar(len(dataset)) + for item in dataset: + filename = os.path.join(args.output_dir, + Path(item['filename']).name + ) if args.output_dir is not None else None + imshow_semantic( + item['img'], + item['gt_semantic_seg'], + dataset.CLASSES, + dataset.PALETTE, + show=args.show, + wait_time=args.show_interval, + out_file=filename, + opacity=args.opacity, + ) + progress_bar.update() + + +if __name__ == '__main__': + main() diff --git a/tools/test.py b/tools/test.py index d99c153728..3923b77f40 100644 --- a/tools/test.py +++ b/tools/test.py @@ -215,7 +215,8 @@ def main(): print(f'\nwriting results to {args.out}') mmcv.dump(results, args.out) if args.eval: - metric = dataset.evaluate(results, args.eval, **eval_kwargs) + eval_kwargs.update(metric=args.eval) + metric = dataset.evaluate(results, **eval_kwargs) metric_dict = dict(config=args.config, metric=metric) if args.work_dir is not None and rank == 0: mmcv.dump(metric_dict, json_file, indent=4)