Skip to content

Commit

Permalink
memory efficient test (#330)
Browse files Browse the repository at this point in the history
* memory efficient test

* implement efficient test

* merge

* Add document and docstring

* fix unit test

* add memory usage report
  • Loading branch information
yamengxi authored Jan 10, 2021
1 parent 8ed47ab commit ce46d70
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 67 deletions.
13 changes: 13 additions & 0 deletions docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Optional arguments:
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `mIoU` is available for all dataset. Cityscapes could be evaluated by `cityscapes` as well as standard `mIoU` metrics.
- `--show`: If specified, segmentation results will be plotted on the images and shown in a new window. It is only applicable to single GPU testing and used for debugging and visualization. Please make sure that GUI is available in your environment, otherwise you may encounter the error like `cannot connect to X server`.
- `--show-dir`: If specified, segmentation results will be plotted on the images and saved to the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option.
- `--eval-options`: Optional parameters during evaluation. When `efficient_test=True`, it will save intermediate results to local files to save CPU memory. Make sure that you have enough local storage space (more than 20GB).

Examples:

Expand Down Expand Up @@ -86,3 +87,15 @@ Assume that you have already downloaded the checkpoints to the directory `checkp

You will get png files under `./pspnet_test_results` directory.
You may run `zip -r results.zip pspnet_test_results/` and submit the zip file to [evaluation server](https://www.cityscapes-dataset.com/submit/).

6. CPU memory efficient test DeeplabV3+ on Cityscapes (without saving the test results) and evaluate the mIoU.

```shell
python tools/test.py \
configs/deeplabv3plus/deeplabv3plus_r18-d8_512x1024_80k_cityscapes.py \
deeplabv3plus_r18-d8_512x1024_80k_cityscapes_20201226_080942-cff257fe.pth \
--eval-options efficient_test=True \
--eval mIoU
```

Using ```pmap``` to view CPU memory footprint, it used 2.25GB CPU memory with ```efficient_test=True``` and 11.06GB CPU memory with ```efficient_test=False``` . This optional parameter can save a lot of memory.
63 changes: 53 additions & 10 deletions mmseg/apis/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,48 @@
import tempfile

import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info


def single_gpu_test(model, data_loader, show=False, out_dir=None):
def np2tmp(array, temp_file_name=None):
"""Save ndarray to local numpy file.
Args:
array (ndarray): Ndarray to save.
temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
function will generate a file name with tempfile.NamedTemporaryFile
to save ndarray. Default: None.
Returns:
str: The numpy file name.
"""

if temp_file_name is None:
temp_file_name = tempfile.NamedTemporaryFile(
suffix='.npy', delete=False).name
np.save(temp_file_name, array)
return temp_file_name


def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
efficient_test=False):
"""Test with single GPU.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
data_loader (utils.data.Dataloader): Pytorch data loader.
show (bool): Whether show results during infernece. Default: False.
out_dir (str, optional): If specified, the results will be dumped
into the directory to save output results.
out_dir (str, optional): If specified, the results will be dumped into
the directory to save output results.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
Returns:
list: The prediction results.
Expand All @@ -31,10 +58,6 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None):
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, **data)
if isinstance(result, list):
results.extend(result)
else:
results.append(result)

if show or out_dir:
img_tensor = data['img'][0]
Expand All @@ -61,13 +84,26 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None):
show=show,
out_file=out_file)

if isinstance(result, list):
if efficient_test:
result = [np2tmp(_) for _ in result]
results.extend(result)
else:
if efficient_test:
result = np2tmp(result)
results.append(result)

batch_size = data['img'][0].size(0)
for _ in range(batch_size):
prog_bar.update()
return results


def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
def multi_gpu_test(model,
data_loader,
tmpdir=None,
gpu_collect=False,
efficient_test=False):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
Expand All @@ -78,10 +114,12 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
data_loader (utils.data.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
Returns:
list: The prediction results.
Expand All @@ -96,9 +134,14 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)

if isinstance(result, list):
if efficient_test:
result = [np2tmp(_) for _ in result]
results.extend(result)
else:
if efficient_test:
result = np2tmp(result)
results.append(result)

if rank == 0:
Expand Down
123 changes: 88 additions & 35 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,49 @@
import mmcv
import numpy as np


def intersect_and_union(pred_label, label, num_classes, ignore_index):
def intersect_and_union(pred_label,
label,
num_classes,
ignore_index,
label_map=dict(),
reduce_zero_label=False):
"""Calculate intersection and Union.
Args:
pred_label (ndarray): Prediction segmentation map
label (ndarray): Ground truth segmentation map
num_classes (int): Number of categories
pred_label (ndarray): Prediction segmentation map.
label (ndarray): Ground truth segmentation map.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. The parameter will
work only when label is str. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. The parameter will
work only when label is str. Default: False.
Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes
on all classes.
ndarray: The union of prediction and ground truth histogram on all
classes
classes.
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""

if isinstance(pred_label, str):
pred_label = np.load(pred_label)

if isinstance(label, str):
label = mmcv.imread(label, flag='unchanged', backend='pillow')
# modify if custom classes
if label_map is not None:
for old_id, new_id in label_map.items():
label[label == old_id] = new_id
if reduce_zero_label:
# avoid using underflow conversion
label[label == 0] = 255
label = label - 1
label[label == 254] = 255

mask = (label != ignore_index)
pred_label = pred_label[mask]
label = label[mask]
Expand All @@ -34,20 +59,27 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index):
return area_intersect, area_union, area_pred_label, area_label


def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index):
def total_intersect_and_union(results,
gt_seg_maps,
num_classes,
ignore_index,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Total Intersection and Union.
Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes
on all classes.
ndarray: The union of prediction and ground truth histogram on all
classes
classes.
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""
Expand All @@ -61,7 +93,7 @@ def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index):
for i in range(num_imgs):
area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
ignore_index=ignore_index)
ignore_index, label_map, reduce_zero_label)
total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
Expand All @@ -70,21 +102,29 @@ def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index):
total_area_pred_label, total_area_label


def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
def mean_iou(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Mean Intersection and Union (mIoU)
Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category IoU, shape (num_classes, )
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category IoU, shape (num_classes, ).
"""

all_acc, acc, iou = eval_metrics(
Expand All @@ -93,29 +133,35 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mIoU'],
nan_to_num=nan_to_num)
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return all_acc, acc, iou


def mean_dice(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None):
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Mean Dice (mDice)
Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category dice, shape (num_classes, )
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category dice, shape (num_classes, ).
"""

all_acc, acc, dice = eval_metrics(
Expand All @@ -124,7 +170,9 @@ def mean_dice(results,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mDice'],
nan_to_num=nan_to_num)
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return all_acc, acc, dice


Expand All @@ -133,20 +181,24 @@ def eval_metrics(results,
num_classes,
ignore_index,
metrics=['mIoU'],
nan_to_num=None):
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
"""Calculate evaluation metrics
Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category evalution metrics, shape (num_classes, )
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evalution metrics, shape (num_classes, ).
"""

if isinstance(metrics, str):
Expand All @@ -156,8 +208,9 @@ def eval_metrics(results,
raise KeyError('metrics {} is not supported'.format(metrics))
total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(results, gt_seg_maps,
num_classes,
ignore_index=ignore_index)
num_classes, ignore_index,
label_map,
reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
ret_metrics = [all_acc, acc]
Expand Down
Loading

0 comments on commit ce46d70

Please sign in to comment.