Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add metric mFscore #509

Merged
merged 3 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mmseg/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .class_names import get_classes, get_palette
from .eval_hooks import DistEvalHook, EvalHook
from .metrics import eval_metrics, mean_dice, mean_iou
from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou

__all__ = [
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
'get_classes', 'get_palette'
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
'eval_metrics', 'get_classes', 'get_palette'
]
129 changes: 106 additions & 23 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
from collections import OrderedDict

import mmcv
import numpy as np
import torch


def f_score(precision, recall, beta=1):
"""calcuate the f-score value.

Args:
precision (float | torch.Tensor): The precision value.
recall (float | torch.Tensor): The recall value.
beta (int): Determines the weight of recall in the combined score.
Default: False.

Returns:
[torch.tensor]: The f-score value.
"""
score = (1 + beta**2) * (precision * recall) / (
(beta**2 * precision) + recall)
return score


def intersect_and_union(pred_label,
label,
num_classes,
Expand Down Expand Up @@ -133,11 +152,12 @@ def mean_iou(results,
reduce_zero_label (bool): Wether ignore zero label. Default: False.

Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category IoU, shape (num_classes, ).
dict[str, float | ndarray]:
<aAcc> float: Overall accuracy on all images.
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
<IoU> ndarray: Per category IoU, shape (num_classes, ).
"""
all_acc, acc, iou = eval_metrics(
iou_result = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
Expand All @@ -146,7 +166,7 @@ def mean_iou(results,
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return all_acc, acc, iou
return iou_result


def mean_dice(results,
Expand All @@ -171,12 +191,13 @@ def mean_dice(results,
reduce_zero_label (bool): Wether ignore zero label. Default: False.

Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category dice, shape (num_classes, ).
dict[str, float | ndarray]: Default metrics.
<aAcc> float: Overall accuracy on all images.
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
<Dice> ndarray: Per category dice, shape (num_classes, ).
"""

all_acc, acc, dice = eval_metrics(
dice_result = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
Expand All @@ -185,7 +206,52 @@ def mean_dice(results,
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return all_acc, acc, dice
return dice_result


def mean_fscore(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False,
beta=1):
"""Calculate Mean Intersection and Union (mIoU)

Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
beta (int): Determines the weight of recall in the combined score.
Default: False.


Returns:
dict[str, float | ndarray]: Default metrics.
<aAcc> float: Overall accuracy on all images.
<Fscore> ndarray: Per category recall, shape (num_classes, ).
<Precision> ndarray: Per category precision, shape (num_classes, ).
<Recall> ndarray: Per category f-score, shape (num_classes, ).
"""
fscore_result = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mFscore'],
nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label,
beta=beta)
return fscore_result


def eval_metrics(results,
Expand All @@ -195,7 +261,8 @@ def eval_metrics(results,
metrics=['mIoU'],
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
reduce_zero_label=False,
beta=1):
"""Calculate evaluation metrics
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
Expand All @@ -210,13 +277,13 @@ def eval_metrics(results,
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
"""
if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice']
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics))

Expand All @@ -225,19 +292,35 @@ def eval_metrics(results,
results, gt_seg_maps, num_classes, ignore_index, label_map,
reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
ret_metrics = [all_acc, acc]
ret_metrics = OrderedDict({'aAcc': all_acc})
for metric in metrics:
if metric == 'mIoU':
iou = total_area_intersect / total_area_union
ret_metrics.append(iou)
acc = total_area_intersect / total_area_label
ret_metrics['IoU'] = iou
ret_metrics['Acc'] = acc
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
ret_metrics.append(dice)
ret_metrics = [metric.numpy() for metric in ret_metrics]
acc = total_area_intersect / total_area_label
ret_metrics['Dice'] = dice
ret_metrics['Acc'] = acc
elif metric == 'mFscore':
precision = total_area_intersect / total_area_pred_label
recall = total_area_intersect / total_area_label
f_value = torch.tensor(
[f_score(x[0], x[1], beta) for x in zip(precision, recall)])
ret_metrics['Fscore'] = f_value
ret_metrics['Precision'] = precision
ret_metrics['Recall'] = recall

ret_metrics = {
metric: value.numpy()
for metric, value in ret_metrics.items()
}
if nan_to_num is not None:
ret_metrics = [
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
]
ret_metrics = OrderedDict({
metric: np.nan_to_num(metric_value, nan=nan_to_num)
for metric, metric_value in ret_metrics.items()
})
return ret_metrics
84 changes: 50 additions & 34 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import os.path as osp
from collections import OrderedDict
from functools import reduce

import mmcv
import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable
from prettytable import PrettyTable
from torch.utils.data import Dataset

from mmseg.core import eval_metrics
Expand Down Expand Up @@ -312,8 +313,8 @@ def evaluate(self,

Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. 'mIoU' and
'mDice' are supported.
metric (str | list[str]): Metrics to be evaluated. 'mIoU',
'mDice' and 'mFscore' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.

Expand All @@ -323,7 +324,7 @@ def evaluate(self,

if isinstance(metric, str):
metric = [metric]
allowed_metrics = ['mIoU', 'mDice']
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric))
eval_results = {}
Expand All @@ -341,42 +342,57 @@ def evaluate(self,
metric,
label_map=self.label_map,
reduce_zero_label=self.reduce_zero_label)
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]

if self.CLASSES is None:
class_names = tuple(range(num_classes))
else:
class_names = self.CLASSES
ret_metrics_round = [
np.round(ret_metric * 100, 2) for ret_metric in ret_metrics
]
for i in range(num_classes):
class_table_data.append([class_names[i]] +
[m[i] for m in ret_metrics_round[2:]] +
[ret_metrics_round[1][i]])
summary_table_data = [['Scope'] +
['m' + head
for head in class_table_data[0][1:]] + ['aAcc']]
ret_metrics_mean = [
np.round(np.nanmean(ret_metric) * 100, 2)
for ret_metric in ret_metrics
]
summary_table_data.append(['global'] + ret_metrics_mean[2:] +
[ret_metrics_mean[1]] +
[ret_metrics_mean[0]])

# summary table
ret_metrics_summary = OrderedDict({
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})

# each class table
ret_metrics.pop('aAcc', None)
ret_metrics_class = OrderedDict({
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
ret_metrics_class.update({'Class': class_names})
ret_metrics_class.move_to_end('Class', last=False)

# for logger
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)

summary_table_data = PrettyTable()
for key, val in ret_metrics_summary.items():
if key == 'aAcc':
summary_table_data.add_column(key, [val])
else:
summary_table_data.add_column('m' + key, [val])

print_log('per class results:', logger)
table = AsciiTable(class_table_data)
print_log('\n' + table.table, logger=logger)
print_log('\n' + class_table_data.get_string(), logger=logger)
print_log('Summary:', logger)
table = AsciiTable(summary_table_data)
print_log('\n' + table.table, logger=logger)

for i in range(1, len(summary_table_data[0])):
eval_results[summary_table_data[0]
[i]] = summary_table_data[1][i] / 100.0
for idx, sub_metric in enumerate(class_table_data[0][1:], 1):
for item in class_table_data[1:]:
eval_results[str(sub_metric) + '.' +
str(item[0])] = item[idx] / 100.0
print_log('\n' + summary_table_data.get_string(), logger=logger)

# each metric dict
for key, value in ret_metrics_summary.items():
if key == 'aAcc':
eval_results[key] = value / 100.0
else:
eval_results['m' + key] = value / 100.0

ret_metrics_class.pop('Class', None)
for key, value in ret_metrics_class.items():
eval_results.update({
key + '.' + str(name): value[idx] / 100.0
for idx, name in enumerate(class_names)
})

if mmcv.is_list_of(results, str):
for file_name in results:
Expand Down
2 changes: 1 addition & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
matplotlib
numpy
terminaltables
prettytable
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmseg
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,seaborn,terminaltables,torch
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
14 changes: 12 additions & 2 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_custom_dataset():
for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
Expand Down Expand Up @@ -193,13 +193,23 @@ def test_custom_dataset():
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results

eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
assert isinstance(eval_results, dict)
assert 'mRecall' in eval_results
assert 'mPrecision' in eval_results
assert 'mFscore' in eval_results
assert 'aAcc' in eval_results

eval_results = train_dataset.evaluate(
pseudo_results, metric=['mIoU', 'mDice'])
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
assert 'mFscore' in eval_results
assert 'mPrecision' in eval_results
assert 'mRecall' in eval_results


@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
Expand Down
Loading