Skip to content

Commit

Permalink
add metric mFscore (#509)
Browse files Browse the repository at this point in the history
* add mFscore and refactor the metrics return value

* fix linting

* some docstring and name fix
  • Loading branch information
sshuair authored Apr 30, 2021
1 parent cf2cb54 commit e16e0e3
Show file tree
Hide file tree
Showing 7 changed files with 318 additions and 85 deletions.
6 changes: 3 additions & 3 deletions mmseg/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .class_names import get_classes, get_palette
from .eval_hooks import DistEvalHook, EvalHook
from .metrics import eval_metrics, mean_dice, mean_iou
from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou

__all__ = [
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
'get_classes', 'get_palette'
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
'eval_metrics', 'get_classes', 'get_palette'
]
129 changes: 106 additions & 23 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
from collections import OrderedDict

import mmcv
import numpy as np
import torch


def f_score(precision, recall, beta=1):
"""calcuate the f-score value.
Args:
precision (float | torch.Tensor): The precision value.
recall (float | torch.Tensor): The recall value.
beta (int): Determines the weight of recall in the combined score.
Default: False.
Returns:
[torch.tensor]: The f-score value.
"""
score = (1 + beta**2) * (precision * recall) / (
(beta**2 * precision) + recall)
return score


def intersect_and_union(pred_label,
label,
num_classes,
Expand Down Expand Up @@ -133,11 +152,12 @@ def mean_iou(results,
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category IoU, shape (num_classes, ).
dict[str, float | ndarray]:
<aAcc> float: Overall accuracy on all images.
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
<IoU> ndarray: Per category IoU, shape (num_classes, ).
"""
all_acc, acc, iou = eval_metrics(
iou_result = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
Expand All @@ -146,7 +166,7 @@ def mean_iou(results,
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return all_acc, acc, iou
return iou_result


def mean_dice(results,
Expand All @@ -171,12 +191,13 @@ def mean_dice(results,
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category dice, shape (num_classes, ).
dict[str, float | ndarray]: Default metrics.
<aAcc> float: Overall accuracy on all images.
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
<Dice> ndarray: Per category dice, shape (num_classes, ).
"""

all_acc, acc, dice = eval_metrics(
dice_result = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
Expand All @@ -185,7 +206,52 @@ def mean_dice(results,
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return all_acc, acc, dice
return dice_result


def mean_fscore(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False,
beta=1):
"""Calculate Mean Intersection and Union (mIoU)
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
beta (int): Determines the weight of recall in the combined score.
Default: False.
Returns:
dict[str, float | ndarray]: Default metrics.
<aAcc> float: Overall accuracy on all images.
<Fscore> ndarray: Per category recall, shape (num_classes, ).
<Precision> ndarray: Per category precision, shape (num_classes, ).
<Recall> ndarray: Per category f-score, shape (num_classes, ).
"""
fscore_result = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mFscore'],
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label,
beta=beta)
return fscore_result


def eval_metrics(results,
Expand All @@ -195,7 +261,8 @@ def eval_metrics(results,
metrics=['mIoU'],
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
reduce_zero_label=False,
beta=1):
"""Calculate evaluation metrics
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
Expand All @@ -210,13 +277,13 @@ def eval_metrics(results,
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
"""
if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice']
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics))

Expand All @@ -225,19 +292,35 @@ def eval_metrics(results,
results, gt_seg_maps, num_classes, ignore_index, label_map,
reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
ret_metrics = [all_acc, acc]
ret_metrics = OrderedDict({'aAcc': all_acc})
for metric in metrics:
if metric == 'mIoU':
iou = total_area_intersect / total_area_union
ret_metrics.append(iou)
acc = total_area_intersect / total_area_label
ret_metrics['IoU'] = iou
ret_metrics['Acc'] = acc
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
ret_metrics.append(dice)
ret_metrics = [metric.numpy() for metric in ret_metrics]
acc = total_area_intersect / total_area_label
ret_metrics['Dice'] = dice
ret_metrics['Acc'] = acc
elif metric == 'mFscore':
precision = total_area_intersect / total_area_pred_label
recall = total_area_intersect / total_area_label
f_value = torch.tensor(
[f_score(x[0], x[1], beta) for x in zip(precision, recall)])
ret_metrics['Fscore'] = f_value
ret_metrics['Precision'] = precision
ret_metrics['Recall'] = recall

ret_metrics = {
metric: value.numpy()
for metric, value in ret_metrics.items()
}
if nan_to_num is not None:
ret_metrics = [
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
]
ret_metrics = OrderedDict({
metric: np.nan_to_num(metric_value, nan=nan_to_num)
for metric, metric_value in ret_metrics.items()
})
return ret_metrics
84 changes: 50 additions & 34 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import os.path as osp
from collections import OrderedDict
from functools import reduce

import mmcv
import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable
from prettytable import PrettyTable
from torch.utils.data import Dataset

from mmseg.core import eval_metrics
Expand Down Expand Up @@ -312,8 +313,8 @@ def evaluate(self,
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. 'mIoU' and
'mDice' are supported.
metric (str | list[str]): Metrics to be evaluated. 'mIoU',
'mDice' and 'mFscore' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
Expand All @@ -323,7 +324,7 @@ def evaluate(self,

if isinstance(metric, str):
metric = [metric]
allowed_metrics = ['mIoU', 'mDice']
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric))
eval_results = {}
Expand All @@ -341,42 +342,57 @@ def evaluate(self,
metric,
label_map=self.label_map,
reduce_zero_label=self.reduce_zero_label)
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]

if self.CLASSES is None:
class_names = tuple(range(num_classes))
else:
class_names = self.CLASSES
ret_metrics_round = [
np.round(ret_metric * 100, 2) for ret_metric in ret_metrics
]
for i in range(num_classes):
class_table_data.append([class_names[i]] +
[m[i] for m in ret_metrics_round[2:]] +
[ret_metrics_round[1][i]])
summary_table_data = [['Scope'] +
['m' + head
for head in class_table_data[0][1:]] + ['aAcc']]
ret_metrics_mean = [
np.round(np.nanmean(ret_metric) * 100, 2)
for ret_metric in ret_metrics
]
summary_table_data.append(['global'] + ret_metrics_mean[2:] +
[ret_metrics_mean[1]] +
[ret_metrics_mean[0]])

# summary table
ret_metrics_summary = OrderedDict({
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})

# each class table
ret_metrics.pop('aAcc', None)
ret_metrics_class = OrderedDict({
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
ret_metrics_class.update({'Class': class_names})
ret_metrics_class.move_to_end('Class', last=False)

# for logger
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)

summary_table_data = PrettyTable()
for key, val in ret_metrics_summary.items():
if key == 'aAcc':
summary_table_data.add_column(key, [val])
else:
summary_table_data.add_column('m' + key, [val])

print_log('per class results:', logger)
table = AsciiTable(class_table_data)
print_log('\n' + table.table, logger=logger)
print_log('\n' + class_table_data.get_string(), logger=logger)
print_log('Summary:', logger)
table = AsciiTable(summary_table_data)
print_log('\n' + table.table, logger=logger)

for i in range(1, len(summary_table_data[0])):
eval_results[summary_table_data[0]
[i]] = summary_table_data[1][i] / 100.0
for idx, sub_metric in enumerate(class_table_data[0][1:], 1):
for item in class_table_data[1:]:
eval_results[str(sub_metric) + '.' +
str(item[0])] = item[idx] / 100.0
print_log('\n' + summary_table_data.get_string(), logger=logger)

# each metric dict
for key, value in ret_metrics_summary.items():
if key == 'aAcc':
eval_results[key] = value / 100.0
else:
eval_results['m' + key] = value / 100.0

ret_metrics_class.pop('Class', None)
for key, value in ret_metrics_class.items():
eval_results.update({
key + '.' + str(name): value[idx] / 100.0
for idx, name in enumerate(class_names)
})

if mmcv.is_list_of(results, str):
for file_name in results:
Expand Down
2 changes: 1 addition & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
matplotlib
numpy
terminaltables
prettytable
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmseg
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,seaborn,terminaltables,torch
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
14 changes: 12 additions & 2 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_custom_dataset():
for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
Expand Down Expand Up @@ -193,13 +193,23 @@ def test_custom_dataset():
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results

eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
assert isinstance(eval_results, dict)
assert 'mRecall' in eval_results
assert 'mPrecision' in eval_results
assert 'mFscore' in eval_results
assert 'aAcc' in eval_results

eval_results = train_dataset.evaluate(
pseudo_results, metric=['mIoU', 'mDice'])
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
assert 'mFscore' in eval_results
assert 'mPrecision' in eval_results
assert 'mRecall' in eval_results


@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
Expand Down
Loading

0 comments on commit e16e0e3

Please sign in to comment.