Skip to content

Commit e16e0e3

Browse files
authored
add metric mFscore (#509)
* add mFscore and refactor the metrics return value * fix linting * some docstring and name fix
1 parent cf2cb54 commit e16e0e3

File tree

7 files changed

+318
-85
lines changed

7 files changed

+318
-85
lines changed

mmseg/core/evaluation/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .class_names import get_classes, get_palette
22
from .eval_hooks import DistEvalHook, EvalHook
3-
from .metrics import eval_metrics, mean_dice, mean_iou
3+
from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
44

55
__all__ = [
6-
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
7-
'get_classes', 'get_palette'
6+
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
7+
'eval_metrics', 'get_classes', 'get_palette'
88
]

mmseg/core/evaluation/metrics.py

+106-23
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,27 @@
1+
from collections import OrderedDict
2+
13
import mmcv
24
import numpy as np
35
import torch
46

57

8+
def f_score(precision, recall, beta=1):
9+
"""calcuate the f-score value.
10+
11+
Args:
12+
precision (float | torch.Tensor): The precision value.
13+
recall (float | torch.Tensor): The recall value.
14+
beta (int): Determines the weight of recall in the combined score.
15+
Default: False.
16+
17+
Returns:
18+
[torch.tensor]: The f-score value.
19+
"""
20+
score = (1 + beta**2) * (precision * recall) / (
21+
(beta**2 * precision) + recall)
22+
return score
23+
24+
625
def intersect_and_union(pred_label,
726
label,
827
num_classes,
@@ -133,11 +152,12 @@ def mean_iou(results,
133152
reduce_zero_label (bool): Wether ignore zero label. Default: False.
134153
135154
Returns:
136-
float: Overall accuracy on all images.
137-
ndarray: Per category accuracy, shape (num_classes, ).
138-
ndarray: Per category IoU, shape (num_classes, ).
155+
dict[str, float | ndarray]:
156+
<aAcc> float: Overall accuracy on all images.
157+
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
158+
<IoU> ndarray: Per category IoU, shape (num_classes, ).
139159
"""
140-
all_acc, acc, iou = eval_metrics(
160+
iou_result = eval_metrics(
141161
results=results,
142162
gt_seg_maps=gt_seg_maps,
143163
num_classes=num_classes,
@@ -146,7 +166,7 @@ def mean_iou(results,
146166
nan_to_num=nan_to_num,
147167
label_map=label_map,
148168
reduce_zero_label=reduce_zero_label)
149-
return all_acc, acc, iou
169+
return iou_result
150170

151171

152172
def mean_dice(results,
@@ -171,12 +191,13 @@ def mean_dice(results,
171191
reduce_zero_label (bool): Wether ignore zero label. Default: False.
172192
173193
Returns:
174-
float: Overall accuracy on all images.
175-
ndarray: Per category accuracy, shape (num_classes, ).
176-
ndarray: Per category dice, shape (num_classes, ).
194+
dict[str, float | ndarray]: Default metrics.
195+
<aAcc> float: Overall accuracy on all images.
196+
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
197+
<Dice> ndarray: Per category dice, shape (num_classes, ).
177198
"""
178199

179-
all_acc, acc, dice = eval_metrics(
200+
dice_result = eval_metrics(
180201
results=results,
181202
gt_seg_maps=gt_seg_maps,
182203
num_classes=num_classes,
@@ -185,7 +206,52 @@ def mean_dice(results,
185206
nan_to_num=nan_to_num,
186207
label_map=label_map,
187208
reduce_zero_label=reduce_zero_label)
188-
return all_acc, acc, dice
209+
return dice_result
210+
211+
212+
def mean_fscore(results,
213+
gt_seg_maps,
214+
num_classes,
215+
ignore_index,
216+
nan_to_num=None,
217+
label_map=dict(),
218+
reduce_zero_label=False,
219+
beta=1):
220+
"""Calculate Mean Intersection and Union (mIoU)
221+
222+
Args:
223+
results (list[ndarray] | list[str]): List of prediction segmentation
224+
maps or list of prediction result filenames.
225+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
226+
segmentation maps or list of label filenames.
227+
num_classes (int): Number of categories.
228+
ignore_index (int): Index that will be ignored in evaluation.
229+
nan_to_num (int, optional): If specified, NaN values will be replaced
230+
by the numbers defined by the user. Default: None.
231+
label_map (dict): Mapping old labels to new labels. Default: dict().
232+
reduce_zero_label (bool): Wether ignore zero label. Default: False.
233+
beta (int): Determines the weight of recall in the combined score.
234+
Default: False.
235+
236+
237+
Returns:
238+
dict[str, float | ndarray]: Default metrics.
239+
<aAcc> float: Overall accuracy on all images.
240+
<Fscore> ndarray: Per category recall, shape (num_classes, ).
241+
<Precision> ndarray: Per category precision, shape (num_classes, ).
242+
<Recall> ndarray: Per category f-score, shape (num_classes, ).
243+
"""
244+
fscore_result = eval_metrics(
245+
results=results,
246+
gt_seg_maps=gt_seg_maps,
247+
num_classes=num_classes,
248+
ignore_index=ignore_index,
249+
metrics=['mFscore'],
250+
nan_to_num=nan_to_num,
251+
label_map=label_map,
252+
reduce_zero_label=reduce_zero_label,
253+
beta=beta)
254+
return fscore_result
189255

190256

191257
def eval_metrics(results,
@@ -195,7 +261,8 @@ def eval_metrics(results,
195261
metrics=['mIoU'],
196262
nan_to_num=None,
197263
label_map=dict(),
198-
reduce_zero_label=False):
264+
reduce_zero_label=False,
265+
beta=1):
199266
"""Calculate evaluation metrics
200267
Args:
201268
results (list[ndarray] | list[str]): List of prediction segmentation
@@ -210,13 +277,13 @@ def eval_metrics(results,
210277
label_map (dict): Mapping old labels to new labels. Default: dict().
211278
reduce_zero_label (bool): Wether ignore zero label. Default: False.
212279
Returns:
213-
float: Overall accuracy on all images.
214-
ndarray: Per category accuracy, shape (num_classes, ).
215-
ndarray: Per category evaluation metrics, shape (num_classes, ).
280+
float: Overall accuracy on all images.
281+
ndarray: Per category accuracy, shape (num_classes, ).
282+
ndarray: Per category evaluation metrics, shape (num_classes, ).
216283
"""
217284
if isinstance(metrics, str):
218285
metrics = [metrics]
219-
allowed_metrics = ['mIoU', 'mDice']
286+
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
220287
if not set(metrics).issubset(set(allowed_metrics)):
221288
raise KeyError('metrics {} is not supported'.format(metrics))
222289

@@ -225,19 +292,35 @@ def eval_metrics(results,
225292
results, gt_seg_maps, num_classes, ignore_index, label_map,
226293
reduce_zero_label)
227294
all_acc = total_area_intersect.sum() / total_area_label.sum()
228-
acc = total_area_intersect / total_area_label
229-
ret_metrics = [all_acc, acc]
295+
ret_metrics = OrderedDict({'aAcc': all_acc})
230296
for metric in metrics:
231297
if metric == 'mIoU':
232298
iou = total_area_intersect / total_area_union
233-
ret_metrics.append(iou)
299+
acc = total_area_intersect / total_area_label
300+
ret_metrics['IoU'] = iou
301+
ret_metrics['Acc'] = acc
234302
elif metric == 'mDice':
235303
dice = 2 * total_area_intersect / (
236304
total_area_pred_label + total_area_label)
237-
ret_metrics.append(dice)
238-
ret_metrics = [metric.numpy() for metric in ret_metrics]
305+
acc = total_area_intersect / total_area_label
306+
ret_metrics['Dice'] = dice
307+
ret_metrics['Acc'] = acc
308+
elif metric == 'mFscore':
309+
precision = total_area_intersect / total_area_pred_label
310+
recall = total_area_intersect / total_area_label
311+
f_value = torch.tensor(
312+
[f_score(x[0], x[1], beta) for x in zip(precision, recall)])
313+
ret_metrics['Fscore'] = f_value
314+
ret_metrics['Precision'] = precision
315+
ret_metrics['Recall'] = recall
316+
317+
ret_metrics = {
318+
metric: value.numpy()
319+
for metric, value in ret_metrics.items()
320+
}
239321
if nan_to_num is not None:
240-
ret_metrics = [
241-
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
242-
]
322+
ret_metrics = OrderedDict({
323+
metric: np.nan_to_num(metric_value, nan=nan_to_num)
324+
for metric, metric_value in ret_metrics.items()
325+
})
243326
return ret_metrics

mmseg/datasets/custom.py

+50-34
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import os
22
import os.path as osp
3+
from collections import OrderedDict
34
from functools import reduce
45

56
import mmcv
67
import numpy as np
78
from mmcv.utils import print_log
8-
from terminaltables import AsciiTable
9+
from prettytable import PrettyTable
910
from torch.utils.data import Dataset
1011

1112
from mmseg.core import eval_metrics
@@ -312,8 +313,8 @@ def evaluate(self,
312313
313314
Args:
314315
results (list): Testing results of the dataset.
315-
metric (str | list[str]): Metrics to be evaluated. 'mIoU' and
316-
'mDice' are supported.
316+
metric (str | list[str]): Metrics to be evaluated. 'mIoU',
317+
'mDice' and 'mFscore' are supported.
317318
logger (logging.Logger | None | str): Logger used for printing
318319
related information during evaluation. Default: None.
319320
@@ -323,7 +324,7 @@ def evaluate(self,
323324

324325
if isinstance(metric, str):
325326
metric = [metric]
326-
allowed_metrics = ['mIoU', 'mDice']
327+
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
327328
if not set(metric).issubset(set(allowed_metrics)):
328329
raise KeyError('metric {} is not supported'.format(metric))
329330
eval_results = {}
@@ -341,42 +342,57 @@ def evaluate(self,
341342
metric,
342343
label_map=self.label_map,
343344
reduce_zero_label=self.reduce_zero_label)
344-
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]
345+
345346
if self.CLASSES is None:
346347
class_names = tuple(range(num_classes))
347348
else:
348349
class_names = self.CLASSES
349-
ret_metrics_round = [
350-
np.round(ret_metric * 100, 2) for ret_metric in ret_metrics
351-
]
352-
for i in range(num_classes):
353-
class_table_data.append([class_names[i]] +
354-
[m[i] for m in ret_metrics_round[2:]] +
355-
[ret_metrics_round[1][i]])
356-
summary_table_data = [['Scope'] +
357-
['m' + head
358-
for head in class_table_data[0][1:]] + ['aAcc']]
359-
ret_metrics_mean = [
360-
np.round(np.nanmean(ret_metric) * 100, 2)
361-
for ret_metric in ret_metrics
362-
]
363-
summary_table_data.append(['global'] + ret_metrics_mean[2:] +
364-
[ret_metrics_mean[1]] +
365-
[ret_metrics_mean[0]])
350+
351+
# summary table
352+
ret_metrics_summary = OrderedDict({
353+
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
354+
for ret_metric, ret_metric_value in ret_metrics.items()
355+
})
356+
357+
# each class table
358+
ret_metrics.pop('aAcc', None)
359+
ret_metrics_class = OrderedDict({
360+
ret_metric: np.round(ret_metric_value * 100, 2)
361+
for ret_metric, ret_metric_value in ret_metrics.items()
362+
})
363+
ret_metrics_class.update({'Class': class_names})
364+
ret_metrics_class.move_to_end('Class', last=False)
365+
366+
# for logger
367+
class_table_data = PrettyTable()
368+
for key, val in ret_metrics_class.items():
369+
class_table_data.add_column(key, val)
370+
371+
summary_table_data = PrettyTable()
372+
for key, val in ret_metrics_summary.items():
373+
if key == 'aAcc':
374+
summary_table_data.add_column(key, [val])
375+
else:
376+
summary_table_data.add_column('m' + key, [val])
377+
366378
print_log('per class results:', logger)
367-
table = AsciiTable(class_table_data)
368-
print_log('\n' + table.table, logger=logger)
379+
print_log('\n' + class_table_data.get_string(), logger=logger)
369380
print_log('Summary:', logger)
370-
table = AsciiTable(summary_table_data)
371-
print_log('\n' + table.table, logger=logger)
372-
373-
for i in range(1, len(summary_table_data[0])):
374-
eval_results[summary_table_data[0]
375-
[i]] = summary_table_data[1][i] / 100.0
376-
for idx, sub_metric in enumerate(class_table_data[0][1:], 1):
377-
for item in class_table_data[1:]:
378-
eval_results[str(sub_metric) + '.' +
379-
str(item[0])] = item[idx] / 100.0
381+
print_log('\n' + summary_table_data.get_string(), logger=logger)
382+
383+
# each metric dict
384+
for key, value in ret_metrics_summary.items():
385+
if key == 'aAcc':
386+
eval_results[key] = value / 100.0
387+
else:
388+
eval_results['m' + key] = value / 100.0
389+
390+
ret_metrics_class.pop('Class', None)
391+
for key, value in ret_metrics_class.items():
392+
eval_results.update({
393+
key + '.' + str(name): value[idx] / 100.0
394+
for idx, name in enumerate(class_names)
395+
})
380396

381397
if mmcv.is_list_of(results, str):
382398
for file_name in results:

requirements/runtime.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
matplotlib
22
numpy
3-
terminaltables
3+
prettytable

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ line_length = 79
88
multi_line_output = 0
99
known_standard_library = setuptools
1010
known_first_party = mmseg
11-
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,seaborn,terminaltables,torch
11+
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch
1212
no_lines_before = STDLIB,LOCALFOLDER
1313
default_section = THIRDPARTY

tests/test_data/test_dataset.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_custom_dataset():
159159
for gt_seg_map in gt_seg_maps:
160160
h, w = gt_seg_map.shape
161161
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
162-
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
162+
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
163163
assert isinstance(eval_results, dict)
164164
assert 'mIoU' in eval_results
165165
assert 'mAcc' in eval_results
@@ -193,13 +193,23 @@ def test_custom_dataset():
193193
assert 'mAcc' in eval_results
194194
assert 'aAcc' in eval_results
195195

196+
eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
197+
assert isinstance(eval_results, dict)
198+
assert 'mRecall' in eval_results
199+
assert 'mPrecision' in eval_results
200+
assert 'mFscore' in eval_results
201+
assert 'aAcc' in eval_results
202+
196203
eval_results = train_dataset.evaluate(
197-
pseudo_results, metric=['mIoU', 'mDice'])
204+
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
198205
assert isinstance(eval_results, dict)
199206
assert 'mIoU' in eval_results
200207
assert 'mDice' in eval_results
201208
assert 'mAcc' in eval_results
202209
assert 'aAcc' in eval_results
210+
assert 'mFscore' in eval_results
211+
assert 'mPrecision' in eval_results
212+
assert 'mRecall' in eval_results
203213

204214

205215
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)

0 commit comments

Comments
 (0)