Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch metrics implementation #430

Merged
merged 7 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 55 additions & 41 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import mmcv
import numpy as np
import torch


def intersect_and_union(pred_label,
Expand All @@ -11,8 +12,10 @@ def intersect_and_union(pred_label,
"""Calculate intersection and Union.

Args:
pred_label (ndarray): Prediction segmentation map.
label (ndarray): Ground truth segmentation map.
pred_label (ndarray | str): Prediction segmentation map
or predict result filename.
label (ndarray | str): Ground truth segmentation map
or label filename.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. The parameter will
Expand All @@ -21,25 +24,29 @@ def intersect_and_union(pred_label,
work only when label is str. Default: False.

Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes.
ndarray: The union of prediction and ground truth histogram on all
classes.
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
torch.Tensor: The intersection of prediction and ground truth
histogram on all classes.
torch.Tensor: The union of prediction and ground truth histogram on
all classes.
torch.Tensor: The prediction histogram on all classes.
torch.Tensor: The ground truth histogram on all classes.
"""

if isinstance(pred_label, str):
pred_label = np.load(pred_label)
pred_label = torch.from_numpy(np.load(pred_label))
else:
pred_label = torch.from_numpy((pred_label))

if isinstance(label, str):
label = mmcv.imread(label, flag='unchanged', backend='pillow')
# modify if custom classes
label = torch.from_numpy(
mmcv.imread(label, flag='unchanged', backend='pillow'))
else:
label = torch.from_numpy(label)

if label_map is not None:
for old_id, new_id in label_map.items():
label[label == old_id] = new_id
if reduce_zero_label:
# avoid using underflow conversion
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
Expand All @@ -49,13 +56,13 @@ def intersect_and_union(pred_label,
label = label[mask]

intersect = pred_label[pred_label == label]
area_intersect, _ = np.histogram(
intersect, bins=np.arange(num_classes + 1))
area_pred_label, _ = np.histogram(
pred_label, bins=np.arange(num_classes + 1))
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=0, max=num_classes)
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=0, max=num_classes)
area_label = torch.histc(
label.float(), bins=(num_classes), min=0, max=num_classes)
area_union = area_pred_label + area_label - area_intersect

return area_intersect, area_union, area_pred_label, area_label


Expand All @@ -68,8 +75,10 @@ def total_intersect_and_union(results,
"""Calculate Total Intersection and Union.

Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. Default: dict().
Expand All @@ -83,23 +92,23 @@ def total_intersect_and_union(results,
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""

num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
total_area_union = np.zeros((num_classes, ), dtype=np.float)
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
total_area_label = np.zeros((num_classes, ), dtype=np.float)
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
for i in range(num_imgs):
area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
ignore_index, label_map, reduce_zero_label)
intersect_and_union(
results[i], gt_seg_maps[i], num_classes, ignore_index,
label_map, reduce_zero_label)
total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
total_area_label += area_label
return total_area_intersect, total_area_union, \
total_area_pred_label, total_area_label
return total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label


def mean_iou(results,
Expand All @@ -112,8 +121,10 @@ def mean_iou(results,
"""Calculate Mean Intersection and Union (mIoU)

Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
Expand All @@ -126,7 +137,6 @@ def mean_iou(results,
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category IoU, shape (num_classes, ).
"""

all_acc, acc, iou = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
Expand All @@ -149,8 +159,10 @@ def mean_dice(results,
"""Calculate Mean Dice (mDice)

Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
Expand Down Expand Up @@ -186,8 +198,10 @@ def eval_metrics(results,
reduce_zero_label=False):
"""Calculate evaluation metrics
Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
Expand All @@ -200,17 +214,16 @@ def eval_metrics(results,
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evalution metrics, shape (num_classes, ).
"""

if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics))

total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(results, gt_seg_maps,
num_classes, ignore_index,
label_map,
reduce_zero_label)
total_area_label = total_intersect_and_union(
results, gt_seg_maps, num_classes, ignore_index, label_map,
reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
ret_metrics = [all_acc, acc]
Expand All @@ -222,6 +235,7 @@ def eval_metrics(results,
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
ret_metrics.append(dice)
ret_metrics = [metric.numpy() for metric in ret_metrics]
if nan_to_num is not None:
ret_metrics = [
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
Expand Down
42 changes: 42 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,45 @@ def test_mean_dice():
results, label, num_classes, ignore_index=255, nan_to_num=-1)
assert acc[-1] == -1
assert iou[-1] == -1


def test_filename_inputs():
import cv2
import tempfile

def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
filenames = []
SUFFIX = '.png' if is_image else '.npy'
for idx, arr in enumerate(input_arrays):
filename = '{}/{}-{}{}'.format(dir, title, idx, SUFFIX)
if is_image:
cv2.imwrite(filename, arr)
else:
np.save(filename, arr)
filenames.append(filename)
return filenames

pred_size = (10, 512, 1024)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
labels = np.random.randint(0, num_classes, size=pred_size)
labels[:, 2, 5:10] = ignore_index

with tempfile.TemporaryDirectory() as temp_dir:

result_files = save_arr(results, 'pred', False, temp_dir)
label_files = save_arr(labels, 'label', True, temp_dir)

all_acc, acc, iou = eval_metrics(
result_files,
label_files,
num_classes,
ignore_index,
metrics='mIoU')

all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)