-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(mmeval/classification): add accuracy metric
- Loading branch information
Showing
8 changed files
with
737 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from mmeval import core | ||
from .version import __version__ # noqa: F401 | ||
# flake8: noqa | ||
|
||
__all__ = ['core', '__version__'] | ||
from .classification import * | ||
from .core import * | ||
from .version import __version__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from .accuracy import Accuracy | ||
|
||
__all__ = ['Accuracy'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import numpy as np | ||
from typing import Dict, List, Optional, Sequence, Tuple, Union, overload | ||
|
||
from mmeval.core.base_metric import BaseMetric | ||
from mmeval.core.dispatcher import dispatch | ||
|
||
try: | ||
import torch | ||
except ImportError: | ||
torch = None | ||
|
||
|
||
def _numpy_topk(inputs: np.ndarray, | ||
k: int, | ||
axis: Optional[int] = None) -> Tuple: | ||
"""A implementation of numpy top-k. | ||
This implementation returns the values and indices of the k largest | ||
elements along a given axis. | ||
Note: | ||
If PyTorch is available, the `torch.topk` would be used. | ||
Args: | ||
inputs (nump.ndarray): The input numpy array. | ||
k (int): The k in `top-k`. | ||
axis (int, optional): The axis to sort along. | ||
Returns: | ||
tuple: The values and indices of the k largest elements. | ||
""" | ||
if torch is not None: | ||
values, indices = torch.from_numpy(inputs).topk(k, dim=axis) | ||
return values.numpy(), indices.numpy() | ||
|
||
indices = np.argsort(inputs, axis=axis) | ||
indices = np.take(indices, np.arange(k), axis=axis) | ||
values = np.take_along_axis(inputs, indices, axis=axis) | ||
return values, indices | ||
|
||
|
||
NUMPY_IMPL_HINTS = Tuple[Union[np.ndarray, np.int64], np.int64] | ||
TORCH_IMPL_HINTS = Tuple['torch.Tensor', 'torch.Tensor'] | ||
|
||
|
||
class Accuracy(BaseMetric): | ||
"""Top-k accuracy evaluation metric. | ||
This metric computes the accuracy based on the given topk and thresholds. | ||
Currently, there are 2 implementations of this metric: NumPy and PyTorch. | ||
Which implementation to use is determined by the type of the calling | ||
parameters. e.g. `numpy.ndarray` or `torch.Tensor`. | ||
Args: | ||
topk (int | Sequence[int]): If the predictions in ``topk`` | ||
matches the target, the predictions will be regarded as | ||
correct ones. Defaults to 1. | ||
thrs (Sequence[float | None] | float | None): Predictions with scores | ||
under the thresholds are considered negative. None means no | ||
thresholds. Defaults to 0. | ||
Examples: | ||
>>> from mmeval import Accuracy | ||
>>> accuracy = Accuracy() | ||
Use NumPy implementation: | ||
>>> import numpy as np | ||
>>> labels = np.asarray([0, 1, 2, 3]) | ||
>>> preds = np.asarray([0, 2, 1, 3]) | ||
>>> accuracy(preds, labels) | ||
{'top1': 0.5} | ||
Use PyTorch implementation: | ||
>>> import torch | ||
>>> labels = torch.Tensor([0, 1, 2, 3]) | ||
>>> preds = torch.Tensor([0, 2, 1, 3]) | ||
>>> accuracy(preds, labels) | ||
{'top1': 0.5} | ||
Computing top-k accuracy with specified threold: | ||
>>> labels = np.asarray([0, 1, 2, 3]) | ||
>>> preds = np.asarray([ | ||
[0.7, 0.1, 0.1, 0.1], | ||
[0.1, 0.3, 0.4, 0.2], | ||
[0.3, 0.4, 0.2, 0.1], | ||
[0.0, 0.0, 0.1, 0.9]]) | ||
>>> accuracy = Accuracy(topk=(1, 2, 3)) | ||
>>> accuracy(preds, labels) | ||
{'top1': 0.5, 'top2': 0.75, 'top3': 1.0} | ||
>>> accuracy = Accuracy(topk=2, thrs=(0.1, 0.5)) | ||
>>> accuracy(preds, labels) | ||
{'top2_thr-0.10': 0.75, 'top2_thr-0.50': 0.5} | ||
Accumulate batch: | ||
>>> for i in range(10): | ||
... labels = torch.randint(0, 4, size=(100, )) | ||
... predicts = torch.randint(0, 4, size=(100, )) | ||
... accuracy.add(predicts, labels) | ||
>>> accuracy.compute() # doctest: +SKIP | ||
""" | ||
|
||
def __init__(self, | ||
topk: Union[int, Sequence[int]] = (1, ), | ||
thrs: Union[float, Sequence[Union[float, None]], None] = 0., | ||
**kwargs) -> None: | ||
super().__init__(**kwargs) | ||
|
||
if isinstance(topk, int): | ||
self.topk = (topk, ) | ||
else: | ||
self.topk = tuple(topk) # type: ignore | ||
|
||
if isinstance(thrs, float) or thrs is None: | ||
self.thrs = (thrs, ) | ||
else: | ||
self.thrs = tuple(thrs) # type: ignore | ||
|
||
def add(self, predictions: Sequence, labels: Sequence) -> None: # type: ignore # yapf: disable # noqa: E501 | ||
"""Add the intermediate results to `self._results`. | ||
Args: | ||
predictions (Sequence): Predictions from the model. It can be | ||
labels (N, ), or scores of every class (N, C). | ||
labels (Sequence): The ground truth labels. It should be (N, ). | ||
""" | ||
# FIXME: Needing refactor or cleanup @yancong at 9/1/2022, 1:47:59 PM | ||
# Instead of directly storing all predicted values and labels in | ||
# `self._results`, we should use more appropriate intermediate results. | ||
# e.g. confusion matrix. | ||
for pred, label in zip(predictions, labels): | ||
self._results.append((pred, label)) | ||
|
||
def _format_metric_results(self, results_per_topk: List[List]) -> Dict: | ||
"""Format the given metric results into a dictionary. | ||
Args: | ||
results_per_topk (list): A list of per topk and thrs accuracy. | ||
Returns: | ||
dict: The formatted dictionary. | ||
""" | ||
metric_results = {} | ||
for k, result_per_topk in zip(self.topk, results_per_topk): | ||
for thr, result_per_thr in zip(self.thrs, result_per_topk): | ||
name = f'top{k}' | ||
if len(self.thrs) > 1: | ||
name += '_no-thr' if thr is None else f'_thr-{thr:.2f}' | ||
metric_results[name] = result_per_thr | ||
return metric_results | ||
|
||
@overload # type: ignore | ||
@dispatch | ||
def _compute_metric(self, | ||
results: List[TORCH_IMPL_HINTS]) -> Dict[str, float]: | ||
"""A PyTorch implementation that compute the accuracy metric.""" | ||
# Concatenating the intermediate results arcoss all ranks. | ||
labels = torch.stack([res[1] for res in results]) | ||
predictions = torch.stack([res[0] for res in results]) | ||
total_length = labels.size(0) | ||
|
||
# In the case where the prediction is a label (N, ), the accuracy is | ||
# calculated directly without considering `topk` and `thrs`. | ||
if predictions.ndim == 1: | ||
correct = (predictions.int() == labels).sum(0, keepdim=True) | ||
acc = correct.float() / total_length | ||
return {'top1': acc.item()} | ||
|
||
# compute the max topk | ||
maxk = max(self.topk) | ||
# NOTE: The torch.topk is non-deterministic with duplicates values. | ||
# See: https://github.com/pytorch/pytorch/issues/27542 | ||
pred_score, pred_label = predictions.topk(maxk, dim=1) | ||
pred_label = pred_label.t() | ||
|
||
# Broadcast `labels` to the shape of `pred_label` and then compute | ||
# correct tensor. | ||
correct = (pred_label == labels.view(1, -1).expand_as(pred_label)) | ||
|
||
# compute the accuracy corresponding to all topk and thrs | ||
results_per_topk = [] | ||
for k in self.topk: | ||
results_per_thr = [] | ||
for thr in self.thrs: | ||
# Only prediction socres larger than thr are counted as correct | ||
if thr is not None: | ||
thr_correct = correct & (pred_score.t() > thr) | ||
else: | ||
thr_correct = correct | ||
topk_thr_correct = thr_correct[:k].reshape(-1).sum( | ||
0, keepdim=True) | ||
acc = topk_thr_correct.float() / total_length | ||
results_per_thr.append(acc.item()) | ||
results_per_topk.append(results_per_thr) | ||
|
||
return self._format_metric_results(results_per_topk) | ||
|
||
@dispatch | ||
def _compute_metric(self, | ||
results: List[NUMPY_IMPL_HINTS]) -> Dict[str, float]: | ||
"""A NumPy implementation that compute the accuracy metric.""" | ||
# Concatenating the intermediate results arcoss all ranks. | ||
labels = np.stack([res[1] for res in results]) | ||
predictions = np.stack([res[0] for res in results]) | ||
total_length = labels.size | ||
|
||
# In the case where the prediction is a label (N, ), the accuracy is | ||
# calculated directly without considering `topk` and `thrs`. | ||
if predictions.ndim == 1: | ||
predictions = predictions.astype(np.int32) | ||
correct = (predictions == labels).sum(0, keepdims=True) | ||
acc = correct / total_length | ||
return {'top1': float(acc)} | ||
|
||
# compute the max topk | ||
maxk = max(self.topk) | ||
pred_score, pred_label = _numpy_topk(predictions, maxk, 1) | ||
pred_label = pred_label.T | ||
|
||
# broadcast `labels` to the shape of `pred_label` | ||
labels = np.broadcast_to(labels.reshape(1, -1), pred_label.shape) | ||
# compute correct tensor | ||
correct = (pred_label == labels) | ||
|
||
# compute the accuracy corresponding to all topk and thrs | ||
results_per_topk = [] | ||
for k in self.topk: | ||
results_per_thr = [] | ||
for thr in self.thrs: | ||
# Only socres greater than thr are counted as correct. | ||
if thr is not None: | ||
thr_correct = correct & (pred_score.T > thr) | ||
else: | ||
thr_correct = correct | ||
topk_thr_correct = thr_correct[:k].reshape(-1).sum( | ||
0, keepdims=True) | ||
acc = topk_thr_correct / total_length | ||
results_per_thr.append(float(acc)) | ||
results_per_topk.append(results_per_thr) | ||
return self._format_metric_results(results_per_topk) | ||
|
||
def compute_metric( | ||
self, results: List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS]] | ||
) -> Dict[str, float]: | ||
"""Compute the accuracy metric. | ||
Currently, there are 2 implementations of this method: NumPy and | ||
PyTorch. Which implementation to use is determined by the type of the | ||
calling parameters. e.g. `numpy.ndarray` or `torch.Tensor`. | ||
This method would be invoked in `BaseMetric.compute` after distributed | ||
synchronization. | ||
Args: | ||
results (List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS]]): A list | ||
of tuples that consisting the prediction and label. This list | ||
has already been synced across all ranks. | ||
Returns: | ||
Dict[str, float]: The computed accuracy metric. | ||
""" | ||
return self._compute_metric(results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
# yapf: disable | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from mmeval.classification.accuracy import Accuracy | ||
from mmeval.core.base_metric import BaseMetric | ||
|
||
try: | ||
import torch | ||
except ImportError: | ||
torch = None | ||
|
||
|
||
@pytest.mark.parametrize( | ||
argnames='metric_kwargs', | ||
argvalues=[ | ||
{}, | ||
{'topk': 1, 'thrs': 0.1}, | ||
{'topk': (1, 2), 'thrs': (0.1, 0.2)}, | ||
] | ||
) | ||
def test_metric_interface(metric_kwargs): | ||
accuracy = Accuracy(**metric_kwargs) | ||
assert isinstance(accuracy, BaseMetric) | ||
assert isinstance(accuracy.topk, tuple) | ||
assert isinstance(accuracy.thrs, tuple) | ||
|
||
results = accuracy(np.asarray([1, 2, 3]), np.asarray([3, 2, 1])) | ||
assert isinstance(results, dict) | ||
results = accuracy( | ||
np.asarray([[0.1, 0.9], [0.5, 0.5]]), np.asarray([0, 1])) | ||
assert isinstance(results, dict) | ||
|
||
|
||
@pytest.mark.skipif(torch is None, reason='PyTorch is not available!') | ||
def test_metric_interface_torch(): | ||
accuracy = Accuracy() | ||
results = accuracy(torch.Tensor([1, 2, 3]), torch.Tensor([3, 2, 1])) | ||
assert isinstance(results, dict) | ||
results = accuracy( | ||
torch.Tensor([[0.1, 0.9], [0.5, 0.5]]), torch.Tensor([0, 1])) | ||
assert isinstance(results, dict) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
argnames=['metric_kwargs', 'preditions', 'labels', 'results'], | ||
argvalues=[ | ||
({}, [0, 2, 1, 3], [0, 1, 2, 3], {'top1': 0.5}), | ||
( | ||
{'topk': (1, 2, 3)}, | ||
[ | ||
[0.7, 0.1, 0.1, 0.1], | ||
[0.1, 0.3, 0.4, 0.2], | ||
[0.3, 0.4, 0.2, 0.1], | ||
[0.0, 0.0, 0.1, 0.9] | ||
], | ||
[0, 1, 2, 3], | ||
{'top1': 0.5, 'top2': 0.75, 'top3': 1.} | ||
), | ||
( | ||
{'topk': 2, 'thrs': (0.1, 0.5)}, | ||
[ | ||
[0.7, 0.1, 0.1, 0.1], | ||
[0.1, 0.3, 0.4, 0.2], | ||
[0.3, 0.4, 0.2, 0.1], | ||
[0.0, 0.0, 0.1, 0.9] | ||
], | ||
[0, 1, 2, 3], | ||
{'top2_thr-0.10': 0.75, 'top2_thr-0.50': 0.5} | ||
) | ||
] | ||
) | ||
def test_metric_accurate(metric_kwargs, preditions, labels, results): | ||
accuracy = Accuracy(**metric_kwargs) | ||
assert accuracy(np.asarray(preditions), np.asarray(labels)) == results | ||
|
||
|
||
@pytest.mark.skipif(torch is None, reason='PyTorch is not available!') | ||
@pytest.mark.parametrize( | ||
argnames=('metric_kwargs', 'classes_num', 'length'), | ||
argvalues=[ | ||
({}, 100, 100), | ||
({'topk': 1, 'thrs': 0.1}, 1000, 100), | ||
({'topk': (1, 2, 3), 'thrs': (0.1, 0.2)}, 1000, 10000), | ||
({'topk': (1, 2, 3), 'thrs': (0.1, None)}, 999, 10002) | ||
] | ||
) | ||
def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): | ||
"""Metamorphic testing for NumPy and PyTorch implementation.""" | ||
accuracy = Accuracy(**metric_kwargs) | ||
|
||
predictions = np.random.rand(length, classes_num) | ||
labels = np.random.randint(0, classes_num, length) | ||
|
||
np_acc_results = accuracy(predictions, labels) | ||
|
||
predictions = torch.from_numpy(predictions) | ||
labels = torch.from_numpy(labels) | ||
torch_acc_results = accuracy(predictions, labels) | ||
|
||
assert np_acc_results.keys() == torch_acc_results.keys() | ||
|
||
for key in np_acc_results: | ||
np.testing.assert_allclose(np_acc_results[key], torch_acc_results[key]) | ||
|
||
|
||
if __name__ == '__main__': | ||
pytest.main([__file__, '-v', '--capture=no']) |
Oops, something went wrong.