diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml index d11dcdc9..1d47e06a 100644 --- a/.pre-commit-config-zh-cn.yaml +++ b/.pre-commit-config-zh-cn.yaml @@ -1,7 +1,7 @@ exclude: ^tests/data/ repos: - repo: https://gitee.com/openmmlab/mirrors-flake8 - rev: 3.8.3 + rev: 5.0.4 hooks: - id: flake8 - repo: https://gitee.com/openmmlab/mirrors-isort @@ -9,11 +9,11 @@ repos: hooks: - id: isort - repo: https://gitee.com/openmmlab/mirrors-yapf - rev: v0.30.0 + rev: v0.32.0 hooks: - id: yapf - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks - rev: v3.1.0 + rev: v4.3.0 hooks: - id: trailing-whitespace - id: check-yaml @@ -26,7 +26,7 @@ repos: - id: mixed-line-ending args: ["--fix=lf"] - repo: https://gitee.com/openmmlab/mirrors-codespell - rev: v2.1.0 + rev: v2.2.1 hooks: - id: codespell - repo: https://gitee.com/openmmlab/mirrors-mdformat @@ -44,7 +44,7 @@ repos: - id: docformatter args: ["--in-place", "--wrap-descriptions", "79"] - repo: https://github.com/asottile/pyupgrade - rev: v2.32.1 + rev: v3.0.0 hooks: - id: pyupgrade args: ["--py36-plus"] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f3b2606e..de19e618 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ exclude: ^tests/data/ repos: - repo: https://github.com/PyCQA/flake8 - rev: 3.8.3 + rev: 5.0.4 hooks: - id: flake8 - repo: https://github.com/PyCQA/isort @@ -9,11 +9,11 @@ repos: hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.30.0 + rev: v0.32.0 hooks: - id: yapf - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.1.0 + rev: v4.3.0 hooks: - id: trailing-whitespace - id: check-yaml @@ -26,7 +26,7 @@ repos: - id: mixed-line-ending args: ["--fix=lf"] - repo: https://github.com/codespell-project/codespell - rev: v2.1.0 + rev: v2.2.1 hooks: - id: codespell - repo: https://github.com/executablebooks/mdformat @@ -44,7 +44,7 @@ repos: - id: docformatter args: ["--in-place", "--wrap-descriptions", "79"] - repo: https://github.com/asottile/pyupgrade - rev: v2.32.1 + rev: v3.0.0 hooks: - id: pyupgrade args: ["--py36-plus"] diff --git a/mmeval/__init__.py b/mmeval/__init__.py index e337c6d4..6f41879b 100644 --- a/mmeval/__init__.py +++ b/mmeval/__init__.py @@ -2,6 +2,7 @@ # flake8: noqa +from .classification import * from .core import * from .segmentation import * from .version import __version__ diff --git a/mmeval/classification/__init__.py b/mmeval/classification/__init__.py new file mode 100644 index 00000000..8055f3bb --- /dev/null +++ b/mmeval/classification/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .accuracy import Accuracy + +__all__ = ['Accuracy'] diff --git a/mmeval/classification/accuracy.py b/mmeval/classification/accuracy.py new file mode 100644 index 00000000..c4b04966 --- /dev/null +++ b/mmeval/classification/accuracy.py @@ -0,0 +1,308 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import numpy as np +from typing import (Dict, Iterable, List, Optional, Sequence, Tuple, Union, + overload) + +from mmeval.core.base_metric import BaseMetric +from mmeval.core.dispatcher import dispatch + +try: + import torch +except ImportError: + torch = None + + +@overload +@dispatch +def _is_scalar(obj: np.number): # type: ignore + """Check if the ``numpy.number`` is a scalar number.""" + return True + + +@overload +@dispatch +def _is_scalar(obj: np.ndarray): # type: ignore + """Check if a ``numpy.ndarray`` is a scalar.""" + return obj.ndim == 0 + + +@overload +@dispatch +def _is_scalar(obj: 'torch.Tensor'): # type: ignore + """Check if a ``torch.Tensor`` is a scalar.""" + return obj.ndim == 0 + + +@dispatch +def _is_scalar(obj): + """Check if an object is a scalar.""" + try: + float(obj) # type: ignore + return True + except Exception: + return False + + +def _torch_topk(inputs: 'torch.Tensor', + k: int, + dim: Optional[int] = None) -> Tuple: + """Invoke the PyTorch topk.""" + return inputs.topk(k, dim=dim) + + +def _numpy_topk(inputs: np.ndarray, + k: int, + axis: Optional[int] = None) -> Tuple: + """A implementation of numpy top-k. + + This implementation returns the values and indices of the k largest + elements along a given axis. + + Args: + inputs (nump.ndarray): The input numpy array. + k (int): The k in `top-k`. + axis (int, optional): The axis to sort along. + + Returns: + tuple: The values and indices of the k largest elements. + + Note: + If PyTorch is available, the ``_torch_topk`` would be used. + """ + if torch is not None: + values, indices = _torch_topk(torch.from_numpy(inputs), k, dim=axis) + return values.numpy(), indices.numpy() + + indices = np.argsort(inputs, axis=axis) + indices = np.take(indices, np.arange(k), axis=axis) + values = np.take_along_axis(inputs, indices, axis=axis) + return values, indices + + +class Accuracy(BaseMetric): + """Top-k accuracy evaluation metric. + + This metric computes the accuracy based on the given topk and thresholds. + + Currently, this metric supports 2 kinds of inputs, i.e. ``numpy.ndarray`` + and ``torch.Tensor``, and the implementation for the calculation depends on + the inputs type. + + Args: + topk (int | Sequence[int]): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thrs (Sequence[float | None] | float | None): Predictions with scores + under the thresholds are considered negative. None means no + thresholds. Defaults to 0. + **kwargs: Keyword parameters passed to :class:`BaseMetric`. + + Examples: + + >>> from mmeval import Accuracy + >>> accuracy = Accuracy() + + Use NumPy implementation: + + >>> import numpy as np + >>> labels = np.asarray([0, 1, 2, 3]) + >>> preds = np.asarray([0, 2, 1, 3]) + >>> accuracy(preds, labels) + {'top1': 0.5} + + Use PyTorch implementation: + + >>> import torch + >>> labels = torch.Tensor([0, 1, 2, 3]) + >>> preds = torch.Tensor([0, 2, 1, 3]) + >>> accuracy(preds, labels) + {'top1': 0.5} + + Computing top-k accuracy with specified threold: + + >>> labels = np.asarray([0, 1, 2, 3]) + >>> preds = np.asarray([ + [0.7, 0.1, 0.1, 0.1], + [0.1, 0.3, 0.4, 0.2], + [0.3, 0.4, 0.2, 0.1], + [0.0, 0.0, 0.1, 0.9]]) + >>> accuracy = Accuracy(topk=(1, 2, 3)) + >>> accuracy(preds, labels) + {'top1': 0.5, 'top2': 0.75, 'top3': 1.0} + >>> accuracy = Accuracy(topk=2, thrs=(0.1, 0.5)) + >>> accuracy(preds, labels) + {'top2_thr-0.10': 0.75, 'top2_thr-0.50': 0.5} + + Accumulate batch: + + >>> for i in range(10): + ... labels = torch.randint(0, 4, size=(100, )) + ... predicts = torch.randint(0, 4, size=(100, )) + ... accuracy.add(predicts, labels) + >>> accuracy.compute() # doctest: +SKIP + """ + + def __init__(self, + topk: Union[int, Sequence[int]] = (1, ), + thrs: Union[float, Sequence[Union[float, None]], None] = 0., + **kwargs) -> None: + super().__init__(**kwargs) + + if isinstance(topk, int): + self.topk = (topk, ) + else: + self.topk = tuple(topk) # type: ignore + self.maxk = max(self.topk) + + if isinstance(thrs, float) or thrs is None: + self.thrs = (thrs, ) + else: + self.thrs = tuple(thrs) # type: ignore + + def add(self, predictions: Sequence, labels: Sequence) -> None: # type: ignore # yapf: disable # noqa: E501 + """Add the intermediate results to ``self._results``. + + Args: + predictions (Sequence): Predictions from the model. It can be + labels (N, ), or scores of every class (N, C). + labels (Sequence): The ground truth labels. It should be (N, ). + """ + corrects = self._compute_corrects(predictions, labels) + for correct in corrects: + self._results.append(correct) + + @overload # type: ignore + @dispatch + def _compute_corrects( + self, predictions: Union['torch.Tensor', Sequence['torch.Tensor']], + labels: Union['torch.Tensor', + Sequence['torch.Tensor']]) -> 'torch.Tensor': + """Compute the correct number of per topk and threshold with PyTorch. + + Args: + prediction (torch.Tensor | Sequence): Predictions from the model. + Same as ``self.add``. + labels (torch.Tensor | Sequence): The ground truth labels. Same as + ``self.add``. + + Returns: + torch.Tensor: Correct number with the following 2 shapes. + + - (N, ): If the ``predictions`` is a label tensor instead of score. + Only return a top-1 correct tensor, and ignore the argument + ``topk`` and ``thrs``. + - (N, num_topk, num_thr): If the ``prediction`` is a score tensor + (number of dimensions is 2). Return the correct number on each + ``topk`` and ``thrs``. + """ + if not isinstance(predictions, torch.Tensor): + predictions = torch.stack(predictions) + if not isinstance(labels, torch.Tensor): + labels = torch.stack(labels) + + if predictions.ndim == 1: + corrects = (predictions.int() == labels) + return corrects.float() + + pred_scores, pred_label = _torch_topk(predictions, self.maxk, dim=1) + pred_label = pred_label.t() + + corrects = (pred_label == labels.view(1, -1).expand_as(pred_label)) + + # compute the corrects corresponding to all topk and thrs per sample + corrects_per_sample = torch.zeros( + (len(predictions), len(self.topk), len(self.thrs))) + for i, k in enumerate(self.topk): + for j, thr in enumerate(self.thrs): + # Only prediction socres larger than thr are counted as correct + if thr is not None: + thr_corrects = corrects & (pred_scores.t() > thr) + else: + thr_corrects = corrects + corrects_per_sample[:, i, j] = thr_corrects[:k].sum( + 0, keepdim=True).float() + return corrects_per_sample + + @dispatch + def _compute_corrects( + self, predictions: Union[np.ndarray, Sequence[np.ndarray]], + labels: Union[np.ndarray, Sequence[np.ndarray]]) -> np.ndarray: + """Compute the correct number of per topk and threshold with NumPy. + + Args: + prediction (numpy.ndarray | Sequence): Predictions from the model. + Same as ``self.add``. + labels (numpy.ndarray | Sequence): The ground truth labels. Same as + ``self.add``. + + Returns: + numpy.ndarray: Correct number with the following 2 shapes. + + - (N, ): If the ``predictions`` is a label array instead of score. + Only return a top-1 correct array, and ignore the argument + ``topk`` and ``thrs``. + - (N, num_topk, num_thr): If the ``prediction`` is a score array + (number of dimensions is 2). Return the correct number on each + ``topk`` and ``thrs``. + """ + if not isinstance(predictions, np.ndarray): + predictions = np.stack(predictions) + if not isinstance(labels, np.ndarray): + labels = np.stack(labels) + + if predictions.ndim == 1: + corrects = (predictions == labels) + return corrects.astype(np.int32) + + pred_scores, pred_label = _numpy_topk(predictions, self.maxk, axis=1) + pred_label = pred_label.T + + # broadcast `label` to the shape of `pred_label` + labels = np.broadcast_to(labels.reshape(1, -1), pred_label.shape) + # compute correct tensor + corrects = (pred_label == labels) + + # compute the corrects corresponding to all topk and thrs per sample + corrects_per_sample = np.zeros( + (len(predictions), len(self.topk), len(self.thrs))) + for i, k in enumerate(self.topk): + for j, thr in enumerate(self.thrs): + # Only prediction socres larger than thr are counted as correct + if thr is not None: + thr_corrects = corrects & (pred_scores.T > thr) + else: + thr_corrects = corrects + corrects_per_sample[:, i, j] = thr_corrects[:k].sum( + 0, keepdims=True).astype(np.int32) + return corrects_per_sample + + def compute_metric( + self, results: List[Union[Iterable, Union[np.number, 'torch.Tensor']]] + ) -> Dict[str, float]: + """Compute the accuracy metric. + + This method would be invoked in ``BaseMetric.compute`` after + distributed synchronization. + + Args: + results (List[Union[Iterable, Union[np.number, torch.Tensor]]]): A + list that consisting the correct numbers. This list has already + been synced across all ranks. + + Returns: + Dict[str, float]: The computed accuracy metric. + """ + if _is_scalar(results[0]): + return {'top1': float(sum(results) / len(results))} # type: ignore + + metric_results = {} + for i, k in enumerate(self.topk): + for j, thr in enumerate(self.thrs): + corrects = [result[i][j] for result in results] # type: ignore + acc = float(sum(corrects) / len(corrects)) + name = f'top{k}' + if len(self.thrs) > 1: + name += '_no-thr' if thr is None else f'_thr-{thr:.2f}' + metric_results[name] = acc + return metric_results diff --git a/setup.cfg b/setup.cfg index 0457227a..df4c3d20 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,3 +15,6 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true [codespell] skip = *.ipynb quiet-level = 3 + +[mypy] +allow_redefinition = True diff --git a/tests/test_classification/test_accuracy.py b/tests/test_classification/test_accuracy.py new file mode 100644 index 00000000..462066e6 --- /dev/null +++ b/tests/test_classification/test_accuracy.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# yapf: disable + +import numpy as np +import pytest + +from mmeval.classification.accuracy import Accuracy +from mmeval.core.base_metric import BaseMetric + +try: + import torch +except ImportError: + torch = None + + +@pytest.mark.parametrize( + argnames='metric_kwargs', + argvalues=[ + {}, + {'topk': 1, 'thrs': 0.1}, + {'topk': (1, 2), 'thrs': (0.1, 0.2)}, + ] +) +def test_metric_interface_numpy(metric_kwargs): + accuracy = Accuracy(**metric_kwargs) + assert isinstance(accuracy, BaseMetric) + assert isinstance(accuracy.topk, tuple) + assert isinstance(accuracy.thrs, tuple) + + results = accuracy(np.asarray([1, 2, 3]), np.asarray([3, 2, 1])) + assert isinstance(results, dict) + results = accuracy( + np.asarray([[0.1, 0.9], [0.5, 0.5]]), np.asarray([0, 1])) + assert isinstance(results, dict) + + +@pytest.mark.skipif(torch is None, reason='PyTorch is not available!') +def test_metric_interface_torch(): + accuracy = Accuracy() + results = accuracy(torch.Tensor([1, 2, 3]), torch.Tensor([3, 2, 1])) + assert isinstance(results, dict) + results = accuracy( + torch.Tensor([[0.1, 0.9], [0.5, 0.5]]), torch.Tensor([0, 1])) + assert isinstance(results, dict) + + +@pytest.mark.parametrize( + argnames=['metric_kwargs', 'preditions', 'labels', 'results'], + argvalues=[ + ({}, [0, 2, 1, 3], [0, 1, 2, 3], {'top1': 0.5}), + ( + {'topk': (1, 2, 3)}, + [ + [0.7, 0.1, 0.1, 0.1], + [0.1, 0.3, 0.4, 0.2], + [0.3, 0.4, 0.2, 0.1], + [0.0, 0.0, 0.1, 0.9] + ], + [0, 1, 2, 3], + {'top1': 0.5, 'top2': 0.75, 'top3': 1.} + ), + ( + {'topk': 2, 'thrs': (0.1, 0.5)}, + [ + [0.7, 0.1, 0.1, 0.1], + [0.1, 0.3, 0.4, 0.2], + [0.3, 0.4, 0.2, 0.1], + [0.0, 0.0, 0.1, 0.9] + ], + [0, 1, 2, 3], + {'top2_thr-0.10': 0.75, 'top2_thr-0.50': 0.5} + ) + ] +) +def test_metric_accurate(metric_kwargs, preditions, labels, results): + accuracy = Accuracy(**metric_kwargs) + assert accuracy(np.asarray(preditions), np.asarray(labels)) == results + + +@pytest.mark.skipif(torch is None, reason='PyTorch is not available!') +@pytest.mark.parametrize( + argnames=('metric_kwargs', 'classes_num', 'length'), + argvalues=[ + ({}, 100, 100), + ({'topk': 1, 'thrs': 0.1}, 1000, 100), + ({'topk': (1, 2, 3), 'thrs': (0.1, 0.2)}, 1000, 10000), + ({'topk': (1, 2, 3), 'thrs': (0.1, None)}, 999, 10002) + ] +) +def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length): + """Metamorphic testing for NumPy and PyTorch implementation.""" + accuracy = Accuracy(**metric_kwargs) + + predictions = np.random.rand(length, classes_num) + labels = np.random.randint(0, classes_num, length) + + np_acc_results = accuracy(predictions, labels) + + predictions = torch.from_numpy(predictions) + labels = torch.from_numpy(labels) + torch_acc_results = accuracy(predictions, labels) + + assert np_acc_results.keys() == torch_acc_results.keys() + + for key in np_acc_results: + np.testing.assert_allclose(np_acc_results[key], torch_acc_results[key]) + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--capture=no'])