Skip to content

Commit

Permalink
[Feature] Support eval concate dataset and add tool to show dataset (#…
Browse files Browse the repository at this point in the history
…833)

* [Feature] Add tool to show origin or augmented train data

* [Feature] Support eval concate dataset

* Add docstring and modify evaluate of concate dataset

Signed-off-by: FreyWang <wangwxyz@qq.com>

* format concat dataset in subfolder of imgfile_prefix

Signed-off-by: FreyWang <wangwxyz@qq.com>

* add unittest of concate dataset

Signed-off-by: FreyWang <wangwxyz@qq.com>

* update unittest for eval dataset with CLASSES is None

Signed-off-by: FreyWang <wangwxyz@qq.com>

* [FIX] bug of generator,  which lead metric to nan when pre_eval=False

Signed-off-by: FreyWang <wangwxyz@qq.com>

* format code

Signed-off-by: FreyWang <wangwxyz@qq.com>

* add more unittest

* add more unittest

* optim concat dataset builder
  • Loading branch information
FreyWang authored Sep 9, 2021
1 parent eb0baee commit 872e544
Show file tree
Hide file tree
Showing 8 changed files with 646 additions and 46 deletions.
2 changes: 0 additions & 2 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ def total_intersect_and_union(results,
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""
num_imgs = len(results)
assert len(list(gt_seg_maps)) == num_imgs
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
Expand Down
4 changes: 3 additions & 1 deletion mmseg/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def _concat_dataset(cfg, default_args=None):
img_dir = cfg['img_dir']
ann_dir = cfg.get('ann_dir', None)
split = cfg.get('split', None)
# pop 'separate_eval' since it is not a valid key for common datasets.
separate_eval = cfg.pop('separate_eval', True)
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
if ann_dir is not None:
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
Expand Down Expand Up @@ -57,7 +59,7 @@ def _concat_dataset(cfg, default_args=None):
data_cfg['split'] = split[i]
datasets.append(build_dataset(data_cfg, default_args))

return ConcatDataset(datasets)
return ConcatDataset(datasets, separate_eval)


def build_dataset(cfg, default_args=None):
Expand Down
24 changes: 14 additions & 10 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os.path as osp
import warnings
from collections import OrderedDict
from functools import reduce

import mmcv
import numpy as np
Expand Down Expand Up @@ -99,6 +98,9 @@ def __init__(self,
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)
if test_mode:
assert self.CLASSES is not None, \
'`cls.CLASSES` or `classes` should be specified when testing'

# join paths if data_root is specified
if self.data_root is not None:
Expand Down Expand Up @@ -339,7 +341,12 @@ def get_palette_for_custom_classes(self, class_names, palette=None):

return palette

def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
def evaluate(self,
results,
metric='mIoU',
logger=None,
gt_seg_maps=None,
**kwargs):
"""Evaluate the dataset.
Args:
Expand All @@ -350,6 +357,8 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
'mDice' and 'mFscore' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,
used in ConcatDataset
Returns:
dict[str, float]: Default metrics.
Expand All @@ -364,14 +373,9 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
# test a list of files
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
gt_seg_maps = self.get_gt_seg_maps()
if self.CLASSES is None:
num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
else:
num_classes = len(self.CLASSES)
# reset generator
gt_seg_maps = self.get_gt_seg_maps()
if gt_seg_maps is None:
gt_seg_maps = self.get_gt_seg_maps()
num_classes = len(self.CLASSES)
ret_metrics = eval_metrics(
results,
gt_seg_maps,
Expand Down
143 changes: 141 additions & 2 deletions mmseg/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,163 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
from itertools import chain

import mmcv
import numpy as np
from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset

from .builder import DATASETS
from .cityscapes import CityscapesDataset


@DATASETS.register_module()
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
support evaluation and formatting results
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the concatenated
dataset results separately, Defaults to True.
"""

def __init__(self, datasets):
def __init__(self, datasets, separate_eval=True):
super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
self.PALETTE = datasets[0].PALETTE
self.separate_eval = separate_eval
assert separate_eval in [True, False], \
f'separate_eval can only be True or False,' \
f'but get {separate_eval}'
if any([isinstance(ds, CityscapesDataset) for ds in datasets]):
raise NotImplementedError(
'Evaluating ConcatDataset containing CityscapesDataset'
'is not supported!')

def evaluate(self, results, logger=None, **kwargs):
"""Evaluate the results.
Args:
results (list[tuple[torch.Tensor]] | list[str]]): per image
pre_eval results or predict segmentation map for
computing evaluation metric.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str: float]: evaluate results of the total dataset
or each separate
dataset if `self.separate_eval=True`.
"""
assert len(results) == self.cumulative_sizes[-1], \
('Dataset and results have different sizes: '
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')

# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f'{type(dataset)} does not implement evaluate function'

if self.separate_eval:
dataset_idx = -1
total_eval_results = dict()
for size, dataset in zip(self.cumulative_sizes, self.datasets):
start_idx = 0 if dataset_idx == -1 else \
self.cumulative_sizes[dataset_idx]
end_idx = self.cumulative_sizes[dataset_idx + 1]

results_per_dataset = results[start_idx:end_idx]
print_log(
f'\nEvaluateing {dataset.img_dir} with '
f'{len(results_per_dataset)} images now',
logger=logger)

eval_results_per_dataset = dataset.evaluate(
results_per_dataset, logger=logger, **kwargs)
dataset_idx += 1
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})

return total_eval_results

if len(set([type(ds) for ds in self.datasets])) != 1:
raise NotImplementedError(
'All the datasets should have same types when '
'self.separate_eval=False')
else:
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
# merge the generators of gt_seg_maps
gt_seg_maps = chain(
*[dataset.get_gt_seg_maps() for dataset in self.datasets])
else:
# if the results are `pre_eval` results,
# we do not need gt_seg_maps to evaluate
gt_seg_maps = None
eval_results = self.datasets[0].evaluate(
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
return eval_results

def get_dataset_idx_and_sample_idx(self, indice):
"""Return dataset and sample index when given an indice of
ConcatDataset.
Args:
indice (int): indice of sample in ConcatDataset
Returns:
int: the index of sub dataset the sample belong to
int: the index of sample in its corresponding subset
"""
if indice < 0:
if -indice > len(self):
raise ValueError(
'absolute value of index should not exceed dataset length')
indice = len(self) + indice
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice)
if dataset_idx == 0:
sample_idx = indice
else:
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx

def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
"""format result for every sample of ConcatDataset."""
if indices is None:
indices = list(range(len(self)))

assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'

ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].format_results(
[results[i]],
imgfile_prefix + f'/{dataset_idx}',
indices=[sample_idx],
**kwargs)
ret_res.append(res)
return sum(ret_res, [])

def pre_eval(self, preds, indices):
"""do pre eval for every sample of ConcatDataset."""
# In order to compat with batch inference
if not isinstance(indices, list):
indices = [indices]
if not isinstance(preds, list):
preds = [preds]
ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx)
ret_res.append(res)
return sum(ret_res, [])


@DATASETS.register_module()
Expand Down
Loading

0 comments on commit 872e544

Please sign in to comment.