-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support eval concate dataset and add tool to show dataset #833
Changes from all commits
6fdd5ed
8e75af9
bc781f1
17cf91a
5bc4e58
e6501a2
bf48690
980e3bf
26570de
51ac8af
5784eb0
9352d60
28e3bd2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,6 @@ | |
import os.path as osp | ||
import warnings | ||
from collections import OrderedDict | ||
from functools import reduce | ||
|
||
import mmcv | ||
import numpy as np | ||
|
@@ -99,6 +98,9 @@ def __init__(self, | |
self.label_map = None | ||
self.CLASSES, self.PALETTE = self.get_classes_and_palette( | ||
classes, palette) | ||
if test_mode: | ||
assert self.CLASSES is not None, \ | ||
'`cls.CLASSES` or `classes` should be specified when testing' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that this modify leads to failed github CI (checked). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I will find time to fix the issue above 😂 |
||
|
||
# join paths if data_root is specified | ||
if self.data_root is not None: | ||
|
@@ -339,7 +341,12 @@ def get_palette_for_custom_classes(self, class_names, palette=None): | |
|
||
return palette | ||
|
||
def evaluate(self, results, metric='mIoU', logger=None, **kwargs): | ||
def evaluate(self, | ||
results, | ||
metric='mIoU', | ||
logger=None, | ||
gt_seg_maps=None, | ||
**kwargs): | ||
"""Evaluate the dataset. | ||
|
||
Args: | ||
|
@@ -350,6 +357,8 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): | |
'mDice' and 'mFscore' are supported. | ||
logger (logging.Logger | None | str): Logger used for printing | ||
related information during evaluation. Default: None. | ||
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input, | ||
used in ConcatDataset | ||
|
||
Returns: | ||
dict[str, float]: Default metrics. | ||
|
@@ -364,14 +373,9 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): | |
# test a list of files | ||
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( | ||
results, str): | ||
gt_seg_maps = self.get_gt_seg_maps() | ||
if self.CLASSES is None: | ||
num_classes = len( | ||
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) | ||
else: | ||
num_classes = len(self.CLASSES) | ||
# reset generator | ||
gt_seg_maps = self.get_gt_seg_maps() | ||
if gt_seg_maps is None: | ||
gt_seg_maps = self.get_gt_seg_maps() | ||
num_classes = len(self.CLASSES) | ||
ret_metrics = eval_metrics( | ||
results, | ||
gt_seg_maps, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,163 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import bisect | ||
from itertools import chain | ||
|
||
import mmcv | ||
import numpy as np | ||
from mmcv.utils import print_log | ||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset | ||
|
||
from .builder import DATASETS | ||
from .cityscapes import CityscapesDataset | ||
|
||
|
||
@DATASETS.register_module() | ||
class ConcatDataset(_ConcatDataset): | ||
"""A wrapper of concatenated dataset. | ||
|
||
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but | ||
concat the group flag for image aspect ratio. | ||
support evaluation and formatting results | ||
|
||
Args: | ||
datasets (list[:obj:`Dataset`]): A list of datasets. | ||
separate_eval (bool): Whether to evaluate the concatenated | ||
dataset results separately, Defaults to True. | ||
""" | ||
|
||
def __init__(self, datasets): | ||
def __init__(self, datasets, separate_eval=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi, I hava fix the issue and add unittest for it, does I need to submit a new PR or not? |
||
super(ConcatDataset, self).__init__(datasets) | ||
self.CLASSES = datasets[0].CLASSES | ||
self.PALETTE = datasets[0].PALETTE | ||
self.separate_eval = separate_eval | ||
assert separate_eval in [True, False], \ | ||
f'separate_eval can only be True or False,' \ | ||
f'but get {separate_eval}' | ||
if any([isinstance(ds, CityscapesDataset) for ds in datasets]): | ||
raise NotImplementedError( | ||
'Evaluating ConcatDataset containing CityscapesDataset' | ||
'is not supported!') | ||
|
||
def evaluate(self, results, logger=None, **kwargs): | ||
"""Evaluate the results. | ||
|
||
Args: | ||
results (list[tuple[torch.Tensor]] | list[str]]): per image | ||
pre_eval results or predict segmentation map for | ||
computing evaluation metric. | ||
logger (logging.Logger | str | None): Logger used for printing | ||
related information during evaluation. Default: None. | ||
|
||
Returns: | ||
dict[str: float]: evaluate results of the total dataset | ||
or each separate | ||
dataset if `self.separate_eval=True`. | ||
""" | ||
assert len(results) == self.cumulative_sizes[-1], \ | ||
('Dataset and results have different sizes: ' | ||
f'{self.cumulative_sizes[-1]} v.s. {len(results)}') | ||
|
||
# Check whether all the datasets support evaluation | ||
for dataset in self.datasets: | ||
assert hasattr(dataset, 'evaluate'), \ | ||
f'{type(dataset)} does not implement evaluate function' | ||
|
||
if self.separate_eval: | ||
dataset_idx = -1 | ||
total_eval_results = dict() | ||
for size, dataset in zip(self.cumulative_sizes, self.datasets): | ||
start_idx = 0 if dataset_idx == -1 else \ | ||
self.cumulative_sizes[dataset_idx] | ||
end_idx = self.cumulative_sizes[dataset_idx + 1] | ||
|
||
results_per_dataset = results[start_idx:end_idx] | ||
print_log( | ||
f'\nEvaluateing {dataset.img_dir} with ' | ||
f'{len(results_per_dataset)} images now', | ||
logger=logger) | ||
|
||
eval_results_per_dataset = dataset.evaluate( | ||
results_per_dataset, logger=logger, **kwargs) | ||
dataset_idx += 1 | ||
for k, v in eval_results_per_dataset.items(): | ||
total_eval_results.update({f'{dataset_idx}_{k}': v}) | ||
|
||
return total_eval_results | ||
|
||
if len(set([type(ds) for ds in self.datasets])) != 1: | ||
raise NotImplementedError( | ||
'All the datasets should have same types when ' | ||
'self.separate_eval=False') | ||
else: | ||
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( | ||
results, str): | ||
# merge the generators of gt_seg_maps | ||
gt_seg_maps = chain( | ||
*[dataset.get_gt_seg_maps() for dataset in self.datasets]) | ||
else: | ||
# if the results are `pre_eval` results, | ||
# we do not need gt_seg_maps to evaluate | ||
gt_seg_maps = None | ||
eval_results = self.datasets[0].evaluate( | ||
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs) | ||
return eval_results | ||
|
||
def get_dataset_idx_and_sample_idx(self, indice): | ||
"""Return dataset and sample index when given an indice of | ||
ConcatDataset. | ||
|
||
Args: | ||
indice (int): indice of sample in ConcatDataset | ||
|
||
Returns: | ||
int: the index of sub dataset the sample belong to | ||
int: the index of sample in its corresponding subset | ||
""" | ||
if indice < 0: | ||
if -indice > len(self): | ||
raise ValueError( | ||
'absolute value of index should not exceed dataset length') | ||
indice = len(self) + indice | ||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice) | ||
if dataset_idx == 0: | ||
sample_idx = indice | ||
else: | ||
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1] | ||
return dataset_idx, sample_idx | ||
|
||
def format_results(self, results, imgfile_prefix, indices=None, **kwargs): | ||
"""format result for every sample of ConcatDataset.""" | ||
if indices is None: | ||
indices = list(range(len(self))) | ||
|
||
assert isinstance(results, list), 'results must be a list.' | ||
assert isinstance(indices, list), 'indices must be a list.' | ||
|
||
ret_res = [] | ||
for i, indice in enumerate(indices): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
you are right, I will fix it |
||
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx( | ||
indice) | ||
res = self.datasets[dataset_idx].format_results( | ||
[results[i]], | ||
imgfile_prefix + f'/{dataset_idx}', | ||
indices=[sample_idx], | ||
**kwargs) | ||
ret_res.append(res) | ||
return sum(ret_res, []) | ||
|
||
def pre_eval(self, preds, indices): | ||
"""do pre eval for every sample of ConcatDataset.""" | ||
# In order to compat with batch inference | ||
if not isinstance(indices, list): | ||
indices = [indices] | ||
if not isinstance(preds, list): | ||
preds = [preds] | ||
ret_res = [] | ||
for i, indice in enumerate(indices): | ||
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx( | ||
indice) | ||
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx) | ||
ret_res.append(res) | ||
return sum(ret_res, []) | ||
|
||
|
||
@DATASETS.register_module() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove these assert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list(gt_seg_maps)
will loop the generator and thengt_seg_maps
will be empty,mmsegmentation/mmseg/core/evaluation/metrics.py
Line 121 in 4981ff6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are still missing some lines.
You can view it through
files changed
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the assert is not remove,
mmsegmentation/tests/test_data/test_dataset.py
Line 172 in 4981ff6