-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(mmeval/segmentation): add MeanIoU
- Loading branch information
Showing
5 changed files
with
469 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from mmeval import core | ||
from .version import __version__ # noqa: F401 | ||
# flake8: noqa | ||
|
||
__all__ = ['core', '__version__'] | ||
from .core import * | ||
from .segmentation import * | ||
from .version import __version__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from .mean_iou import MeanIoU | ||
|
||
__all__ = ['MeanIoU'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,348 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import numpy as np | ||
from typing import (Dict, List, Optional, Sequence, Union, no_type_check, | ||
overload) | ||
|
||
from mmeval.core.base_metric import BaseMetric | ||
from mmeval.core.dispatcher import dispatch | ||
|
||
try: | ||
import torch | ||
except ImportError: | ||
torch = None | ||
|
||
|
||
class MeanIoU(BaseMetric): | ||
"""MeanIoU evaluation metric. | ||
MeanIou is a widely used evaluation metric for image semantic segmentation. | ||
In addition to mean iou, it will also compute and return accuracy, mean | ||
accuracy, mean dice, mean precision, mean recall and mean f-score. | ||
Currently, there are 2 implementations of this metric: NumPy and PyTorch. | ||
Which implementation to use is determined by the type of the calling | ||
parameters. e.g. `numpy.ndarray` or `torch.Tensor`. | ||
Args: | ||
num_classes (int, optional): The number of classes. If None, it will be | ||
obtained from the 'num_classes' or 'classes' field in | ||
`self.dataset_meta`. Defaults to None. | ||
ignore_index (int, optional): Index that will be ignored in evaluation. | ||
Defaults to 255. | ||
nan_to_num (int, optional): If specified, NaN values will be replaced | ||
by the numbers defined by the user. Defaults to None. | ||
beta (int, optional): Determines the weight of recall in the F-score. | ||
Defaults to 1. | ||
verbose_results (bool, optional): If True, the metric results per class | ||
are added to the returned results. Defaults to False. | ||
Examples: | ||
>>> from mmeval import MeanIoU | ||
>>> miou = MeanIoU(num_classes=4) | ||
Use NumPy implementation: | ||
>>> import numpy as np | ||
>>> labels = np.asarray([[0, 1, 1], [2, 3, 2]]) | ||
>>> preds = np.asarray([[0, 2, 1], [1, 3, 2]]) | ||
>>> miou(preds, labels) | ||
{'aAcc': 0.6666666666666666, 'mIoU': 0.6666666666666666, 'mAcc': 0.75, | ||
'mDice': 0.75, 'mPrecision': 0.75, 'mRecall': 0.75, 'mFscore': 0.75} | ||
Use PyTorch implementation: | ||
>>> import torch | ||
>>> labels = torch.Tensor([[0, 1, 1], [2, 3, 2]]) | ||
>>> preds = torch.Tensor([[0, 2, 1], [1, 3, 2]]) | ||
>>> miou(preds, labels) | ||
{'aAcc': 0.6666666666666666, 'mIoU': 0.6666666666666666, 'mAcc': 0.75, | ||
'mDice': 0.75, 'mPrecision': 0.75, 'mRecall': 0.75, 'mFscore': 0.75} | ||
Accumulate batch: | ||
>>> for i in range(10): | ||
... labels = torch.randint(0, 4, size=(100, 100)) | ||
... predicts = torch.randint(0, 4, size=(100, 100)) | ||
... miou.add(predicts, labels) | ||
>>> miou.compute() # doctest: +SKIP | ||
""" | ||
|
||
def __init__(self, | ||
num_classes: Optional[int] = None, | ||
ignore_index: int = 255, | ||
nan_to_num: Optional[int] = None, | ||
beta: int = 1, | ||
verbose_results: bool = False, | ||
**kwargs) -> None: | ||
super().__init__(**kwargs) | ||
|
||
self._num_classes = num_classes | ||
self.ignore_index = ignore_index | ||
self.nan_to_num = nan_to_num | ||
self.beta = beta | ||
self.verbose_results = verbose_results | ||
|
||
@property | ||
def num_classes(self) -> int: | ||
"""Returns the number of classes. | ||
The number of classes should be set during initialization, otherwise it | ||
will be obtained from the 'classes' or 'num_classes' field in | ||
`self.dataset_meta`. | ||
Raises: | ||
RuntimeError: If the num_classes is not set. | ||
Returns: | ||
int: The number of classes. | ||
""" | ||
if self._num_classes is not None: | ||
return self._num_classes | ||
if self.dataset_meta and 'num_classes' in self.dataset_meta: | ||
self._num_classes = self.dataset_meta['num_classes'] | ||
elif self.dataset_meta and 'classes' in self.dataset_meta: | ||
self._num_classes = len(self.dataset_meta['classes']) | ||
else: | ||
raise RuntimeError( | ||
'The `num_claases` is required, and not found in ' | ||
f'dataset_meta: {self.dataset_meta}') | ||
return self._num_classes | ||
|
||
def add(self, predictions: Sequence, labels: Sequence) -> None: # type: ignore # yapf: disable # noqa: E501 | ||
"""Process one batch of data and predictions. | ||
Calculate the confusion matrix from the inputs and update the total | ||
confusion matrix stored in `self._results[0]`. | ||
Args: | ||
data_batch (Sequence): A batch of data from the dataloader. | ||
predictions (Sequence): A batch of outputs from the model. | ||
""" | ||
cm = self.compute_confusion_matrix(predictions, labels, | ||
self.num_classes) | ||
# update the total confusion matrix stored in `self._results[0]` | ||
if len(self._results) == 0: | ||
self._results.append(cm) | ||
else: | ||
# Cumulative the confusion matrix. | ||
total_cm = self._results[0] | ||
self._results[0] = total_cm + cm | ||
|
||
@overload # type: ignore | ||
@dispatch | ||
def compute_confusion_matrix(self, predictions: np.ndarray, | ||
labels: np.ndarray, | ||
num_classes: int) -> np.ndarray: | ||
"""Computing confusion matrix with NumPy. | ||
Args: | ||
predictions (numpy.ndarray): The predicition. | ||
labels (numpy.ndarray): The ground truth. | ||
num_classes (int): The number of classes. | ||
Returns: | ||
numpy.ndarray: The confusion matrix. | ||
""" | ||
# IDEA: Possible implementations - @yancong at 8/26/2022, 12:26:05 PM | ||
# Maybe we can implement some general methods for computing confusion | ||
# matrix in `mmeval.functional`. | ||
mask = (labels != self.ignore_index) | ||
predictions, labels = predictions[mask], labels[mask] | ||
confusion_matrix_1d = np.bincount( | ||
num_classes * labels + predictions, minlength=num_classes**2) | ||
return confusion_matrix_1d.reshape(num_classes, num_classes) | ||
|
||
@dispatch | ||
def compute_confusion_matrix(self, predictions: 'torch.Tensor', | ||
labels: 'torch.Tensor', | ||
num_classes: int) -> 'torch.Tensor': | ||
"""Computing confusion matrix with PyTorch. | ||
Args: | ||
predictions (torch.Tensor): The predicition. | ||
labels (torch.Tensor): The ground truth. | ||
num_classes (int): The number of classes. | ||
Returns: | ||
torch.Tensor: The confusion matrix. | ||
""" | ||
mask = (labels != self.ignore_index) | ||
predictions, labels = predictions[mask], labels[mask] | ||
confusion_matrix_1d = torch.bincount( | ||
(num_classes * labels + predictions).long(), | ||
minlength=num_classes**2) | ||
return confusion_matrix_1d.reshape(num_classes, num_classes) | ||
|
||
@no_type_check | ||
def _compute_core(self, intersect: Sequence, row_wise_sum: Sequence, | ||
col_wise_sum: Sequence) -> Dict: | ||
"""Returns the computed metric. | ||
Args: | ||
intersect (Sequence): The intersect area, namely the true positive | ||
in the confusion matrix. | ||
row_wise_sum (Sequence): The row-wise sum of the confusion matrix, | ||
namely the number of each class in the ground truth. | ||
col_wise_sum (Sequence): The col-wise sum of the confusion matrix, | ||
namely the number of each class in the prediction. | ||
Returns: | ||
Dict: The computed class-wise metric, with following keys: | ||
- IoU, the iou per class. | ||
- Acc, the accuracy per class. | ||
- Dice, the dice per class. | ||
- Precision, the precision per class. | ||
- Recall, the recall per class. | ||
- Fscore, the f-score per class. | ||
""" | ||
# compute iou per class | ||
union = col_wise_sum + row_wise_sum - intersect | ||
iou = intersect / union | ||
|
||
# compute accuracy per class | ||
accuracy = intersect / row_wise_sum | ||
|
||
# compute dice per class | ||
dice = 2 * intersect / (col_wise_sum + row_wise_sum) | ||
|
||
# compute precision, recall and f1 score per class | ||
precision = intersect / col_wise_sum | ||
recall = intersect / row_wise_sum | ||
f_score = (1 + self.beta**2) * (precision * recall) / ( | ||
(self.beta**2 * precision) + recall) | ||
|
||
metric_results = { | ||
'IoU': iou, | ||
'Acc': accuracy, | ||
'Dice': dice, | ||
'Precision': precision, | ||
'Recall': recall, | ||
'Fscore': f_score, | ||
} | ||
return metric_results | ||
|
||
@overload # type: ignore | ||
@dispatch | ||
def _compute_metric(self, results: List[np.ndarray]) -> Dict: | ||
"""A NumPy implementation that compute the MeanIoU metric.""" | ||
# Accumulate the confusion matrix of all ranks. | ||
confusion_matrix: np.ndarray = sum(results) | ||
row_wise_sum = confusion_matrix.sum(1) | ||
col_wise_sum = confusion_matrix.sum(0) | ||
intersect = np.diag(confusion_matrix) | ||
|
||
class_wise_results = self._compute_core(intersect, row_wise_sum, | ||
col_wise_sum) | ||
|
||
metric_results = dict() | ||
# Computing overall accuracy. | ||
metric_results['aAcc'] = intersect.sum() / row_wise_sum.sum() | ||
|
||
# Average class-wise metric. | ||
for key, value in class_wise_results.items(): | ||
if self.nan_to_num is not None: | ||
value = np.nan_to_num(value, nan=self.nan_to_num) | ||
# We prefix the averaged metrics with an 'm'. eg. 'IoU' -> 'mIoU' | ||
metric_results['m' + key] = np.nanmean(value) | ||
|
||
# Add the class-wise metric to the returned results. | ||
if self.verbose_results: | ||
metric_results.update(class_wise_results) | ||
return metric_results | ||
|
||
@dispatch | ||
def _compute_metric(self, results: List['torch.Tensor']) -> Dict: | ||
"""A PyTorch implementation that compute the MeanIoU metric.""" | ||
# Accumulate the confusion matrix of all ranks. | ||
confusion_matrix: 'torch.Tensor' = sum(results) | ||
# NOTE: In PyTorch 1.6, integers cannot be directly divided by `/`, | ||
# so we convert confusion matrix to float here. | ||
if torch.__version__.startswith('1.6.'): | ||
confusion_matrix = confusion_matrix.float() | ||
row_wise_sum = confusion_matrix.sum(1) | ||
col_wise_sum = confusion_matrix.sum(0) | ||
intersect = torch.diag(confusion_matrix) | ||
|
||
class_wise_results = self._compute_core(intersect, row_wise_sum, | ||
col_wise_sum) | ||
|
||
metric_results = dict() | ||
# Computing overall accuracy. | ||
metric_results['aAcc'] = (intersect.sum() / row_wise_sum.sum()).item() | ||
|
||
# Average class-wise metric. | ||
for key, value in class_wise_results.items(): | ||
if self.nan_to_num is not None: | ||
value = torch.nan_to_num(value, nan=self.nan_to_num) | ||
# We prefix the averaged metrics with an 'm'. e.g. 'IoU' -> 'mIoU' | ||
# NOTE: For PyTorch version compatibility, | ||
# use `torch.mean(a[~a.isnan()])` instead of `torch.nanmean(a)`. | ||
metric_results['m' + key] = torch.mean( | ||
value[~value.isnan()]).item() | ||
|
||
# Add the class-wise metric to the returned results. | ||
if self.verbose_results: | ||
metric_results.update(class_wise_results) | ||
return metric_results | ||
|
||
def compute_metric( | ||
self, results: List[Union[np.ndarray, 'torch.Tensor']]) -> Dict: | ||
"""Compute the MeanIoU metric. | ||
Currently, there are 2 implementations of this method: NumPy and | ||
PyTorch. Which implementation to use is determined by the type of the | ||
calling parameters. e.g. `numpy.ndarray` or `torch.Tensor`. | ||
This method would be invoked in `BaseMetric.compute` after distributed | ||
synchronization. | ||
Args: | ||
results (list): A list of confusion matrices to be accumulated. | ||
This list has already been synced across all ranks. | ||
Returns: | ||
Dict: The computed metric, with following keys: | ||
- aAcc, the overall accuracy. | ||
- mIoU, the mean iou. | ||
- mAcc, the mean accuracy. | ||
- mDice, the mean dice. | ||
- mPrecision, the mean precision. | ||
- mRecall, the mean recall. | ||
- mFscore, the mean f-score. | ||
- Keys from `self._compute_core` if in verbose results mode. | ||
""" | ||
return self._compute_metric(results) | ||
|
||
|
||
# The code below is temporary for test, will be removed. | ||
|
||
if __name__ == '__main__': | ||
|
||
from _miou import IoUMetric as _IoUMetric | ||
torch.manual_seed(0) | ||
|
||
num_classes = 2 | ||
high, width = (224, 224) | ||
batch_size = 48 | ||
np_miou = MeanIoU(num_classes=num_classes, verbose_results=True) | ||
miou = MeanIoU( | ||
dataset_meta={'classes': [i for i in range(num_classes)]}, | ||
verbose_results=True) | ||
_miou = _IoUMetric( | ||
iou_metrics=['mIoU', 'mDice', 'mFscore'], | ||
dataset_meta={'classes': [i for i in range(num_classes)]}) | ||
|
||
for i in range(10): | ||
labels = torch.randint(0, num_classes, size=(batch_size, high, width)) | ||
predicts = torch.randint( | ||
0, num_classes, size=(batch_size, high, width)) | ||
np_miou.add(predicts.numpy(), labels.numpy()) | ||
miou.add(predicts, labels) | ||
_miou.add(predicts, labels) | ||
|
||
print(miou.compute()) | ||
print(np_miou.compute()) | ||
print(_miou.compute()) |
Oops, something went wrong.