Skip to content

Commit

Permalink
Support progressive test with fewer memory cost.
Browse files Browse the repository at this point in the history
  • Loading branch information
sennnnn committed Jul 16, 2021
1 parent 5097d55 commit f3aaecc
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 12 deletions.
6 changes: 4 additions & 2 deletions mmseg/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .inference import inference_segmentor, init_segmentor, show_result_pyplot
from .test import multi_gpu_test, single_gpu_test
from .test import (multi_gpu_test, progressive_multi_gpu_test,
progressive_single_gpu_test, single_gpu_test)
from .train import get_root_logger, set_random_seed, train_segmentor

__all__ = [
'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
'show_result_pyplot'
'show_result_pyplot', 'progressive_single_gpu_test',
'progressive_multi_gpu_test'
]
152 changes: 150 additions & 2 deletions mmseg/apis/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info

from mmseg.core.evaluation.metrics import intersect_and_union


def np2tmp(array, temp_file_name=None, tmpdir=None):
"""Save ndarray to local numpy file.
Expand Down Expand Up @@ -44,8 +46,8 @@ def single_gpu_test(model,
show (bool): Whether show results during inference. Default: False.
out_dir (str, optional): If specified, the results will be dumped into
the directory to save output results.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
efficient_test (bool, optional): Whether save the results as local
numpy files to save CPU memory during evaluation. Default: False.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Expand Down Expand Up @@ -163,3 +165,149 @@ def multi_gpu_test(model,
else:
results = collect_results_cpu(results, len(dataset), tmpdir)
return results


def progressive_single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
opacity=0.5):
model.eval()
dataset = data_loader.dataset
num_classes = len(dataset.CLASSES)
prog_bar = mmcv.ProgressBar(len(dataset))

total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)

cur = 0
for _, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, **data)

if show or out_dir:
img_tensor = data['img'][0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)

for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]

ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h))

if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None

model.module.show_result(
img_show,
result,
palette=dataset.PALETTE,
show=show,
out_file=out_file,
opacity=opacity)

for i in range(len(result)):
gt_semantic_map = dataset.get_gt_seg_map(cur + i)

area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(
result[i], gt_semantic_map, num_classes,
dataset.ignore_index, dataset.label_map,
dataset.reduce_zero_label)

total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
total_area_label += area_label

print(total_area_intersect / total_area_union)

prog_bar.update()

cur += len(result)

return total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label


# TODO: Support distributed test api
def progressive_multi_gpu_test(model,
data_loader,
tmpdir=None,
gpu_collect=False):

model.eval()
dataset = data_loader.dataset
num_classes = len(dataset.CLASSES)
rank, world_size = get_dist_info()
if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset))

total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)

cur = 0
for _, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)

for i in range(len(result)):
gt_semantic_map = dataset.get_gt_seg_map(cur + i * world_size)

area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(
result[i], gt_semantic_map, num_classes,
dataset.ignore_index, dataset.label_map,
dataset.reduce_zero_label)

total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
total_area_label += area_label

if rank == 0:
for _ in range(len(result) * world_size):
prog_bar.update()

cur += len(result) * world_size

pixel_count_matrix = [
total_area_intersect, total_area_union, total_area_pred_label,
total_area_label
]
# collect results from all ranks
if gpu_collect:
results = collect_count_results_gpu(pixel_count_matrix, 4 * world_size)
else:
results = collect_count_results_cpu(pixel_count_matrix, 4 * world_size,
tmpdir)
return results


def collect_count_results_gpu(result_part, size):
"""Collect pixel count matrix result under gpu mode.
On gpu mode, this function will encode results to gpu tensors and use gpu
communication for results collection.
Args:
result_part (list[Tensor]): four type of pixel count matrix --
{area_intersect, area_union, area_pred_label, area_label}, These
four tensor shape of (num_classes, ).
size (int): Size of the results, commonly equal to length of
the results.
"""
pass


def collect_count_results_cpu(result_part, size, tmpdir=None):
pass
5 changes: 3 additions & 2 deletions mmseg/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .class_names import get_classes, get_palette
from .eval_hooks import DistEvalHook, EvalHook
from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
from .metrics import (calculate_metrics, eval_metrics, mean_dice, mean_fscore,
mean_iou)

__all__ = [
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
'eval_metrics', 'get_classes', 'get_palette'
'eval_metrics', 'get_classes', 'get_palette', 'calculate_metrics'
]
66 changes: 66 additions & 0 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,69 @@ def eval_metrics(results,
for metric, metric_value in ret_metrics.items()
})
return ret_metrics


def calculate_metrics(total_area_intersect,
total_area_union,
total_area_pred_label,
total_area_label,
metrics=['mIoU'],
nan_to_num=None,
beta=1):
"""Calculate evaluation metrics
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
"""
if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics))

all_acc = total_area_intersect.sum() / total_area_label.sum()
ret_metrics = OrderedDict({'aAcc': all_acc})
for metric in metrics:
if metric == 'mIoU':
iou = total_area_intersect / total_area_union
acc = total_area_intersect / total_area_label
ret_metrics['IoU'] = iou
ret_metrics['Acc'] = acc
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
acc = total_area_intersect / total_area_label
ret_metrics['Dice'] = dice
ret_metrics['Acc'] = acc
elif metric == 'mFscore':
precision = total_area_intersect / total_area_pred_label
recall = total_area_intersect / total_area_label
f_value = torch.tensor(
[f_score(x[0], x[1], beta) for x in zip(precision, recall)])
ret_metrics['Fscore'] = f_value
ret_metrics['Precision'] = precision
ret_metrics['Recall'] = recall

ret_metrics = {
metric: value.numpy()
for metric, value in ret_metrics.items()
}
if nan_to_num is not None:
ret_metrics = OrderedDict({
metric: np.nan_to_num(metric_value, nan=nan_to_num)
for metric, metric_value in ret_metrics.items()
})
return ret_metrics
84 changes: 84 additions & 0 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.utils.data import Dataset

from mmseg.core import eval_metrics
from mmseg.core.evaluation.metrics import calculate_metrics
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose
Expand Down Expand Up @@ -240,6 +241,13 @@ def get_gt_seg_maps(self, efficient_test=False):
gt_seg_maps.append(gt_seg_map)
return gt_seg_maps

def get_gt_seg_map(self, idx):
"""Get ground truth segmentation maps for evaluation."""
seg_map = osp.join(self.ann_dir, self.img_infos[idx]['ann']['seg_map'])
gt_seg_map = mmcv.imread(seg_map, flag='unchanged', backend='pillow')

return gt_seg_map

def get_classes_and_palette(self, classes=None, palette=None):
"""Get class names of current dataset.
Expand Down Expand Up @@ -303,6 +311,82 @@ def get_palette_for_custom_classes(self, class_names, palette=None):

return palette

def progressive_evaluate(self,
results,
metric='mIoU',
logger=None,
**kwargs):
if isinstance(metric, str):
metric = [metric]
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric))

eval_results = {}

total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = results

ret_metrics = calculate_metrics(total_area_intersect, total_area_union,
total_area_pred_label,
total_area_label, metric)

# Because dataset.CLASSES is required in progressive_single_gpu_test,
# progressive_multi_gpu_test, so it's necessary to keep
# dataset.CLASSES.
class_names = self.CLASSES

# summary table
ret_metrics_summary = OrderedDict({
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})

# each class table
ret_metrics.pop('aAcc', None)
ret_metrics_class = OrderedDict({
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
ret_metrics_class.update({'Class': class_names})
ret_metrics_class.move_to_end('Class', last=False)

# for logger
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)

summary_table_data = PrettyTable()
for key, val in ret_metrics_summary.items():
if key == 'aAcc':
summary_table_data.add_column(key, [val])
else:
summary_table_data.add_column('m' + key, [val])

print_log('per class results:', logger)
print_log('\n' + class_table_data.get_string(), logger=logger)
print_log('Summary:', logger)
print_log('\n' + summary_table_data.get_string(), logger=logger)

# each metric dict
for key, value in ret_metrics_summary.items():
if key == 'aAcc':
eval_results[key] = value / 100.0
else:
eval_results['m' + key] = value / 100.0

ret_metrics_class.pop('Class', None)
for key, value in ret_metrics_class.items():
eval_results.update({
key + '.' + str(name): value[idx] / 100.0
for idx, name in enumerate(class_names)
})

if mmcv.is_list_of(results, str):
for file_name in results:
os.remove(file_name)
return eval_results

def evaluate(self,
results,
metric='mIoU',
Expand Down
Loading

0 comments on commit f3aaecc

Please sign in to comment.