Skip to content

Commit 598f5c7

Browse files
authored
[Feature] Support eval concate dataset and add tool to show dataset (open-mmlab#833)
* [Feature] Add tool to show origin or augmented train data * [Feature] Support eval concate dataset * Add docstring and modify evaluate of concate dataset Signed-off-by: FreyWang <wangwxyz@qq.com> * format concat dataset in subfolder of imgfile_prefix Signed-off-by: FreyWang <wangwxyz@qq.com> * add unittest of concate dataset Signed-off-by: FreyWang <wangwxyz@qq.com> * update unittest for eval dataset with CLASSES is None Signed-off-by: FreyWang <wangwxyz@qq.com> * [FIX] bug of generator, which lead metric to nan when pre_eval=False Signed-off-by: FreyWang <wangwxyz@qq.com> * format code Signed-off-by: FreyWang <wangwxyz@qq.com> * add more unittest * add more unittest * optim concat dataset builder
1 parent b0787b8 commit 598f5c7

File tree

8 files changed

+646
-46
lines changed

8 files changed

+646
-46
lines changed

mmseg/core/evaluation/metrics.py

-2
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,6 @@ def total_intersect_and_union(results,
112112
ndarray: The prediction histogram on all classes.
113113
ndarray: The ground truth histogram on all classes.
114114
"""
115-
num_imgs = len(results)
116-
assert len(list(gt_seg_maps)) == num_imgs
117115
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
118116
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
119117
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)

mmseg/datasets/builder.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def _concat_dataset(cfg, default_args=None):
3030
img_dir = cfg['img_dir']
3131
ann_dir = cfg.get('ann_dir', None)
3232
split = cfg.get('split', None)
33+
# pop 'separate_eval' since it is not a valid key for common datasets.
34+
separate_eval = cfg.pop('separate_eval', True)
3335
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
3436
if ann_dir is not None:
3537
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
@@ -57,7 +59,7 @@ def _concat_dataset(cfg, default_args=None):
5759
data_cfg['split'] = split[i]
5860
datasets.append(build_dataset(data_cfg, default_args))
5961

60-
return ConcatDataset(datasets)
62+
return ConcatDataset(datasets, separate_eval)
6163

6264

6365
def build_dataset(cfg, default_args=None):

mmseg/datasets/custom.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os.path as osp
33
import warnings
44
from collections import OrderedDict
5-
from functools import reduce
65

76
import mmcv
87
import numpy as np
@@ -99,6 +98,9 @@ def __init__(self,
9998
self.label_map = None
10099
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
101100
classes, palette)
101+
if test_mode:
102+
assert self.CLASSES is not None, \
103+
'`cls.CLASSES` or `classes` should be specified when testing'
102104

103105
# join paths if data_root is specified
104106
if self.data_root is not None:
@@ -339,7 +341,12 @@ def get_palette_for_custom_classes(self, class_names, palette=None):
339341

340342
return palette
341343

342-
def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
344+
def evaluate(self,
345+
results,
346+
metric='mIoU',
347+
logger=None,
348+
gt_seg_maps=None,
349+
**kwargs):
343350
"""Evaluate the dataset.
344351
345352
Args:
@@ -350,6 +357,8 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
350357
'mDice' and 'mFscore' are supported.
351358
logger (logging.Logger | None | str): Logger used for printing
352359
related information during evaluation. Default: None.
360+
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,
361+
used in ConcatDataset
353362
354363
Returns:
355364
dict[str, float]: Default metrics.
@@ -364,14 +373,9 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
364373
# test a list of files
365374
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
366375
results, str):
367-
gt_seg_maps = self.get_gt_seg_maps()
368-
if self.CLASSES is None:
369-
num_classes = len(
370-
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
371-
else:
372-
num_classes = len(self.CLASSES)
373-
# reset generator
374-
gt_seg_maps = self.get_gt_seg_maps()
376+
if gt_seg_maps is None:
377+
gt_seg_maps = self.get_gt_seg_maps()
378+
num_classes = len(self.CLASSES)
375379
ret_metrics = eval_metrics(
376380
results,
377381
gt_seg_maps,

mmseg/datasets/dataset_wrappers.py

+141-2
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,163 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import bisect
3+
from itertools import chain
4+
5+
import mmcv
6+
import numpy as np
7+
from mmcv.utils import print_log
28
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
39

410
from .builder import DATASETS
11+
from .cityscapes import CityscapesDataset
512

613

714
@DATASETS.register_module()
815
class ConcatDataset(_ConcatDataset):
916
"""A wrapper of concatenated dataset.
1017
1118
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
12-
concat the group flag for image aspect ratio.
19+
support evaluation and formatting results
1320
1421
Args:
1522
datasets (list[:obj:`Dataset`]): A list of datasets.
23+
separate_eval (bool): Whether to evaluate the concatenated
24+
dataset results separately, Defaults to True.
1625
"""
1726

18-
def __init__(self, datasets):
27+
def __init__(self, datasets, separate_eval=True):
1928
super(ConcatDataset, self).__init__(datasets)
2029
self.CLASSES = datasets[0].CLASSES
2130
self.PALETTE = datasets[0].PALETTE
31+
self.separate_eval = separate_eval
32+
assert separate_eval in [True, False], \
33+
f'separate_eval can only be True or False,' \
34+
f'but get {separate_eval}'
35+
if any([isinstance(ds, CityscapesDataset) for ds in datasets]):
36+
raise NotImplementedError(
37+
'Evaluating ConcatDataset containing CityscapesDataset'
38+
'is not supported!')
39+
40+
def evaluate(self, results, logger=None, **kwargs):
41+
"""Evaluate the results.
42+
43+
Args:
44+
results (list[tuple[torch.Tensor]] | list[str]]): per image
45+
pre_eval results or predict segmentation map for
46+
computing evaluation metric.
47+
logger (logging.Logger | str | None): Logger used for printing
48+
related information during evaluation. Default: None.
49+
50+
Returns:
51+
dict[str: float]: evaluate results of the total dataset
52+
or each separate
53+
dataset if `self.separate_eval=True`.
54+
"""
55+
assert len(results) == self.cumulative_sizes[-1], \
56+
('Dataset and results have different sizes: '
57+
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
58+
59+
# Check whether all the datasets support evaluation
60+
for dataset in self.datasets:
61+
assert hasattr(dataset, 'evaluate'), \
62+
f'{type(dataset)} does not implement evaluate function'
63+
64+
if self.separate_eval:
65+
dataset_idx = -1
66+
total_eval_results = dict()
67+
for size, dataset in zip(self.cumulative_sizes, self.datasets):
68+
start_idx = 0 if dataset_idx == -1 else \
69+
self.cumulative_sizes[dataset_idx]
70+
end_idx = self.cumulative_sizes[dataset_idx + 1]
71+
72+
results_per_dataset = results[start_idx:end_idx]
73+
print_log(
74+
f'\nEvaluateing {dataset.img_dir} with '
75+
f'{len(results_per_dataset)} images now',
76+
logger=logger)
77+
78+
eval_results_per_dataset = dataset.evaluate(
79+
results_per_dataset, logger=logger, **kwargs)
80+
dataset_idx += 1
81+
for k, v in eval_results_per_dataset.items():
82+
total_eval_results.update({f'{dataset_idx}_{k}': v})
83+
84+
return total_eval_results
85+
86+
if len(set([type(ds) for ds in self.datasets])) != 1:
87+
raise NotImplementedError(
88+
'All the datasets should have same types when '
89+
'self.separate_eval=False')
90+
else:
91+
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
92+
results, str):
93+
# merge the generators of gt_seg_maps
94+
gt_seg_maps = chain(
95+
*[dataset.get_gt_seg_maps() for dataset in self.datasets])
96+
else:
97+
# if the results are `pre_eval` results,
98+
# we do not need gt_seg_maps to evaluate
99+
gt_seg_maps = None
100+
eval_results = self.datasets[0].evaluate(
101+
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
102+
return eval_results
103+
104+
def get_dataset_idx_and_sample_idx(self, indice):
105+
"""Return dataset and sample index when given an indice of
106+
ConcatDataset.
107+
108+
Args:
109+
indice (int): indice of sample in ConcatDataset
110+
111+
Returns:
112+
int: the index of sub dataset the sample belong to
113+
int: the index of sample in its corresponding subset
114+
"""
115+
if indice < 0:
116+
if -indice > len(self):
117+
raise ValueError(
118+
'absolute value of index should not exceed dataset length')
119+
indice = len(self) + indice
120+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice)
121+
if dataset_idx == 0:
122+
sample_idx = indice
123+
else:
124+
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1]
125+
return dataset_idx, sample_idx
126+
127+
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
128+
"""format result for every sample of ConcatDataset."""
129+
if indices is None:
130+
indices = list(range(len(self)))
131+
132+
assert isinstance(results, list), 'results must be a list.'
133+
assert isinstance(indices, list), 'indices must be a list.'
134+
135+
ret_res = []
136+
for i, indice in enumerate(indices):
137+
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
138+
indice)
139+
res = self.datasets[dataset_idx].format_results(
140+
[results[i]],
141+
imgfile_prefix + f'/{dataset_idx}',
142+
indices=[sample_idx],
143+
**kwargs)
144+
ret_res.append(res)
145+
return sum(ret_res, [])
146+
147+
def pre_eval(self, preds, indices):
148+
"""do pre eval for every sample of ConcatDataset."""
149+
# In order to compat with batch inference
150+
if not isinstance(indices, list):
151+
indices = [indices]
152+
if not isinstance(preds, list):
153+
preds = [preds]
154+
ret_res = []
155+
for i, indice in enumerate(indices):
156+
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
157+
indice)
158+
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx)
159+
ret_res.append(res)
160+
return sum(ret_res, [])
22161

23162

24163
@DATASETS.register_module()

0 commit comments

Comments
 (0)