Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dice evaluation metric #225

Merged
merged 10 commits into from
Nov 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mmseg/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .class_names import get_classes, get_palette
from .eval_hooks import DistEvalHook, EvalHook
from .mean_iou import mean_iou
from .metrics import eval_metrics, mean_dice, mean_iou

__all__ = [
'EvalHook', 'DistEvalHook', 'mean_iou', 'get_classes', 'get_palette'
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
'get_classes', 'get_palette'
]
74 changes: 0 additions & 74 deletions mmseg/core/evaluation/mean_iou.py

This file was deleted.

176 changes: 176 additions & 0 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import numpy as np


def intersect_and_union(pred_label, label, num_classes, ignore_index):
"""Calculate intersection and Union.

Args:
pred_label (ndarray): Prediction segmentation map
label (ndarray): Ground truth segmentation map
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.

Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes
ndarray: The union of prediction and ground truth histogram on all
classes
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""

mask = (label != ignore_index)
pred_label = pred_label[mask]
label = label[mask]

intersect = pred_label[pred_label == label]
area_intersect, _ = np.histogram(
intersect, bins=np.arange(num_classes + 1))
area_pred_label, _ = np.histogram(
pred_label, bins=np.arange(num_classes + 1))
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
area_union = area_pred_label + area_label - area_intersect

return area_intersect, area_union, area_pred_label, area_label


def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index):
"""Calculate Total Intersection and Union.

Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.

Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes
ndarray: The union of prediction and ground truth histogram on all
classes
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""

num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
total_area_union = np.zeros((num_classes, ), dtype=np.float)
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
total_area_label = np.zeros((num_classes, ), dtype=np.float)
for i in range(num_imgs):
area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
ignore_index=ignore_index)
total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
total_area_label += area_label
return total_area_intersect, total_area_union, \
total_area_pred_label, total_area_label


def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
"""Calculate Mean Intersection and Union (mIoU)

Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.

Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category IoU, shape (num_classes, )
"""

all_acc, acc, iou = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mIoU'],
nan_to_num=nan_to_num)
return all_acc, acc, iou


def mean_dice(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None):
"""Calculate Mean Dice (mDice)

Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.

Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category dice, shape (num_classes, )
"""

all_acc, acc, dice = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mDice'],
nan_to_num=nan_to_num)
return all_acc, acc, dice


def eval_metrics(results,
gt_seg_maps,
num_classes,
ignore_index,
metrics=['mIoU'],
nan_to_num=None):
"""Calculate evaluation metrics
Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category evalution metrics, shape (num_classes, )
"""

if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics))
total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(results, gt_seg_maps,
num_classes,
ignore_index=ignore_index)
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
ret_metrics = [all_acc, acc]
for metric in metrics:
if metric == 'mIoU':
iou = total_area_intersect / total_area_union
ret_metrics.append(iou)
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
ret_metrics.append(dice)
if nan_to_num is not None:
ret_metrics = [
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
]
return ret_metrics
80 changes: 43 additions & 37 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
import mmcv
import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable
from torch.utils.data import Dataset

from mmseg.core import mean_iou
from mmseg.core import eval_metrics
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose


@DATASETS.register_module()
class CustomDataset(Dataset):
"""Custom dataset for semantic segmentation.

An example of file structure is as followed.
"""Custom dataset for semantic segmentation. An example of file structure
is as followed.

.. code-block:: none

Expand Down Expand Up @@ -315,57 +315,63 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):

Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
metric (str | list[str]): Metrics to be evaluated. 'mIoU' and
'mDice' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.

Returns:
dict[str, float]: Default metrics.
"""

if not isinstance(metric, str):
assert len(metric) == 1
metric = metric[0]
allowed_metrics = ['mIoU']
if metric not in allowed_metrics:
if isinstance(metric, str):
metric = [metric]
allowed_metrics = ['mIoU', 'mDice']
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric))

eval_results = {}
gt_seg_maps = self.get_gt_seg_maps()
if self.CLASSES is None:
num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
else:
num_classes = len(self.CLASSES)

all_acc, acc, iou = mean_iou(
results, gt_seg_maps, num_classes, ignore_index=self.ignore_index)
summary_str = ''
summary_str += 'per class results:\n'

line_format = '{:<15} {:>10} {:>10}\n'
summary_str += line_format.format('Class', 'IoU', 'Acc')
ret_metrics = eval_metrics(
results,
gt_seg_maps,
num_classes,
ignore_index=self.ignore_index,
metrics=metric)
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]
if self.CLASSES is None:
class_names = tuple(range(num_classes))
else:
class_names = self.CLASSES
ret_metrics_round = [
np.round(ret_metric * 100, 2) for ret_metric in ret_metrics
]
for i in range(num_classes):
iou_str = '{:.2f}'.format(iou[i] * 100)
acc_str = '{:.2f}'.format(acc[i] * 100)
summary_str += line_format.format(class_names[i], iou_str, acc_str)
summary_str += 'Summary:\n'
line_format = '{:<15} {:>10} {:>10} {:>10}\n'
summary_str += line_format.format('Scope', 'mIoU', 'mAcc', 'aAcc')

iou_str = '{:.2f}'.format(np.nanmean(iou) * 100)
acc_str = '{:.2f}'.format(np.nanmean(acc) * 100)
all_acc_str = '{:.2f}'.format(all_acc * 100)
summary_str += line_format.format('global', iou_str, acc_str,
all_acc_str)
print_log(summary_str, logger)

eval_results['mIoU'] = np.nanmean(iou)
eval_results['mAcc'] = np.nanmean(acc)
eval_results['aAcc'] = all_acc

class_table_data.append([class_names[i]] +
[m[i] for m in ret_metrics_round[2:]] +
[ret_metrics_round[1][i]])
summary_table_data = [['Scope'] +
['m' + head
for head in class_table_data[0][1:]] + ['aAcc']]
ret_metrics_mean = [
np.round(np.nanmean(ret_metric) * 100, 2)
for ret_metric in ret_metrics
]
summary_table_data.append(['global'] + ret_metrics_mean[2:] +
[ret_metrics_mean[1]] +
[ret_metrics_mean[0]])
print_log('per class results:', logger)
table = AsciiTable(class_table_data)
print_log('\n' + table.table, logger=logger)
print_log('Summary:', logger)
table = AsciiTable(summary_table_data)
print_log('\n' + table.table, logger=logger)

for i in range(1, len(summary_table_data[0])):
eval_results[summary_table_data[0]
[i]] = summary_table_data[1][i] / 100.0
return eval_results
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
matplotlib
numpy
terminaltables
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmseg
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,torch
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,terminaltables,torch
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
Loading