forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add dice evaluation metric (open-mmlab#225)
* add dice evaluation metric * add dice evaluation metric * add dice evaluation metric * support 2 metrics * support 2 metrics * support 2 metrics * support 2 metrics * fix docstring * use np.round once for all
- Loading branch information
1 parent
d8f780c
commit 1530af6
Showing
9 changed files
with
420 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from .class_names import get_classes, get_palette | ||
from .eval_hooks import DistEvalHook, EvalHook | ||
from .mean_iou import mean_iou | ||
from .metrics import eval_metrics, mean_dice, mean_iou | ||
|
||
__all__ = [ | ||
'EvalHook', 'DistEvalHook', 'mean_iou', 'get_classes', 'get_palette' | ||
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics', | ||
'get_classes', 'get_palette' | ||
] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import numpy as np | ||
|
||
|
||
def intersect_and_union(pred_label, label, num_classes, ignore_index): | ||
"""Calculate intersection and Union. | ||
Args: | ||
pred_label (ndarray): Prediction segmentation map | ||
label (ndarray): Ground truth segmentation map | ||
num_classes (int): Number of categories | ||
ignore_index (int): Index that will be ignored in evaluation. | ||
Returns: | ||
ndarray: The intersection of prediction and ground truth histogram | ||
on all classes | ||
ndarray: The union of prediction and ground truth histogram on all | ||
classes | ||
ndarray: The prediction histogram on all classes. | ||
ndarray: The ground truth histogram on all classes. | ||
""" | ||
|
||
mask = (label != ignore_index) | ||
pred_label = pred_label[mask] | ||
label = label[mask] | ||
|
||
intersect = pred_label[pred_label == label] | ||
area_intersect, _ = np.histogram( | ||
intersect, bins=np.arange(num_classes + 1)) | ||
area_pred_label, _ = np.histogram( | ||
pred_label, bins=np.arange(num_classes + 1)) | ||
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1)) | ||
area_union = area_pred_label + area_label - area_intersect | ||
|
||
return area_intersect, area_union, area_pred_label, area_label | ||
|
||
|
||
def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index): | ||
"""Calculate Total Intersection and Union. | ||
Args: | ||
results (list[ndarray]): List of prediction segmentation maps | ||
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps | ||
num_classes (int): Number of categories | ||
ignore_index (int): Index that will be ignored in evaluation. | ||
Returns: | ||
ndarray: The intersection of prediction and ground truth histogram | ||
on all classes | ||
ndarray: The union of prediction and ground truth histogram on all | ||
classes | ||
ndarray: The prediction histogram on all classes. | ||
ndarray: The ground truth histogram on all classes. | ||
""" | ||
|
||
num_imgs = len(results) | ||
assert len(gt_seg_maps) == num_imgs | ||
total_area_intersect = np.zeros((num_classes, ), dtype=np.float) | ||
total_area_union = np.zeros((num_classes, ), dtype=np.float) | ||
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float) | ||
total_area_label = np.zeros((num_classes, ), dtype=np.float) | ||
for i in range(num_imgs): | ||
area_intersect, area_union, area_pred_label, area_label = \ | ||
intersect_and_union(results[i], gt_seg_maps[i], num_classes, | ||
ignore_index=ignore_index) | ||
total_area_intersect += area_intersect | ||
total_area_union += area_union | ||
total_area_pred_label += area_pred_label | ||
total_area_label += area_label | ||
return total_area_intersect, total_area_union, \ | ||
total_area_pred_label, total_area_label | ||
|
||
|
||
def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None): | ||
"""Calculate Mean Intersection and Union (mIoU) | ||
Args: | ||
results (list[ndarray]): List of prediction segmentation maps | ||
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps | ||
num_classes (int): Number of categories | ||
ignore_index (int): Index that will be ignored in evaluation. | ||
nan_to_num (int, optional): If specified, NaN values will be replaced | ||
by the numbers defined by the user. Default: None. | ||
Returns: | ||
float: Overall accuracy on all images. | ||
ndarray: Per category accuracy, shape (num_classes, ) | ||
ndarray: Per category IoU, shape (num_classes, ) | ||
""" | ||
|
||
all_acc, acc, iou = eval_metrics( | ||
results=results, | ||
gt_seg_maps=gt_seg_maps, | ||
num_classes=num_classes, | ||
ignore_index=ignore_index, | ||
metrics=['mIoU'], | ||
nan_to_num=nan_to_num) | ||
return all_acc, acc, iou | ||
|
||
|
||
def mean_dice(results, | ||
gt_seg_maps, | ||
num_classes, | ||
ignore_index, | ||
nan_to_num=None): | ||
"""Calculate Mean Dice (mDice) | ||
Args: | ||
results (list[ndarray]): List of prediction segmentation maps | ||
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps | ||
num_classes (int): Number of categories | ||
ignore_index (int): Index that will be ignored in evaluation. | ||
nan_to_num (int, optional): If specified, NaN values will be replaced | ||
by the numbers defined by the user. Default: None. | ||
Returns: | ||
float: Overall accuracy on all images. | ||
ndarray: Per category accuracy, shape (num_classes, ) | ||
ndarray: Per category dice, shape (num_classes, ) | ||
""" | ||
|
||
all_acc, acc, dice = eval_metrics( | ||
results=results, | ||
gt_seg_maps=gt_seg_maps, | ||
num_classes=num_classes, | ||
ignore_index=ignore_index, | ||
metrics=['mDice'], | ||
nan_to_num=nan_to_num) | ||
return all_acc, acc, dice | ||
|
||
|
||
def eval_metrics(results, | ||
gt_seg_maps, | ||
num_classes, | ||
ignore_index, | ||
metrics=['mIoU'], | ||
nan_to_num=None): | ||
"""Calculate evaluation metrics | ||
Args: | ||
results (list[ndarray]): List of prediction segmentation maps | ||
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps | ||
num_classes (int): Number of categories | ||
ignore_index (int): Index that will be ignored in evaluation. | ||
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. | ||
nan_to_num (int, optional): If specified, NaN values will be replaced | ||
by the numbers defined by the user. Default: None. | ||
Returns: | ||
float: Overall accuracy on all images. | ||
ndarray: Per category accuracy, shape (num_classes, ) | ||
ndarray: Per category evalution metrics, shape (num_classes, ) | ||
""" | ||
|
||
if isinstance(metrics, str): | ||
metrics = [metrics] | ||
allowed_metrics = ['mIoU', 'mDice'] | ||
if not set(metrics).issubset(set(allowed_metrics)): | ||
raise KeyError('metrics {} is not supported'.format(metrics)) | ||
total_area_intersect, total_area_union, total_area_pred_label, \ | ||
total_area_label = total_intersect_and_union(results, gt_seg_maps, | ||
num_classes, | ||
ignore_index=ignore_index) | ||
all_acc = total_area_intersect.sum() / total_area_label.sum() | ||
acc = total_area_intersect / total_area_label | ||
ret_metrics = [all_acc, acc] | ||
for metric in metrics: | ||
if metric == 'mIoU': | ||
iou = total_area_intersect / total_area_union | ||
ret_metrics.append(iou) | ||
elif metric == 'mDice': | ||
dice = 2 * total_area_intersect / ( | ||
total_area_pred_label + total_area_label) | ||
ret_metrics.append(dice) | ||
if nan_to_num is not None: | ||
ret_metrics = [ | ||
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics | ||
] | ||
return ret_metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
matplotlib | ||
numpy | ||
terminaltables |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.