Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch metrics implementation #430

Merged
merged 7 commits into from
Mar 29, 2021
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 47 additions & 42 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import mmcv
import numpy as np
import torch


def intersect_and_union(pred_label,
label,
num_classes,
ignore_index,
num_classes: int,
ignore_index: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may update the docstring since pred_label and label could be str now.

label_map=dict(),
reduce_zero_label=False):
"""Calculate intersection and Union.
Expand All @@ -21,25 +22,29 @@ def intersect_and_union(pred_label,
work only when label is str. Default: False.

Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes.
ndarray: The union of prediction and ground truth histogram on all
classes.
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
torch.Tensor: The intersection of prediction and ground truth
histogram on all classes.
torch.Tensor: The union of prediction and ground truth histogram on
all classes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing indent.

torch.Tensor: The prediction histogram on all classes.
torch.Tensor: The ground truth histogram on all classes.
"""

if isinstance(pred_label, str):
pred_label = np.load(pred_label)
pred_label = torch.from_numpy(np.load(pred_label))
else:
pred_label = torch.from_numpy((pred_label))

if isinstance(label, str):
label = mmcv.imread(label, flag='unchanged', backend='pillow')
# modify if custom classes
label = torch.from_numpy(
mmcv.imread(label, flag='unchanged', backend='pillow'))
else:
label = torch.from_numpy(label)

if label_map is not None:
for old_id, new_id in label_map.items():
label[label == old_id] = new_id
if reduce_zero_label:
# avoid using underflow conversion
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
Expand All @@ -49,27 +54,28 @@ def intersect_and_union(pred_label,
label = label[mask]

intersect = pred_label[pred_label == label]
area_intersect, _ = np.histogram(
intersect, bins=np.arange(num_classes + 1))
area_pred_label, _ = np.histogram(
pred_label, bins=np.arange(num_classes + 1))
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=0, max=num_classes)
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=0, max=num_classes)
area_label = torch.histc(
label.float(), bins=(num_classes), min=0, max=num_classes)
area_union = area_pred_label + area_label - area_intersect

return area_intersect, area_union, area_pred_label, area_label


def total_intersect_and_union(results,
gt_seg_maps,
num_classes,
ignore_index,
def total_intersect_and_union(results: list,
gt_seg_maps: list,
num_classes: int,
ignore_index: int,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Total Intersection and Union.

Args:
results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation
maps.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
gt_seg_maps (list[ndarray]): list of ground truth segmentation
maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation
maps.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may update the doc string since results and gt_seg_maps could be list of str now.

num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. Default: dict().
Expand All @@ -83,23 +89,23 @@ def total_intersect_and_union(results,
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""

num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
total_area_union = np.zeros((num_classes, ), dtype=np.float)
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
total_area_label = np.zeros((num_classes, ), dtype=np.float)
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
for i in range(num_imgs):
area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
ignore_index, label_map, reduce_zero_label)
intersect_and_union(
results[i], gt_seg_maps[i], num_classes, ignore_index,
label_map, reduce_zero_label)
total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
total_area_label += area_label
return total_area_intersect, total_area_union, \
total_area_pred_label, total_area_label
return total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label


def mean_iou(results,
Expand All @@ -126,7 +132,6 @@ def mean_iou(results,
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category IoU, shape (num_classes, ).
"""

all_acc, acc, iou = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
Expand Down Expand Up @@ -176,10 +181,10 @@ def mean_dice(results,
return all_acc, acc, dice


def eval_metrics(results,
gt_seg_maps,
num_classes,
ignore_index,
def eval_metrics(results: list,
gt_seg_maps: list,
num_classes: int,
ignore_index: int,
metrics=['mIoU'],
nan_to_num=None,
label_map=dict(),
Expand All @@ -200,17 +205,16 @@ def eval_metrics(results,
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evalution metrics, shape (num_classes, ).
"""

if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics))

total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(results, gt_seg_maps,
num_classes, ignore_index,
label_map,
reduce_zero_label)
total_area_label = total_intersect_and_union(
results, gt_seg_maps, num_classes, ignore_index, label_map,
reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
ret_metrics = [all_acc, acc]
Expand All @@ -222,6 +226,7 @@ def eval_metrics(results,
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
ret_metrics.append(dice)
ret_metrics = [metric.numpy() for metric in ret_metrics]
if nan_to_num is not None:
ret_metrics = [
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
Expand Down