Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support eval concate dataset and add tool to show dataset #781

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mmseg/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def _concat_dataset(cfg, default_args=None):
img_dir = cfg['img_dir']
ann_dir = cfg.get('ann_dir', None)
split = cfg.get('split', None)
separate_eval = cfg.get('separate_eval', True)
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
if ann_dir is not None:
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
Expand All @@ -48,6 +49,9 @@ def _concat_dataset(cfg, default_args=None):
datasets = []
for i in range(num_dset):
data_cfg = copy.deepcopy(cfg)
# pop 'separate_eval' since it is not a valid key for common datasets.
if 'separate_eval' in data_cfg:
data_cfg.pop('separate_eval')
if isinstance(img_dir, (list, tuple)):
data_cfg['img_dir'] = img_dir[i]
if isinstance(ann_dir, (list, tuple)):
Expand All @@ -56,7 +60,7 @@ def _concat_dataset(cfg, default_args=None):
data_cfg['split'] = split[i]
datasets.append(build_dataset(data_cfg, default_args))

return ConcatDataset(datasets)
return ConcatDataset(datasets, separate_eval)


def build_dataset(cfg, default_args=None):
Expand Down
5 changes: 4 additions & 1 deletion mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def evaluate(self,
results,
metric='mIoU',
logger=None,
gt_seg_maps=None,
efficient_test=False,
**kwargs):
"""Evaluate the dataset.
Expand All @@ -315,6 +316,7 @@ def evaluate(self,
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. 'mIoU',
'mDice' and 'mFscore' are supported.
gt_seg_maps (list): Custom gt seg maps as input, used in concat dataset
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.

Expand All @@ -328,7 +330,8 @@ def evaluate(self,
if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric))
eval_results = {}
gt_seg_maps = self.get_gt_seg_maps(efficient_test)
if gt_seg_maps is None:
gt_seg_maps = self.get_gt_seg_maps(efficient_test)
if self.CLASSES is None:
num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
Expand Down
58 changes: 57 additions & 1 deletion mmseg/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset

from .builder import DATASETS
Expand All @@ -14,12 +15,67 @@ class ConcatDataset(_ConcatDataset):
datasets (list[:obj:`Dataset`]): A list of datasets.
"""

def __init__(self, datasets):
def __init__(self, datasets, separate_eval=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring for separate_eval.

super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
self.PALETTE = datasets[0].PALETTE
self.separate_eval = separate_eval

def evaluate(self, results, logger=None, **kwargs):
"""Evaluate the results.

Args:
results (list[list | tuple]): Testing results of the dataset.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.

Returns:
dict[str: float]: AP results of the total dataset or each separate
dataset if `self.separate_eval=True`.
"""
assert len(results) == self.cumulative_sizes[-1], \
('Dataset and results have different sizes: '
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')

# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f'{type(dataset)} does not implement evaluate function'

if self.separate_eval:
dataset_idx = -1
total_eval_results = dict()
for size, dataset in zip(self.cumulative_sizes, self.datasets):
start_idx = 0 if dataset_idx == -1 else \
self.cumulative_sizes[dataset_idx]
end_idx = self.cumulative_sizes[dataset_idx + 1]

results_per_dataset = results[start_idx:end_idx]
print_log(
f'\nEvaluateing {dataset.img_dir} with '
f'{len(results_per_dataset)} images now',
logger=logger)

eval_results_per_dataset = dataset.evaluate(
results_per_dataset, logger=logger, **kwargs)
dataset_idx += 1
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})

return total_eval_results

if len(set([type(ds) for ds in self.datasets])) != 1:
raise NotImplementedError(
'All the datasets should have same types')
else:
efficient_test = kwargs.get('efficient_test', False)
gt_seg_maps = sum([dataset.get_gt_seg_maps(efficient_test)
for dataset in self.datasets], [])
eval_results = self.datasets[0].evaluate(
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
return eval_results


@DATASETS.register_module()
class RepeatDataset(object):
"""A wrapper of repeated dataset.
Expand Down
172 changes: 172 additions & 0 deletions tools/browse_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import warnings
import argparse
import os
from pathlib import Path

import mmcv
from mmcv import Config
import numpy as np

from mmseg.datasets.builder import build_dataset


def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--show-origin',
default=False,
action='store_true',
help='if True, omit all augmentation in pipeline,'
' show origin image and seg map'
)
parser.add_argument(
'--skip-type',
type=str,
nargs='+',
default=['DefaultFormatBundle', 'Normalize', 'Collect'],
help='skip some useless pipeline,if `show-origin` is true, '
'all pipeline except `Load` will be skipped')
parser.add_argument(
'--output-dir',
default='./output',
type=str,
help='If there is no display interface, you can save it')
parser.add_argument(
'--show',
default=False,
action='store_true')
parser.add_argument(
'--show-interval',
type=int,
default=999,
help='the interval of show (ms)')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='the opacity of semantic map'
)
args = parser.parse_args()
return args


def imshow_semantic(img,
seg,
class_names,
palette=None,
win_name='',
show=False,
wait_time=0,
out_file=None,
opacity=0.5):
"""Draw `result` over `img`.

Args:
img (str or Tensor): The image to be displayed.
seg (Tensor): The semantic segmentation results to draw over
`img`.
class_names (list[str]): Names of each classes.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
if palette is None:
palette = np.random.randint(
0, 255, size=(len(class_names), 3))
palette = np.array(palette)
assert palette.shape[0] == len(class_names)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]

img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
# if out_file specified, do not show image in window
if out_file is not None:
show = False

if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)

if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img


def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
if show_origin is True:
# only keep pipeline of Loading data and ann
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if 'Load' in x['type']
]
else:
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if x['type'] not in skip_type
]


def retrieve_data_cfg(config_path, skip_type, show_origin=False):
cfg = Config.fromfile(config_path)
train_data_cfg = cfg.data.train
if isinstance(train_data_cfg, list):
for _data_cfg in train_data_cfg:
if 'pipeline' in _data_cfg:
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
elif 'dataset' in _data_cfg:
_retrieve_data_cfg(_data_cfg['dataset'], skip_type, show_origin)
else:
raise ValueError
elif 'dataset' in train_data_cfg:
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
else:
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
return cfg


def main():
args = parse_args()
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
dataset = build_dataset(cfg.data.train)
progress_bar = mmcv.ProgressBar(len(dataset))
for item in dataset:
filename = os.path.join(args.output_dir,
Path(item['filename']).name
) if args.output_dir is not None else None
imshow_semantic(item['img'],
item['gt_semantic_seg'],
dataset.CLASSES,
dataset.PALETTE,
show=args.show,
wait_time=args.show_interval,
out_file=filename,
opacity=args.opacity,
)
progress_bar.update()


if __name__ == '__main__':
main()

12 changes: 10 additions & 2 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def parse_args():

def main():
args = parse_args()

assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the '
Expand Down Expand Up @@ -158,7 +158,15 @@ def main():
if args.format_only:
dataset.format_results(outputs, **kwargs)
if args.eval:
dataset.evaluate(outputs, args.eval, **kwargs)
eval_kwargs = cfg.get('evaluation', {}).copy()
# hard-code way to remove EvalHook args
for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
'rule'
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=args.eval, **kwargs))
dataset.evaluate(outputs, **eval_kwargs)


if __name__ == '__main__':
Expand Down