diff --git a/mmeval/__init__.py b/mmeval/__init__.py index 22d1134f..057a63ea 100644 --- a/mmeval/__init__.py +++ b/mmeval/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmeval import core -from .version import __version__ # noqa: F401 +# flake8: noqa -__all__ = ['core', '__version__'] +from .classification import * +from .core import * +from .version import __version__ diff --git a/mmeval/classification/__init__.py b/mmeval/classification/__init__.py new file mode 100644 index 00000000..8055f3bb --- /dev/null +++ b/mmeval/classification/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .accuracy import Accuracy + +__all__ = ['Accuracy'] diff --git a/mmeval/classification/accuracy.py b/mmeval/classification/accuracy.py new file mode 100644 index 00000000..b9f90dab --- /dev/null +++ b/mmeval/classification/accuracy.py @@ -0,0 +1,268 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import numpy as np +from typing import Dict, List, Optional, Sequence, Tuple, Union, overload + +from mmeval.core.base_metric import BaseMetric +from mmeval.core.dispatcher import dispatch + +try: + import torch +except ImportError: + torch = None + + +def _numpy_topk(inputs: np.ndarray, + k: int, + axis: Optional[int] = None) -> Tuple: + """A implementation of numpy top-k. + + This implementation returns the values and indices of the k largest + elements along a given axis. + + Note: + If PyTorch is available, the `torch.topk` would be used. + + Args: + inputs (nump.ndarray): The input numpy array. + k (int): The k in `top-k`. + axis (int, optional): The axis to sort along. + + Returns: + tuple: The values and indices of the k largest elements. + """ + if torch is not None: + values, indices = torch.from_numpy(inputs).topk(k, dim=axis) + return values.numpy(), indices.numpy() + + indices = np.argsort(inputs, axis=axis) + indices = np.take(indices, np.arange(k), axis=axis) + values = np.take_along_axis(inputs, indices, axis=axis) + return values, indices + + +NUMPY_IMPL_HINTS = Tuple[Union[np.ndarray, np.int64], np.int64] +TORCH_IMPL_HINTS = Tuple['torch.Tensor', 'torch.Tensor'] + + +class Accuracy(BaseMetric): + """Top-k accuracy evaluation metric. + + This metric computes the accuracy based on the given topk and thresholds. + + Currently, there are 2 implementations of this metric: NumPy and PyTorch. + Which implementation to use is determined by the type of the calling + parameters. e.g. `numpy.ndarray` or `torch.Tensor`. + + Args: + topk (int | Sequence[int]): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thrs (Sequence[float | None] | float | None): Predictions with scores + under the thresholds are considered negative. None means no + thresholds. Defaults to 0. + + Examples: + + >>> from mmeval import Accuracy + >>> accuracy = Accuracy() + + Use NumPy implementation: + + >>> import numpy as np + >>> labels = np.asarray([0, 1, 2, 3]) + >>> preds = np.asarray([0, 2, 1, 3]) + >>> accuracy(preds, labels) + {'top1': 0.5} + + Use PyTorch implementation: + + >>> import torch + >>> labels = torch.Tensor([0, 1, 2, 3]) + >>> preds = torch.Tensor([0, 2, 1, 3]) + >>> accuracy(preds, labels) + {'top1': 0.5} + + Computing top-k accuracy with specified threold: + + >>> labels = np.asarray([0, 1, 2, 3]) + >>> preds = np.asarray([ + [0.7, 0.1, 0.1, 0.1], + [0.1, 0.3, 0.4, 0.2], + [0.3, 0.4, 0.2, 0.1], + [0.0, 0.0, 0.1, 0.9]]) + >>> accuracy = Accuracy(topk=(1, 2, 3)) + >>> accuracy(preds, labels) + {'top1': 0.5, 'top2': 0.75, 'top3': 1.0} + >>> accuracy = Accuracy(topk=2, thrs=(0.1, 0.5)) + >>> accuracy(preds, labels) + {'top2_thr-0.10': 0.75, 'top2_thr-0.50': 0.5} + + Accumulate batch: + + >>> for i in range(10): + ... labels = torch.randint(0, 4, size=(100, )) + ... predicts = torch.randint(0, 4, size=(100, )) + ... accuracy.add(predicts, labels) + >>> accuracy.compute() # doctest: +SKIP + """ + + def __init__(self, + topk: Union[int, Sequence[int]] = (1, ), + thrs: Union[float, Sequence[Union[float, None]], None] = 0., + **kwargs) -> None: + super().__init__(**kwargs) + + if isinstance(topk, int): + self.topk = (topk, ) + else: + self.topk = tuple(topk) # type: ignore + + if isinstance(thrs, float) or thrs is None: + self.thrs = (thrs, ) + else: + self.thrs = tuple(thrs) # type: ignore + + def add(self, predictions: Sequence, labels: Sequence) -> None: # type: ignore # yapf: disable # noqa: E501 + """Add the intermediate results to `self._results`. + + Args: + predictions (Sequence): Predictions from the model. It can be + labels (N, ), or scores of every class (N, C). + labels (Sequence): The ground truth labels. It should be (N, ). + """ + # FIXME: Needing refactor or cleanup @yancong at 9/1/2022, 1:47:59 PM + # Instead of directly storing all predicted values and labels in + # `self._results`, we should use more appropriate intermediate results. + # e.g. confusion matrix. + for pred, label in zip(predictions, labels): + self._results.append((pred, label)) + + def _format_metric_results(self, results_per_topk: List[List]) -> Dict: + """Format the given metric results into a dictionary. + + Args: + results_per_topk (list): A list of per topk and thrs accuracy. + Returns: + dict: The formatted dictionary. + """ + metric_results = {} + for k, result_per_topk in zip(self.topk, results_per_topk): + for thr, result_per_thr in zip(self.thrs, result_per_topk): + name = f'top{k}' + if len(self.thrs) > 1: + name += '_no-thr' if thr is None else f'_thr-{thr:.2f}' + metric_results[name] = result_per_thr + return metric_results + + @overload # type: ignore + @dispatch + def _compute_metric(self, + results: List[TORCH_IMPL_HINTS]) -> Dict[str, float]: + """A PyTorch implementation that compute the accuracy metric.""" + # Concatenating the intermediate results arcoss all ranks. + labels = torch.stack([res[1] for res in results]) + predictions = torch.stack([res[0] for res in results]) + total_length = labels.size(0) + + # In the case where the prediction is a label (N, ), the accuracy is + # calculated directly without considering `topk` and `thrs`. + if predictions.ndim == 1: + correct = (predictions.int() == labels).sum(0, keepdim=True) + acc = correct.float() / total_length + return {'top1': acc.item()} + + # compute the max topk + maxk = max(self.topk) + # NOTE: The torch.topk is non-deterministic with duplicates values. + # See: https://github.com/pytorch/pytorch/issues/27542 + pred_score, pred_label = predictions.topk(maxk, dim=1) + pred_label = pred_label.t() + + # Broadcast `labels` to the shape of `pred_label` and then compute + # correct tensor. + correct = (pred_label == labels.view(1, -1).expand_as(pred_label)) + + # compute the accuracy corresponding to all topk and thrs + results_per_topk = [] + for k in self.topk: + results_per_thr = [] + for thr in self.thrs: + # Only prediction socres larger than thr are counted as correct + if thr is not None: + thr_correct = correct & (pred_score.t() > thr) + else: + thr_correct = correct + topk_thr_correct = thr_correct[:k].reshape(-1).sum( + 0, keepdim=True) + acc = topk_thr_correct.float() / total_length + results_per_thr.append(acc.item()) + results_per_topk.append(results_per_thr) + + return self._format_metric_results(results_per_topk) + + @dispatch + def _compute_metric(self, + results: List[NUMPY_IMPL_HINTS]) -> Dict[str, float]: + """A NumPy implementation that compute the accuracy metric.""" + # Concatenating the intermediate results arcoss all ranks. + labels = np.stack([res[1] for res in results]) + predictions = np.stack([res[0] for res in results]) + total_length = labels.size + + # In the case where the prediction is a label (N, ), the accuracy is + # calculated directly without considering `topk` and `thrs`. + if predictions.ndim == 1: + predictions = predictions.astype(np.int32) + correct = (predictions == labels).sum(0, keepdims=True) + acc = correct / total_length + return {'top1': float(acc)} + + # compute the max topk + maxk = max(self.topk) + pred_score, pred_label = _numpy_topk(predictions, maxk, 1) + pred_label = pred_label.T + + # broadcast `labels` to the shape of `pred_label` + labels = np.broadcast_to(labels.reshape(1, -1), pred_label.shape) + # compute correct tensor + correct = (pred_label == labels) + + # compute the accuracy corresponding to all topk and thrs + results_per_topk = [] + for k in self.topk: + results_per_thr = [] + for thr in self.thrs: + # Only socres greater than thr are counted as correct. + if thr is not None: + thr_correct = correct & (pred_score.T > thr) + else: + thr_correct = correct + topk_thr_correct = thr_correct[:k].reshape(-1).sum( + 0, keepdims=True) + acc = topk_thr_correct / total_length + results_per_thr.append(float(acc)) + results_per_topk.append(results_per_thr) + return self._format_metric_results(results_per_topk) + + def compute_metric( + self, results: List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS]] + ) -> Dict[str, float]: + """Compute the accuracy metric. + + Currently, there are 2 implementations of this method: NumPy and + PyTorch. Which implementation to use is determined by the type of the + calling parameters. e.g. `numpy.ndarray` or `torch.Tensor`. + + This method would be invoked in `BaseMetric.compute` after distributed + synchronization. + + Args: + results (List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS]]): A list + of tuples that consisting the prediction and label. This list + has already been synced across all ranks. + + Returns: + Dict[str, float]: The computed accuracy metric. + """ + return self._compute_metric(results) diff --git a/tests/test_classification/test_accuracy.py b/tests/test_classification/test_accuracy.py new file mode 100644 index 00000000..02118a36 --- /dev/null +++ b/tests/test_classification/test_accuracy.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# yapf: disable + +import numpy as np +import pytest + +from mmeval.classification.accuracy import Accuracy +from mmeval.core.base_metric import BaseMetric + +try: + import torch +except ImportError: + torch = None + + +@pytest.mark.parametrize( + argnames='metric_kwargs', + argvalues=[ + {}, + {'topk': 1, 'thrs': 0.1}, + {'topk': (1, 2), 'thrs': (0.1, 0.2)}, + ] +) +def test_metric_interface(metric_kwargs): + accuracy = Accuracy(**metric_kwargs) + assert isinstance(accuracy, BaseMetric) + assert isinstance(accuracy.topk, tuple) + assert isinstance(accuracy.thrs, tuple) + + results = accuracy(np.asarray([1, 2, 3]), np.asarray([3, 2, 1])) + assert isinstance(results, dict) + results = accuracy( + np.asarray([[0.1, 0.9], [0.5, 0.5]]), np.asarray([0, 1])) + assert isinstance(results, dict) + + +@pytest.mark.skipif(torch is None, reason='PyTorch is not available!') +def test_metric_interface_torch(): + accuracy = Accuracy() + results = accuracy(torch.Tensor([1, 2, 3]), torch.Tensor([3, 2, 1])) + assert isinstance(results, dict) + results = accuracy( + torch.Tensor([[0.1, 0.9], [0.5, 0.5]]), torch.Tensor([0, 1])) + assert isinstance(results, dict) + + +@pytest.mark.parametrize( + argnames=['metric_kwargs', 'preditions', 'labels', 'results'], + argvalues=[ + ({}, [0, 2, 1, 3], [0, 1, 2, 3], {'top1': 0.5}), + ( + {'topk': (1, 2, 3)}, + [ + [0.7, 0.1, 0.1, 0.1], + [0.1, 0.3, 0.4, 0.2], + [0.3, 0.4, 0.2, 0.1], + [0.0, 0.0, 0.1, 0.9] + ], + [0, 1, 2, 3], + {'top1': 0.5, 'top2': 0.75, 'top3': 1.} + ), + ( + {'topk': 2, 'thrs': (0.1, 0.5)}, + [ + [0.7, 0.1, 0.1, 0.1], + [0.1, 0.3, 0.4, 0.2], + [0.3, 0.4, 0.2, 0.1], + [0.0, 0.0, 0.1, 0.9] + ], + [0, 1, 2, 3], + {'top2_thr-0.10': 0.75, 'top2_thr-0.50': 0.5} + ) + ] +) +def test_metric_accurate(metric_kwargs, preditions, labels, results): + accuracy = Accuracy(**metric_kwargs) + assert accuracy(np.asarray(preditions), np.asarray(labels)) == results + + +@pytest.mark.skipif(torch is None, reason='PyTorch is not available!') +@pytest.mark.parametrize( + argnames=('metric_kwargs', 'classes_num', 'length'), + argvalues=[ + ({}, 100, 100), + ({'topk': 1, 'thrs': 0.1}, 1000, 100), + ({'topk': (1, 2, 3), 'thrs': (0.1, 0.2)}, 1000, 10000), + ({'topk': (1, 2, 3), 'thrs': (0.1, None)}, 999, 10002) + ] +) +def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): + """Metamorphic testing for NumPy and PyTorch implementation.""" + accuracy = Accuracy(**metric_kwargs) + + predictions = np.random.rand(length, classes_num) + labels = np.random.randint(0, classes_num, length) + + np_acc_results = accuracy(predictions, labels) + + predictions = torch.from_numpy(predictions) + labels = torch.from_numpy(labels) + torch_acc_results = accuracy(predictions, labels) + + assert np_acc_results.keys() == torch_acc_results.keys() + + for key in np_acc_results: + np.testing.assert_allclose(np_acc_results[key], torch_acc_results[key]) + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--capture=no'])