Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dice evaluation metric #225

Merged
merged 10 commits into from
Nov 24, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mmseg/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .class_names import get_classes, get_palette
from .eval_hooks import DistEvalHook, EvalHook
from .mean_iou import mean_iou
from .metrics import eval_metrics, mean_dice, mean_iou

__all__ = [
'EvalHook', 'DistEvalHook', 'mean_iou', 'get_classes', 'get_palette'
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
'get_classes', 'get_palette'
]
74 changes: 0 additions & 74 deletions mmseg/core/evaluation/mean_iou.py

This file was deleted.

166 changes: 166 additions & 0 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import numpy as np


def intersect_and_union(pred_label, label, num_classes, ignore_index):
"""Calculate intersection and Union.

Args:
pred_label (ndarray): Prediction segmentation map
label (ndarray): Ground truth segmentation map
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.

Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes
ndarray: The union of prediction and ground truth histogram on all
classes
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""

mask = (label != ignore_index)
pred_label = pred_label[mask]
label = label[mask]

intersect = pred_label[pred_label == label]
area_intersect, _ = np.histogram(
intersect, bins=np.arange(num_classes + 1))
area_pred_label, _ = np.histogram(
pred_label, bins=np.arange(num_classes + 1))
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
area_union = area_pred_label + area_label - area_intersect

return area_intersect, area_union, area_pred_label, area_label


def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index):
"""Calculate Total Intersection and Union.

Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.

Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes
ndarray: The union of prediction and ground truth histogram on all
classes
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""

num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
total_area_union = np.zeros((num_classes, ), dtype=np.float)
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
total_area_label = np.zeros((num_classes, ), dtype=np.float)
for i in range(num_imgs):
area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
ignore_index=ignore_index)
total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
total_area_label += area_label
return total_area_intersect, total_area_union, \
total_area_pred_label, total_area_label


def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
"""Calculate Mean Intersection and Union (mIoU)

Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.

Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category IoU, shape (num_classes, )
"""

total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(results, gt_seg_maps,
num_classes,
ignore_index=ignore_index)
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
iou = total_area_intersect / total_area_union
if nan_to_num is not None:
return all_acc, np.nan_to_num(acc, nan=nan_to_num), \
np.nan_to_num(iou, nan=nan_to_num)
return all_acc, acc, iou


def mean_dice(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None):
"""Calculate Mean Dice (mDice)

Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.

Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category dice, shape (num_classes, )
"""

total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(results, gt_seg_maps,
num_classes,
ignore_index=ignore_index)
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
if nan_to_num is not None:
return all_acc, np.nan_to_num(acc, nan=nan_to_num), \
np.nan_to_num(dice, nan=nan_to_num)
return all_acc, acc, dice


def eval_metrics(results,
gt_seg_maps,
num_classes,
ignore_index,
metric='mIoU',
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
nan_to_num=None):
"""Calculate evaluation metrics
Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
metric (str): Metrics to be evaluated, 'mIoU' or 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category evalution metrics, shape (num_classes, )
"""

allowed_metrics = {'mIoU': mean_iou, 'mDice': mean_dice}
if (not isinstance(metric, str)) or (metric not in allowed_metrics):
raise KeyError('metric {} is not supported'.format(metric))
all_acc, acc, eval_metric = allowed_metrics[metric](results, gt_seg_maps,
num_classes,
ignore_index,
nan_to_num)
return all_acc, acc, eval_metric
38 changes: 24 additions & 14 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
from mmcv.utils import print_log
from torch.utils.data import Dataset

from mmseg.core import mean_iou
from mmseg.core import eval_metrics
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose


@DATASETS.register_module()
class CustomDataset(Dataset):
"""Custom dataset for semantic segmentation.

An example of file structure is as followed.
"""Custom dataset for semantic segmentation. An example of file structure
is as followed.

.. code-block:: none

Expand Down Expand Up @@ -326,7 +325,7 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
if not isinstance(metric, str):
assert len(metric) == 1
metric = metric[0]
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
allowed_metrics = ['mIoU']
allowed_metrics = ['mIoU', 'mDice']
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
if metric not in allowed_metrics:
raise KeyError('metric {} is not supported'.format(metric))

Expand All @@ -338,33 +337,44 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
else:
num_classes = len(self.CLASSES)

all_acc, acc, iou = mean_iou(
results, gt_seg_maps, num_classes, ignore_index=self.ignore_index)
all_acc, acc, eval_metric = eval_metrics(
results,
gt_seg_maps,
num_classes,
ignore_index=self.ignore_index,
metric=metric)
summary_str = ''
summary_str += 'per class results:\n'

line_format = '{:<15} {:>10} {:>10}\n'
summary_str += line_format.format('Class', 'IoU', 'Acc')
if metric == 'mIoU':
summary_str += line_format.format('Class', 'IoU', 'Acc')
else:
summary_str += line_format.format('Class', 'Dice', 'Acc')
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
if self.CLASSES is None:
class_names = tuple(range(num_classes))
else:
class_names = self.CLASSES
for i in range(num_classes):
iou_str = '{:.2f}'.format(iou[i] * 100)
eval_metric_str = '{:.2f}'.format(eval_metric[i] * 100)
acc_str = '{:.2f}'.format(acc[i] * 100)
summary_str += line_format.format(class_names[i], iou_str, acc_str)
summary_str += line_format.format(class_names[i], eval_metric_str,
acc_str)
summary_str += 'Summary:\n'
line_format = '{:<15} {:>10} {:>10} {:>10}\n'
summary_str += line_format.format('Scope', 'mIoU', 'mAcc', 'aAcc')
if metric == 'mIoU':
summary_str += line_format.format('Scope', 'mIoU', 'mAcc', 'aAcc')
else:
summary_str += line_format.format('Scope', 'mDice', 'mAcc', 'aAcc')

iou_str = '{:.2f}'.format(np.nanmean(iou) * 100)
eval_metric_str = '{:.2f}'.format(np.nanmean(eval_metric) * 100)
acc_str = '{:.2f}'.format(np.nanmean(acc) * 100)
all_acc_str = '{:.2f}'.format(all_acc * 100)
summary_str += line_format.format('global', iou_str, acc_str,
summary_str += line_format.format('global', eval_metric_str, acc_str,
all_acc_str)
print_log(summary_str, logger)

eval_results['mIoU'] = np.nanmean(iou)
eval_results[metric] = np.nanmean(eval_metric)
eval_results['mAcc'] = np.nanmean(acc)
eval_results['aAcc'] = all_acc

Expand Down
16 changes: 14 additions & 2 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,32 @@ def test_custom_dataset():
for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
eval_results = train_dataset.evaluate(pseudo_results)
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results

eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
assert isinstance(eval_results, dict)
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results

# evaluation with CLASSES
train_dataset.CLASSES = tuple(['a'] * 7)
eval_results = train_dataset.evaluate(pseudo_results)
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results

eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
assert isinstance(eval_results, dict)
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results


@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
Expand Down
Loading