Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support eval concate dataset and add tool to show dataset #833

Merged
merged 13 commits into from
Sep 9, 2021
Merged
2 changes: 0 additions & 2 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ def total_intersect_and_union(results,
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""
num_imgs = len(results)
assert len(list(gt_seg_maps)) == num_imgs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove these assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list(gt_seg_maps) will loop the generator and then gt_seg_maps will be empty,

for result, gt_seg_map in zip(results, gt_seg_maps):
will case error, lead metric to be nan, I have add unittest in bf48690.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still missing some lines.
You can view it through files changed.

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the assert is not remove,

eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
The eval result will be nan

total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
Expand Down
4 changes: 3 additions & 1 deletion mmseg/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def _concat_dataset(cfg, default_args=None):
img_dir = cfg['img_dir']
ann_dir = cfg.get('ann_dir', None)
split = cfg.get('split', None)
# pop 'separate_eval' since it is not a valid key for common datasets.
separate_eval = cfg.pop('separate_eval', True)
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
if ann_dir is not None:
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
Expand Down Expand Up @@ -57,7 +59,7 @@ def _concat_dataset(cfg, default_args=None):
data_cfg['split'] = split[i]
datasets.append(build_dataset(data_cfg, default_args))

return ConcatDataset(datasets)
return ConcatDataset(datasets, separate_eval)


def build_dataset(cfg, default_args=None):
Expand Down
24 changes: 14 additions & 10 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os.path as osp
import warnings
from collections import OrderedDict
from functools import reduce

import mmcv
import numpy as np
Expand Down Expand Up @@ -99,6 +98,9 @@ def __init__(self,
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)
if test_mode:
assert self.CLASSES is not None, \
'`cls.CLASSES` or `classes` should be specified when testing'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this modify leads to failed github CI (checked).
Could you please add some unittests and fix the failed unitsests ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will find time to fix the issue above 😂


# join paths if data_root is specified
if self.data_root is not None:
Expand Down Expand Up @@ -339,7 +341,12 @@ def get_palette_for_custom_classes(self, class_names, palette=None):

return palette

def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
def evaluate(self,
results,
metric='mIoU',
logger=None,
gt_seg_maps=None,
**kwargs):
"""Evaluate the dataset.

Args:
Expand All @@ -350,6 +357,8 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
'mDice' and 'mFscore' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,
used in ConcatDataset

Returns:
dict[str, float]: Default metrics.
Expand All @@ -364,14 +373,9 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
# test a list of files
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
gt_seg_maps = self.get_gt_seg_maps()
if self.CLASSES is None:
num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
else:
num_classes = len(self.CLASSES)
# reset generator
gt_seg_maps = self.get_gt_seg_maps()
if gt_seg_maps is None:
gt_seg_maps = self.get_gt_seg_maps()
num_classes = len(self.CLASSES)
ret_metrics = eval_metrics(
results,
gt_seg_maps,
Expand Down
143 changes: 141 additions & 2 deletions mmseg/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,163 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
from itertools import chain

import mmcv
import numpy as np
from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset

from .builder import DATASETS
from .cityscapes import CityscapesDataset


@DATASETS.register_module()
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.

Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
support evaluation and formatting results

Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the concatenated
dataset results separately, Defaults to True.
"""

def __init__(self, datasets):
def __init__(self, datasets, separate_eval=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring for separate_eval.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, I hava fix the issue and add unittest for it, does I need to submit a new PR or not?

super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
self.PALETTE = datasets[0].PALETTE
self.separate_eval = separate_eval
assert separate_eval in [True, False], \
f'separate_eval can only be True or False,' \
f'but get {separate_eval}'
if any([isinstance(ds, CityscapesDataset) for ds in datasets]):
raise NotImplementedError(
'Evaluating ConcatDataset containing CityscapesDataset'
'is not supported!')

def evaluate(self, results, logger=None, **kwargs):
"""Evaluate the results.

Args:
results (list[tuple[torch.Tensor]] | list[str]]): per image
pre_eval results or predict segmentation map for
computing evaluation metric.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.

Returns:
dict[str: float]: evaluate results of the total dataset
or each separate
dataset if `self.separate_eval=True`.
"""
assert len(results) == self.cumulative_sizes[-1], \
('Dataset and results have different sizes: '
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')

# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f'{type(dataset)} does not implement evaluate function'

if self.separate_eval:
dataset_idx = -1
total_eval_results = dict()
for size, dataset in zip(self.cumulative_sizes, self.datasets):
start_idx = 0 if dataset_idx == -1 else \
self.cumulative_sizes[dataset_idx]
end_idx = self.cumulative_sizes[dataset_idx + 1]

results_per_dataset = results[start_idx:end_idx]
print_log(
f'\nEvaluateing {dataset.img_dir} with '
f'{len(results_per_dataset)} images now',
logger=logger)

eval_results_per_dataset = dataset.evaluate(
results_per_dataset, logger=logger, **kwargs)
dataset_idx += 1
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})

return total_eval_results

if len(set([type(ds) for ds in self.datasets])) != 1:
raise NotImplementedError(
'All the datasets should have same types when '
'self.separate_eval=False')
else:
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
# merge the generators of gt_seg_maps
gt_seg_maps = chain(
*[dataset.get_gt_seg_maps() for dataset in self.datasets])
else:
# if the results are `pre_eval` results,
# we do not need gt_seg_maps to evaluate
gt_seg_maps = None
eval_results = self.datasets[0].evaluate(
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
return eval_results

def get_dataset_idx_and_sample_idx(self, indice):
"""Return dataset and sample index when given an indice of
ConcatDataset.

Args:
indice (int): indice of sample in ConcatDataset

Returns:
int: the index of sub dataset the sample belong to
int: the index of sample in its corresponding subset
"""
if indice < 0:
if -indice > len(self):
raise ValueError(
'absolute value of index should not exceed dataset length')
indice = len(self) + indice
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice)
if dataset_idx == 0:
sample_idx = indice
else:
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx

def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
"""format result for every sample of ConcatDataset."""
if indices is None:
indices = list(range(len(self)))

assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'

ret_res = []
for i, indice in enumerate(indices):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about indices=None
Maybe we need handle this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about indices=None
Maybe we need handle this case.

you are right, I will fix it

dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].format_results(
[results[i]],
imgfile_prefix + f'/{dataset_idx}',
indices=[sample_idx],
**kwargs)
ret_res.append(res)
return sum(ret_res, [])

def pre_eval(self, preds, indices):
"""do pre eval for every sample of ConcatDataset."""
# In order to compat with batch inference
if not isinstance(indices, list):
indices = [indices]
if not isinstance(preds, list):
preds = [preds]
ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx)
ret_res.append(res)
return sum(ret_res, [])


@DATASETS.register_module()
Expand Down
Loading