diff --git a/docs/inference.md b/docs/inference.md index a19c2258da..d7bc21b65a 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -25,6 +25,7 @@ Optional arguments: - `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `mIoU` is available for all dataset. Cityscapes could be evaluated by `cityscapes` as well as standard `mIoU` metrics. - `--show`: If specified, segmentation results will be plotted on the images and shown in a new window. It is only applicable to single GPU testing and used for debugging and visualization. Please make sure that GUI is available in your environment, otherwise you may encounter the error like `cannot connect to X server`. - `--show-dir`: If specified, segmentation results will be plotted on the images and saved to the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option. +- `--eval-options`: Optional parameters during evaluation. When `efficient_test=True`, it will save intermediate results to local files to save CPU memory. Make sure that you have enough local storage space (more than 20GB). Examples: @@ -86,3 +87,15 @@ Assume that you have already downloaded the checkpoints to the directory `checkp You will get png files under `./pspnet_test_results` directory. You may run `zip -r results.zip pspnet_test_results/` and submit the zip file to [evaluation server](https://www.cityscapes-dataset.com/submit/). + +6. CPU memory efficient test DeeplabV3+ on Cityscapes (without saving the test results) and evaluate the mIoU. + + ```shell + python tools/test.py \ + configs/deeplabv3plus/deeplabv3plus_r18-d8_512x1024_80k_cityscapes.py \ + deeplabv3plus_r18-d8_512x1024_80k_cityscapes_20201226_080942-cff257fe.pth \ + --eval-options efficient_test=True \ + --eval mIoU + ``` + + Using ```pmap``` to view CPU memory footprint, it used 2.25GB CPU memory with ```efficient_test=True``` and 11.06GB CPU memory with ```efficient_test=False``` . This optional parameter can save a lot of memory. diff --git a/mmseg/apis/test.py b/mmseg/apis/test.py index 7f98abf297..148df7680e 100644 --- a/mmseg/apis/test.py +++ b/mmseg/apis/test.py @@ -4,21 +4,48 @@ import tempfile import mmcv +import numpy as np import torch import torch.distributed as dist from mmcv.image import tensor2imgs from mmcv.runner import get_dist_info -def single_gpu_test(model, data_loader, show=False, out_dir=None): +def np2tmp(array, temp_file_name=None): + """Save ndarray to local numpy file. + + Args: + array (ndarray): Ndarray to save. + temp_file_name (str): Numpy file name. If 'temp_file_name=None', this + function will generate a file name with tempfile.NamedTemporaryFile + to save ndarray. Default: None. + + Returns: + str: The numpy file name. + """ + + if temp_file_name is None: + temp_file_name = tempfile.NamedTemporaryFile( + suffix='.npy', delete=False).name + np.save(temp_file_name, array) + return temp_file_name + + +def single_gpu_test(model, + data_loader, + show=False, + out_dir=None, + efficient_test=False): """Test with single GPU. Args: model (nn.Module): Model to be tested. - data_loader (nn.Dataloader): Pytorch data loader. + data_loader (utils.data.Dataloader): Pytorch data loader. show (bool): Whether show results during infernece. Default: False. - out_dir (str, optional): If specified, the results will be dumped - into the directory to save output results. + out_dir (str, optional): If specified, the results will be dumped into + the directory to save output results. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. Returns: list: The prediction results. @@ -31,10 +58,6 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None): for i, data in enumerate(data_loader): with torch.no_grad(): result = model(return_loss=False, **data) - if isinstance(result, list): - results.extend(result) - else: - results.append(result) if show or out_dir: img_tensor = data['img'][0] @@ -61,13 +84,26 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None): show=show, out_file=out_file) + if isinstance(result, list): + if efficient_test: + result = [np2tmp(_) for _ in result] + results.extend(result) + else: + if efficient_test: + result = np2tmp(result) + results.append(result) + batch_size = data['img'][0].size(0) for _ in range(batch_size): prog_bar.update() return results -def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): +def multi_gpu_test(model, + data_loader, + tmpdir=None, + gpu_collect=False, + efficient_test=False): """Test model with multiple gpus. This method tests model with multiple gpus and collects the results @@ -78,10 +114,12 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): Args: model (nn.Module): Model to be tested. - data_loader (nn.Dataloader): Pytorch data loader. + data_loader (utils.data.Dataloader): Pytorch data loader. tmpdir (str): Path of directory to save the temporary results from different gpus under cpu mode. gpu_collect (bool): Option to use either gpu or cpu to collect results. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. Returns: list: The prediction results. @@ -96,9 +134,14 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): for i, data in enumerate(data_loader): with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) + if isinstance(result, list): + if efficient_test: + result = [np2tmp(_) for _ in result] results.extend(result) else: + if efficient_test: + result = np2tmp(result) results.append(result) if rank == 0: diff --git a/mmseg/core/evaluation/metrics.py b/mmseg/core/evaluation/metrics.py index 45c62b1641..86475a8983 100644 --- a/mmseg/core/evaluation/metrics.py +++ b/mmseg/core/evaluation/metrics.py @@ -1,24 +1,49 @@ +import mmcv import numpy as np -def intersect_and_union(pred_label, label, num_classes, ignore_index): +def intersect_and_union(pred_label, + label, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): """Calculate intersection and Union. Args: - pred_label (ndarray): Prediction segmentation map - label (ndarray): Ground truth segmentation map - num_classes (int): Number of categories + pred_label (ndarray): Prediction segmentation map. + label (ndarray): Ground truth segmentation map. + num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. + label_map (dict): Mapping old labels to new labels. The parameter will + work only when label is str. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. The parameter will + work only when label is str. Default: False. Returns: ndarray: The intersection of prediction and ground truth histogram - on all classes + on all classes. ndarray: The union of prediction and ground truth histogram on all - classes + classes. ndarray: The prediction histogram on all classes. ndarray: The ground truth histogram on all classes. """ + if isinstance(pred_label, str): + pred_label = np.load(pred_label) + + if isinstance(label, str): + label = mmcv.imread(label, flag='unchanged', backend='pillow') + # modify if custom classes + if label_map is not None: + for old_id, new_id in label_map.items(): + label[label == old_id] = new_id + if reduce_zero_label: + # avoid using underflow conversion + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + mask = (label != ignore_index) pred_label = pred_label[mask] label = label[mask] @@ -34,20 +59,27 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index): return area_intersect, area_union, area_pred_label, area_label -def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index): +def total_intersect_and_union(results, + gt_seg_maps, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): """Calculate Total Intersection and Union. Args: - results (list[ndarray]): List of prediction segmentation maps - gt_seg_maps (list[ndarray]): list of ground truth segmentation maps - num_classes (int): Number of categories + results (list[ndarray]): List of prediction segmentation maps. + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. Returns: ndarray: The intersection of prediction and ground truth histogram - on all classes + on all classes. ndarray: The union of prediction and ground truth histogram on all - classes + classes. ndarray: The prediction histogram on all classes. ndarray: The ground truth histogram on all classes. """ @@ -61,7 +93,7 @@ def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index): for i in range(num_imgs): area_intersect, area_union, area_pred_label, area_label = \ intersect_and_union(results[i], gt_seg_maps[i], num_classes, - ignore_index=ignore_index) + ignore_index, label_map, reduce_zero_label) total_area_intersect += area_intersect total_area_union += area_union total_area_pred_label += area_pred_label @@ -70,21 +102,29 @@ def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index): total_area_pred_label, total_area_label -def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None): +def mean_iou(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): """Calculate Mean Intersection and Union (mIoU) Args: - results (list[ndarray]): List of prediction segmentation maps - gt_seg_maps (list[ndarray]): list of ground truth segmentation maps - num_classes (int): Number of categories + results (list[ndarray]): List of prediction segmentation maps. + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. nan_to_num (int, optional): If specified, NaN values will be replaced by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. Returns: float: Overall accuracy on all images. - ndarray: Per category accuracy, shape (num_classes, ) - ndarray: Per category IoU, shape (num_classes, ) + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category IoU, shape (num_classes, ). """ all_acc, acc, iou = eval_metrics( @@ -93,7 +133,9 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None): num_classes=num_classes, ignore_index=ignore_index, metrics=['mIoU'], - nan_to_num=nan_to_num) + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label) return all_acc, acc, iou @@ -101,21 +143,25 @@ def mean_dice(results, gt_seg_maps, num_classes, ignore_index, - nan_to_num=None): + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): """Calculate Mean Dice (mDice) Args: - results (list[ndarray]): List of prediction segmentation maps - gt_seg_maps (list[ndarray]): list of ground truth segmentation maps - num_classes (int): Number of categories + results (list[ndarray]): List of prediction segmentation maps. + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. nan_to_num (int, optional): If specified, NaN values will be replaced by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. Returns: float: Overall accuracy on all images. - ndarray: Per category accuracy, shape (num_classes, ) - ndarray: Per category dice, shape (num_classes, ) + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category dice, shape (num_classes, ). """ all_acc, acc, dice = eval_metrics( @@ -124,7 +170,9 @@ def mean_dice(results, num_classes=num_classes, ignore_index=ignore_index, metrics=['mDice'], - nan_to_num=nan_to_num) + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label) return all_acc, acc, dice @@ -133,20 +181,24 @@ def eval_metrics(results, num_classes, ignore_index, metrics=['mIoU'], - nan_to_num=None): + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): """Calculate evaluation metrics Args: - results (list[ndarray]): List of prediction segmentation maps - gt_seg_maps (list[ndarray]): list of ground truth segmentation maps - num_classes (int): Number of categories + results (list[ndarray]): List of prediction segmentation maps. + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. nan_to_num (int, optional): If specified, NaN values will be replaced by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. Returns: float: Overall accuracy on all images. - ndarray: Per category accuracy, shape (num_classes, ) - ndarray: Per category evalution metrics, shape (num_classes, ) + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category evalution metrics, shape (num_classes, ). """ if isinstance(metrics, str): @@ -156,8 +208,9 @@ def eval_metrics(results, raise KeyError('metrics {} is not supported'.format(metrics)) total_area_intersect, total_area_union, total_area_pred_label, \ total_area_label = total_intersect_and_union(results, gt_seg_maps, - num_classes, - ignore_index=ignore_index) + num_classes, ignore_index, + label_map, + reduce_zero_label) all_acc = total_area_intersect.sum() / total_area_label.sum() acc = total_area_intersect / total_area_label ret_metrics = [all_acc, acc] diff --git a/mmseg/datasets/cityscapes.py b/mmseg/datasets/cityscapes.py index e26cd00b09..fa9958ac14 100644 --- a/mmseg/datasets/cityscapes.py +++ b/mmseg/datasets/cityscapes.py @@ -38,6 +38,8 @@ def __init__(self, **kwargs): @staticmethod def _convert_to_label_id(result): """Convert trainId to id for cityscapes.""" + if isinstance(result, str): + result = np.load(result) import cityscapesscripts.helpers.labels as CSLabels result_copy = result.copy() for trainId, label in CSLabels.trainId2label.items(): @@ -123,7 +125,8 @@ def evaluate(self, results, metric='mIoU', logger=None, - imgfile_prefix=None): + imgfile_prefix=None, + efficient_test=False): """Evaluation in Cityscapes/default protocol. Args: @@ -154,7 +157,7 @@ def evaluate(self, if len(metrics) > 0: eval_results.update( super(CityscapesDataset, - self).evaluate(results, metrics, logger)) + self).evaluate(results, metrics, logger, efficient_test)) return eval_results diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index 4e7e30e91c..dc923fb42d 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -1,3 +1,4 @@ +import os import os.path as osp from functools import reduce @@ -226,25 +227,17 @@ def format_results(self, results, **kwargs): """Place holder to format result to dataset specific output.""" pass - def get_gt_seg_maps(self): + def get_gt_seg_maps(self, efficient_test=False): """Get ground truth segmentation maps for evaluation.""" gt_seg_maps = [] for img_info in self.img_infos: seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) - gt_seg_map = mmcv.imread( - seg_map, flag='unchanged', backend='pillow') - # modify if custom classes - if self.label_map is not None: - for old_id, new_id in self.label_map.items(): - gt_seg_map[gt_seg_map == old_id] = new_id - if self.reduce_zero_label: - # avoid using underflow conversion - gt_seg_map[gt_seg_map == 0] = 255 - gt_seg_map = gt_seg_map - 1 - gt_seg_map[gt_seg_map == 254] = 255 - + if efficient_test: + gt_seg_map = seg_map + else: + gt_seg_map = mmcv.imread( + seg_map, flag='unchanged', backend='pillow') gt_seg_maps.append(gt_seg_map) - return gt_seg_maps def get_classes_and_palette(self, classes=None, palette=None): @@ -310,7 +303,12 @@ def get_palette_for_custom_classes(self, class_names, palette=None): return palette - def evaluate(self, results, metric='mIoU', logger=None, **kwargs): + def evaluate(self, + results, + metric='mIoU', + logger=None, + efficient_test=False, + **kwargs): """Evaluate the dataset. Args: @@ -330,7 +328,7 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): if not set(metric).issubset(set(allowed_metrics)): raise KeyError('metric {} is not supported'.format(metric)) eval_results = {} - gt_seg_maps = self.get_gt_seg_maps() + gt_seg_maps = self.get_gt_seg_maps(efficient_test) if self.CLASSES is None: num_classes = len( reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) @@ -340,8 +338,10 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): results, gt_seg_maps, num_classes, - ignore_index=self.ignore_index, - metrics=metric) + self.ignore_index, + metric, + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label) class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']] if self.CLASSES is None: class_names = tuple(range(num_classes)) @@ -374,4 +374,7 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): for i in range(1, len(summary_table_data[0])): eval_results[summary_table_data[0] [i]] = summary_table_data[1][i] / 100.0 + if mmcv.is_list_of(results, str): + for file_name in results: + os.remove(file_name) return eval_results diff --git a/tools/test.py b/tools/test.py index 3910f1f0bb..e47fcca68f 100644 --- a/tools/test.py +++ b/tools/test.py @@ -115,16 +115,21 @@ def main(): model.CLASSES = checkpoint['meta']['CLASSES'] model.PALETTE = checkpoint['meta']['PALETTE'] + efficient_test = False + if args.eval_options is not None: + efficient_test = args.eval_options.get('efficient_test', False) + if not distributed: model = MMDataParallel(model, device_ids=[0]) - outputs = single_gpu_test(model, data_loader, args.show, args.show_dir) + outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, + efficient_test) else: model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False) outputs = multi_gpu_test(model, data_loader, args.tmpdir, - args.gpu_collect) + args.gpu_collect, efficient_test) rank, _ = get_dist_info() if rank == 0: