Skip to content

Commit

Permalink
dev(mmeval/core): add BaseMetric
Browse files Browse the repository at this point in the history
  • Loading branch information
ice-tong committed Sep 15, 2022
1 parent f892eb2 commit b4b4c23
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 1 deletion.
3 changes: 3 additions & 0 deletions mmeval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.

from mmeval import core
from .version import __version__ # noqa: F401

__all__ = ['core', '__version__']
3 changes: 2 additions & 1 deletion mmeval/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.

from mmeval.core import dist_backends
from mmeval.core.base_metric import BaseMetric
from mmeval.core.dispatcher import dispatch
from mmeval.core.dist import (get_dist_backend, list_all_backends,
set_default_dist_backend)

__all__ = [
'dist_backends', 'get_dist_backend', 'set_default_dist_backend',
'list_all_backends', 'dispatch'
'list_all_backends', 'dispatch', 'BaseMetric'
]
162 changes: 162 additions & 0 deletions mmeval/core/base_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) OpenMMLab. All rights reserved.

from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Optional

from mmeval.core.dist import get_dist_backend


class BaseMetric(metaclass=ABCMeta):
"""Base class for metric.
To implement a metric, you should subclass ``BaseMetric`` and then
implement the ``add`` and ``compute_metric`` methods. ``BaseMetric`` will
automatically complete distributed synchronization between processes.
Each metric will maintain a list named ``self._results`` to store
intermediate results. When computing the final metric result, the
``self._results`` will be synchronized between processes.
Args:
dataset_meta (dict, optional): Meta information of the dataset, this is
required for some metrics that require dataset information.
Defaults to None.
dist_merge_method (str, optional): The method of concatenating the
collected synchronization results. This depends on how the
distributed data is split. Currently only 'unzip' and 'cat' are
available. For PyTorch's `DistributedSampler`, 'unzip' should
be used. Defaults to 'unzip'.
dist_backend (str, optional): The name of the distributed communication
backend, you can get all the backend names through
``mmeval.core.list_all_backends()``.
If None, use the default backend. Defaults to None.
Example to implement an accuracy metric:
>>> from mmeval.core import BaseMetric
>>> class Accuracy(BaseMetric):
... def add(self, predictions, labels):
... self._results.append((predictions, labels))
... def compute_metrcis(self, results):
... predictions = np.concatenate([res[0] for res in results])
... labels = np.concatenate([res[1] for res in results])
... correct = (predictions == labels)
... accuracy = sum(correct) / len(predictions)
... return {'accuracy': accuracy}
>>> accuracy = Accuracy()
>>> accuracy(predictions=[1, 2, 3, 4], labels=[1, 2, 3, 1])
{'accuracy': 0.75}
"""

def __init__(self,
dataset_meta: Optional[Dict] = None,
dist_merge_method: str = 'unzip',
dist_backend: Optional[str] = None):
self.dataset_meta = dataset_meta
assert dist_merge_method in ('cat', 'unzip')
self.dist_merge_method = dist_merge_method
self.dist_comm = get_dist_backend(dist_backend)
self._results: List[Any] = []

@property
def dataset_meta(self) -> Optional[Dict]:
"""Meta information of the dataset."""
if self._dataset_meta is None:
return self._dataset_meta
else:
return self._dataset_meta.copy()

@dataset_meta.setter
def dataset_meta(self, dataset_meta: Optional[Dict]) -> None:
"""Set the dataset meta information to the metric."""
if dataset_meta is None:
self._dataset_meta = dataset_meta
else:
self._dataset_meta = dataset_meta.copy()

@property
def name(self) -> str:
"""The metric name, defaults to the name of the class."""
return self.__class__.__name__

def reset(self) -> None:
"""Clear the metric stored results."""
self._results.clear()

def __call__(self, *args, **kwargs) -> Dict:
"""Stateless call for a metric compute."""
cache_results = self._results
self._results = []
self.add(*args, **kwargs)
metric_result = self.compute_metric(self._results)
self._results = cache_results
return metric_result

def compute(self, size: Optional[int] = None) -> Dict:
"""Synchronize intermediate results and then call
``self.compute_metric``.
Args:
size (int, optional): The length of the entire dataset, it is only
used when distributed evaluation. When batch size > 1, the
dataloader may pad some data samples to make sure all ranks
have the same length of dataset slice. The ``compute`` will
drop the padded data based on this size.
If None, do nothing. Defaults to None.
Returns:
dict: The computed metric results.
"""
if not self.dist_comm.is_initialized or self.dist_comm.world_size == 1:
return self.compute_metric(self._results)

global_results = self.dist_comm.all_gather_object(self._results)

collected_results: List[Any]
if self.dist_merge_method == 'cat':
# use `sum` to concatenate list
# e.g. sum([[1, 3], [2, 4]], []) = [1, 3, 2, 4]
collected_results = sum(global_results, [])
else:
collected_results = []
for partial_result in zip(*global_results):
collected_results.extend(list(partial_result))

# NOTE: Needs discussion or investigation @yancong at 9/13/2022, 4:20:48 PM # noqa: E501
# If the intermediate results stored in ``self._results`` are not
# correspond one-to-one with samples (e.g. a total confusion matrix),
# the size here may not work anymore.
if size is not None:
collected_results = collected_results[:size]

if self.dist_comm.rank == 0:
metric_result = self.compute_metric(collected_results)
else:
metric_result = None # type: ignore

global_metric_result = self.dist_comm.broadcast_object(
metric_result, 0)
return global_metric_result

@abstractmethod
def add(self, *args, **kwargs):
"""Override this method to add the intermediate results to
``self._results``.
Note:
For performance issues, what you add to the ``self._results``
should be as simple as possible. But be aware that the intermediate
result stored in ``self._results`` should correspond one-to-one
with the samples, in that we need to remove the padded samples for
the most accurate result.
"""

@abstractmethod
def compute_metric(self, results: List[Any]) -> Dict:
"""Override this method to compute the metric result from collectd
intermediate results.
The returned result of the metric compute should be a dictionary.
"""
118 changes: 118 additions & 0 deletions tests/test_core/test_base_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) OpenMMLab. All rights reserved.

import pytest

from mmeval.core.base_metric import BaseMetric

try:
import torch
except ImportError:
torch = None


class Mertic(BaseMetric):

def add(self, num):
self._results.append(num)

def compute_metric(self, results):
return {'results': results}


def test_dataset_meta():
datset_meta = {'CLASSES': ['test1', 'test2']}

metric = Mertic(dataset_meta=None)
assert metric.dataset_meta is None

metric.dataset_meta = datset_meta
assert metric.dataset_meta == datset_meta


def test_metric_reset():
metric = Mertic()
metric.add(1)
assert len(metric._results) == 1

metric.reset()
assert len(metric._results) == 0


def test_metric_call():
metric = Mertic()
results = metric(1)
assert results == {'results': [1]}

metric.add(2)

# stateless call
results = metric(1)
assert results == {'results': [1]}


def test_metric_compute():
metric = Mertic()

for i in range(10):
metric.add(i)

results = metric.compute()
assert results == {'results': [i for i in range(10)]}


def _init_torch_dist(rank, world_size, comm_backend, port):
torch.distributed.init_process_group(
backend=comm_backend,
init_method=f'tcp://127.0.0.1:{port}',
world_size=world_size,
rank=rank)

if comm_backend == 'nccl':
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)


def _test_metric_compute(rank, world_size, port, dist_merge_method):
_init_torch_dist(rank, world_size, comm_backend='gloo', port=port)

metric = Mertic(
dist_backend='torch_cpu', dist_merge_method=dist_merge_method)

if dist_merge_method == 'unzip':
data_slice = range(rank, 10 * world_size, world_size)
else:
data_slice = range(rank * 10, (rank + 1) * 10)

for i in data_slice:
metric.add(i)

results = metric.compute()
assert results == {'results': [i for i in range(10 * world_size)]}

if world_size == 1:
return

if dist_merge_method == 'unzip':
results = metric.compute(size=(10 * world_size - 1))
assert results == {'results': [i for i in range(10 * world_size - 1)]}
else:
results = metric.compute(size=(10 * world_size - 1))
assert results == {'results': [i for i in range(10 * world_size - 1)]}


@pytest.mark.skipif(torch is None, reason='PyTorch is not available!')
@pytest.mark.skipif(
not torch.distributed.is_available(),
reason='torch.distributed is not available!')
@pytest.mark.parametrize(
argnames=['process_num', 'comm_port', 'dist_merge_method'],
argvalues=[(1, 2346, 'unzip'), (4, 2346, 'unzip'), (4, 2346, 'cat')])
def test_metric_compute_dist(process_num, comm_port, dist_merge_method):
torch.multiprocessing.spawn(
_test_metric_compute,
nprocs=process_num,
args=(process_num, comm_port, dist_merge_method))


if __name__ == '__main__':
pytest.main([__file__, '-v', '--capture=no'])

0 comments on commit b4b4c23

Please sign in to comment.