Skip to content

Commit

Permalink
feat(mmeval/classification): add accuracy metric
Browse files Browse the repository at this point in the history
  • Loading branch information
ice-tong committed Sep 13, 2022
1 parent cb0d999 commit d951c4c
Show file tree
Hide file tree
Showing 4 changed files with 388 additions and 3 deletions.
7 changes: 4 additions & 3 deletions mmeval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

from mmeval import core
from .version import __version__ # noqa: F401
# flake8: noqa

__all__ = ['core', '__version__']
from .classification import *
from .core import *
from .version import __version__
5 changes: 5 additions & 0 deletions mmeval/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.

from .accuracy import Accuracy

__all__ = ['Accuracy']
268 changes: 268 additions & 0 deletions mmeval/classification/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
# Copyright (c) OpenMMLab. All rights reserved.

import numpy as np
from typing import Dict, List, Optional, Sequence, Tuple, Union, overload

from mmeval.core.base_metric import BaseMetric
from mmeval.core.dispatcher import dispatch

try:
import torch
except ImportError:
torch = None


def _numpy_topk(inputs: np.ndarray,
k: int,
axis: Optional[int] = None) -> Tuple:
"""A implementation of numpy top-k.
This implementation returns the values and indices of the k largest
elements along a given axis.
Note:
If PyTorch is available, the `torch.topk` would be used.
Args:
inputs (nump.ndarray): The input numpy array.
k (int): The k in `top-k`.
axis (int, optional): The axis to sort along.
Returns:
tuple: The values and indices of the k largest elements.
"""
if torch is not None:
values, indices = torch.from_numpy(inputs).topk(k, dim=axis)
return values.numpy(), indices.numpy()

indices = np.argsort(inputs, axis=axis)
indices = np.take(indices, np.arange(k), axis=axis)
values = np.take_along_axis(inputs, indices, axis=axis)
return values, indices


NUMPY_IMPL_HINTS = Tuple[Union[np.ndarray, np.int64], np.int64]
TORCH_IMPL_HINTS = Tuple['torch.Tensor', 'torch.Tensor']


class Accuracy(BaseMetric):
"""Top-k accuracy evaluation metric.
This metric computes the accuracy based on the given topk and thresholds.
Currently, there are 2 implementations of this metric: NumPy and PyTorch.
Which implementation to use is determined by the type of the calling
parameters. e.g. `numpy.ndarray` or `torch.Tensor`.
Args:
topk (int | Sequence[int]): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thrs (Sequence[float | None] | float | None): Predictions with scores
under the thresholds are considered negative. None means no
thresholds. Defaults to 0.
Examples:
>>> from mmeval import Accuracy
>>> accuracy = Accuracy()
Use NumPy implementation:
>>> import numpy as np
>>> labels = np.asarray([0, 1, 2, 3])
>>> preds = np.asarray([0, 2, 1, 3])
>>> accuracy(preds, labels)
{'top1': 0.5}
Use PyTorch implementation:
>>> import torch
>>> labels = torch.Tensor([0, 1, 2, 3])
>>> preds = torch.Tensor([0, 2, 1, 3])
>>> accuracy(preds, labels)
{'top1': 0.5}
Computing top-k accuracy with specified threold:
>>> labels = np.asarray([0, 1, 2, 3])
>>> preds = np.asarray([
[0.7, 0.1, 0.1, 0.1],
[0.1, 0.3, 0.4, 0.2],
[0.3, 0.4, 0.2, 0.1],
[0.0, 0.0, 0.1, 0.9]])
>>> accuracy = Accuracy(topk=(1, 2, 3))
>>> accuracy(preds, labels)
{'top1': 0.5, 'top2': 0.75, 'top3': 1.0}
>>> accuracy = Accuracy(topk=2, thrs=(0.1, 0.5))
>>> accuracy(preds, labels)
{'top2_thr-0.10': 0.75, 'top2_thr-0.50': 0.5}
Accumulate batch:
>>> for i in range(10):
... labels = torch.randint(0, 4, size=(100, ))
... predicts = torch.randint(0, 4, size=(100, ))
... accuracy.add(predicts, labels)
>>> accuracy.compute() # doctest: +SKIP
"""

def __init__(self,
topk: Union[int, Sequence[int]] = (1, ),
thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
**kwargs) -> None:
super().__init__(**kwargs)

if isinstance(topk, int):
self.topk = (topk, )
else:
self.topk = tuple(topk) # type: ignore

if isinstance(thrs, float) or thrs is None:
self.thrs = (thrs, )
else:
self.thrs = tuple(thrs) # type: ignore

def add(self, predictions: Sequence, labels: Sequence) -> None: # type: ignore # yapf: disable # noqa: E501
"""Add the intermediate results to `self._results`.
Args:
predictions (Sequence): Predictions from the model. It can be
labels (N, ), or scores of every class (N, C).
labels (Sequence): The ground truth labels. It should be (N, ).
"""
# FIXME: Needing refactor or cleanup @yancong at 9/1/2022, 1:47:59 PM
# Instead of directly storing all predicted values and labels in
# `self._results`, we should use more appropriate intermediate results.
# e.g. confusion matrix.
for pred, label in zip(predictions, labels):
self._results.append((pred, label))

def _format_metric_results(self, results_per_topk: List[List]) -> Dict:
"""Format the given metric results into a dictionary.
Args:
results_per_topk (list): A list of per topk and thrs accuracy.
Returns:
dict: The formatted dictionary.
"""
metric_results = {}
for k, result_per_topk in zip(self.topk, results_per_topk):
for thr, result_per_thr in zip(self.thrs, result_per_topk):
name = f'top{k}'
if len(self.thrs) > 1:
name += '_no-thr' if thr is None else f'_thr-{thr:.2f}'
metric_results[name] = result_per_thr
return metric_results

@overload # type: ignore
@dispatch
def _compute_metric(self,
results: List[TORCH_IMPL_HINTS]) -> Dict[str, float]:
"""A PyTorch implementation that compute the accuracy metric."""
# Concatenating the intermediate results arcoss all ranks.
labels = torch.stack([res[1] for res in results])
predictions = torch.stack([res[0] for res in results])
total_length = labels.size(0)

# In the case where the prediction is a label (N, ), the accuracy is
# calculated directly without considering `topk` and `thrs`.
if predictions.ndim == 1:
correct = (predictions.int() == labels).sum(0, keepdim=True)
acc = correct.float() / total_length
return {'top1': acc.item()}

# compute the max topk
maxk = max(self.topk)
# NOTE: The torch.topk is non-deterministic with duplicates values.
# See: https://github.com/pytorch/pytorch/issues/27542
pred_score, pred_label = predictions.topk(maxk, dim=1)
pred_label = pred_label.t()

# Broadcast `labels` to the shape of `pred_label` and then compute
# correct tensor.
correct = (pred_label == labels.view(1, -1).expand_as(pred_label))

# compute the accuracy corresponding to all topk and thrs
results_per_topk = []
for k in self.topk:
results_per_thr = []
for thr in self.thrs:
# Only prediction socres larger than thr are counted as correct
if thr is not None:
thr_correct = correct & (pred_score.t() > thr)
else:
thr_correct = correct
topk_thr_correct = thr_correct[:k].reshape(-1).sum(
0, keepdim=True)
acc = topk_thr_correct.float() / total_length
results_per_thr.append(acc.item())
results_per_topk.append(results_per_thr)

return self._format_metric_results(results_per_topk)

@dispatch
def _compute_metric(self,
results: List[NUMPY_IMPL_HINTS]) -> Dict[str, float]:
"""A NumPy implementation that compute the accuracy metric."""
# Concatenating the intermediate results arcoss all ranks.
labels = np.stack([res[1] for res in results])
predictions = np.stack([res[0] for res in results])
total_length = labels.size

# In the case where the prediction is a label (N, ), the accuracy is
# calculated directly without considering `topk` and `thrs`.
if predictions.ndim == 1:
predictions = predictions.astype(np.int32)
correct = (predictions == labels).sum(0, keepdims=True)
acc = correct / total_length
return {'top1': float(acc)}

# compute the max topk
maxk = max(self.topk)
pred_score, pred_label = _numpy_topk(predictions, maxk, 1)
pred_label = pred_label.T

# broadcast `labels` to the shape of `pred_label`
labels = np.broadcast_to(labels.reshape(1, -1), pred_label.shape)
# compute correct tensor
correct = (pred_label == labels)

# compute the accuracy corresponding to all topk and thrs
results_per_topk = []
for k in self.topk:
results_per_thr = []
for thr in self.thrs:
# Only socres greater than thr are counted as correct.
if thr is not None:
thr_correct = correct & (pred_score.T > thr)
else:
thr_correct = correct
topk_thr_correct = thr_correct[:k].reshape(-1).sum(
0, keepdims=True)
acc = topk_thr_correct / total_length
results_per_thr.append(float(acc))
results_per_topk.append(results_per_thr)
return self._format_metric_results(results_per_topk)

def compute_metric(
self, results: List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS]]
) -> Dict[str, float]:
"""Compute the accuracy metric.
Currently, there are 2 implementations of this method: NumPy and
PyTorch. Which implementation to use is determined by the type of the
calling parameters. e.g. `numpy.ndarray` or `torch.Tensor`.
This method would be invoked in `BaseMetric.compute` after distributed
synchronization.
Args:
results (List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS]]): A list
of tuples that consisting the prediction and label. This list
has already been synced across all ranks.
Returns:
Dict[str, float]: The computed accuracy metric.
"""
return self._compute_metric(results)
111 changes: 111 additions & 0 deletions tests/test_classification/test_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) OpenMMLab. All rights reserved.

# yapf: disable

import numpy as np
import pytest

from mmeval.classification.accuracy import Accuracy
from mmeval.core.base_metric import BaseMetric

try:
import torch
except ImportError:
torch = None


@pytest.mark.parametrize(
argnames='metric_kwargs',
argvalues=[
{},
{'topk': 1, 'thrs': 0.1},
{'topk': (1, 2), 'thrs': (0.1, 0.2)},
]
)
def test_metric_interface(metric_kwargs):
accuracy = Accuracy(**metric_kwargs)
assert isinstance(accuracy, BaseMetric)
assert isinstance(accuracy.topk, tuple)
assert isinstance(accuracy.thrs, tuple)

results = accuracy(np.asarray([1, 2, 3]), np.asarray([3, 2, 1]))
assert isinstance(results, dict)
results = accuracy(
np.asarray([[0.1, 0.9], [0.5, 0.5]]), np.asarray([0, 1]))
assert isinstance(results, dict)


@pytest.mark.skipif(torch is None, reason='PyTorch is not available!')
def test_metric_interface_torch():
accuracy = Accuracy()
results = accuracy(torch.Tensor([1, 2, 3]), torch.Tensor([3, 2, 1]))
assert isinstance(results, dict)
results = accuracy(
torch.Tensor([[0.1, 0.9], [0.5, 0.5]]), torch.Tensor([0, 1]))
assert isinstance(results, dict)


@pytest.mark.parametrize(
argnames=['metric_kwargs', 'preditions', 'labels', 'results'],
argvalues=[
({}, [0, 2, 1, 3], [0, 1, 2, 3], {'top1': 0.5}),
(
{'topk': (1, 2, 3)},
[
[0.7, 0.1, 0.1, 0.1],
[0.1, 0.3, 0.4, 0.2],
[0.3, 0.4, 0.2, 0.1],
[0.0, 0.0, 0.1, 0.9]
],
[0, 1, 2, 3],
{'top1': 0.5, 'top2': 0.75, 'top3': 1.}
),
(
{'topk': 2, 'thrs': (0.1, 0.5)},
[
[0.7, 0.1, 0.1, 0.1],
[0.1, 0.3, 0.4, 0.2],
[0.3, 0.4, 0.2, 0.1],
[0.0, 0.0, 0.1, 0.9]
],
[0, 1, 2, 3],
{'top2_thr-0.10': 0.75, 'top2_thr-0.50': 0.5}
)
]
)
def test_metric_accurate(metric_kwargs, preditions, labels, results):
accuracy = Accuracy(**metric_kwargs)
assert accuracy(np.asarray(preditions), np.asarray(labels)) == results


@pytest.mark.skipif(torch is None, reason='PyTorch is not available!')
@pytest.mark.parametrize(
argnames=('metric_kwargs', 'classes_num', 'length'),
argvalues=[
({}, 100, 100),
({'topk': 1, 'thrs': 0.1}, 1000, 100),
({'topk': (1, 2, 3), 'thrs': (0.1, 0.2)}, 1000, 10000),
({'topk': (1, 2, 3), 'thrs': (0.1, None)}, 999, 10002)
]
)
def test_metamorphic_numpy_pytorch(metric_kwargs, classes_num, length):
"""Metamorphic testing for NumPy and PyTorch implementation."""
accuracy = Accuracy(**metric_kwargs)

predictions = np.random.rand(length, classes_num)
labels = np.random.randint(0, classes_num, length)

np_acc_results = accuracy(predictions, labels)

predictions = torch.from_numpy(predictions)
labels = torch.from_numpy(labels)
torch_acc_results = accuracy(predictions, labels)

assert np_acc_results.keys() == torch_acc_results.keys()

for key in np_acc_results:
np.testing.assert_allclose(np_acc_results[key], torch_acc_results[key])


if __name__ == '__main__':
pytest.main([__file__, '-v', '--capture=no'])

0 comments on commit d951c4c

Please sign in to comment.