-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
242 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from mmeval import core | ||
from .version import __version__ # noqa: F401 | ||
|
||
__all__ = ['core', '__version__'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from mmeval.core import dist_backends | ||
from mmeval.core.base_metric import BaseMetric | ||
from mmeval.core.dispatcher import dispatch | ||
from mmeval.core.dist import (get_dist_backend, list_all_backends, | ||
set_default_dist_backend) | ||
|
||
__all__ = [ | ||
'dist_backends', 'get_dist_backend', 'set_default_dist_backend', | ||
'list_all_backends', 'dispatch' | ||
'list_all_backends', 'dispatch', 'BaseMetric' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from abc import ABCMeta, abstractmethod | ||
from typing import Any, Dict, List, Optional | ||
|
||
from mmeval.core.dist import get_dist_backend | ||
|
||
|
||
class BaseMetric(metaclass=ABCMeta): | ||
"""Base class for metric. | ||
To implement a metric, you should subclass ``BaseMetric`` and then | ||
implement the ``add`` and ``compute_metric`` methods. ``BaseMetric`` will | ||
automatically complete distributed synchronization between processes. | ||
Each metric will maintain a list named `self._results` to store | ||
intermediate results. When computing the final metric result, the | ||
`self._results` will be synchronized between processes. | ||
Args: | ||
dataset_meta (dict, optional): Meta information of the dataset, this is | ||
required for some metrics that require dataset information. | ||
Defaults to None. | ||
dist_merge_method (str, optional): The method of concatenating the | ||
collected synchronization results. This depends on how the | ||
distributed data is split. Currently only 'unzip' and 'cat' are | ||
available. For PyTorch's `DistributedSampler`, 'unzip' should | ||
be used. Defaults to 'unzip'. | ||
dist_backend (str, optional): The name of the distributed communication | ||
backend, you can get all the backend names through | ||
`mmeval.core.list_all_backends()`. | ||
If None, use the default backend. Defaults to None. | ||
""" | ||
|
||
def __init__(self, | ||
dataset_meta: Optional[Dict] = None, | ||
dist_merge_method: str = 'unzip', | ||
dist_backend: Optional[str] = None): | ||
self.dataset_meta = dataset_meta | ||
assert dist_merge_method in ('cat', 'unzip') | ||
self.dist_merge_method = dist_merge_method | ||
self.dist_comm = get_dist_backend(dist_backend) | ||
self._results: List[Any] = [] | ||
|
||
@property | ||
def dataset_meta(self) -> Optional[Dict]: | ||
"""Meta information of the dataset.""" | ||
return self._dataset_meta | ||
|
||
@dataset_meta.setter | ||
def dataset_meta(self, dataset_meta: Optional[Dict]) -> None: | ||
"""Set the dataset meta information to the metric.""" | ||
self._dataset_meta = dataset_meta | ||
|
||
@property | ||
def name(self) -> str: | ||
"""The metric name, defaults to the name of the class.""" | ||
return self.__class__.__name__ | ||
|
||
def reset(self) -> None: | ||
"""Clear the metric stored results.""" | ||
self._results.clear() | ||
|
||
def __call__(self, *args, **kwargs) -> Dict: | ||
"""Stateless call for a metric compute.""" | ||
cache_results = self._results | ||
self._results = [] | ||
self.add(*args, **kwargs) | ||
metric_result = self.compute() | ||
self._results = cache_results | ||
return metric_result | ||
|
||
def compute(self, size: Optional[int] = None) -> Dict: | ||
"""Synchronize intermediate results and then call | ||
`self.compute_metric`. | ||
Args: | ||
size (int, optional): The length of the entire dataset. When batch | ||
size > 1, the dataloader may pad some data samples to make | ||
sure all ranks have the same length of dataset slice. The | ||
`compute` will drop the padded data based on this size. | ||
If None, do nothing. Defaults to None. | ||
Returns: | ||
dict: The computed metric results. | ||
""" | ||
if not self.dist_comm.is_initialized or self.dist_comm.world_size == 1: | ||
return self.compute_metric(self._results) | ||
|
||
global_results = self.dist_comm.all_gather_object(self._results) | ||
|
||
collected_results: List[Any] | ||
if self.dist_merge_method == 'cat': | ||
# use `sum` to concatenate list | ||
# e.g. sum([[1, 3], [2, 4]], []) = [1, 3, 2, 4] | ||
collected_results = sum(global_results, []) | ||
else: | ||
collected_results = [] | ||
for partial_result in zip(*global_results): | ||
collected_results.extend(list(partial_result)) | ||
|
||
if size is not None: | ||
collected_results = collected_results[:size] | ||
|
||
if self.dist_comm.rank == 0: | ||
metric_result = self.compute_metric(collected_results) | ||
else: | ||
metric_result = None # type: ignore | ||
|
||
global_metric_result = self.dist_comm.broadcast_object( | ||
metric_result, 0) | ||
return global_metric_result | ||
|
||
@abstractmethod | ||
def add(self, *args, **kwargs): | ||
"""Override this method to add the intermediate results to | ||
`self._results`. | ||
For performance issue, what you add to the `self._results` should be as | ||
simple as possible. | ||
""" | ||
|
||
@abstractmethod | ||
def compute_metric(self, results: List[Any]) -> Dict: | ||
"""Override this method to compute the metric result from collectd | ||
intermediate results. | ||
The returned result of the metric compute should be a dictionary. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import pytest | ||
|
||
from mmeval.core.base_metric import BaseMetric | ||
|
||
try: | ||
import torch | ||
except ImportError: | ||
torch = None | ||
|
||
|
||
class Mertic(BaseMetric): | ||
|
||
def add(self, num): | ||
self._results.append(num) | ||
|
||
def compute_metric(self, results): | ||
return {'results': results} | ||
|
||
|
||
def test_dataset_meta(): | ||
datset_meta = {'CLASSES': ['test1', 'test2']} | ||
|
||
metric = Mertic(dataset_meta=None) | ||
assert metric.dataset_meta is None | ||
|
||
metric.dataset_meta = datset_meta | ||
assert metric.dataset_meta is datset_meta | ||
|
||
|
||
def test_metric_reset(): | ||
metric = Mertic() | ||
metric.add(1) | ||
assert len(metric._results) == 1 | ||
|
||
metric.reset() | ||
assert len(metric._results) == 0 | ||
|
||
|
||
def test_metric_call(): | ||
metric = Mertic() | ||
results = metric(1) | ||
assert results == {'results': [1]} | ||
|
||
metric.add(2) | ||
|
||
# stateless call | ||
results = metric(1) | ||
assert results == {'results': [1]} | ||
|
||
|
||
def test_metric_compute(): | ||
metric = Mertic() | ||
|
||
for i in range(10): | ||
metric.add(i) | ||
|
||
results = metric.compute() | ||
assert results == {'results': [i for i in range(10)]} | ||
|
||
|
||
def _init_torch_dist(rank, world_size, comm_backend, port): | ||
torch.distributed.init_process_group( | ||
backend=comm_backend, | ||
init_method=f'tcp://127.0.0.1:{port}', | ||
world_size=world_size, | ||
rank=rank) | ||
|
||
if comm_backend == 'nccl': | ||
num_gpus = torch.cuda.device_count() | ||
torch.cuda.set_device(rank % num_gpus) | ||
|
||
|
||
def _test_metric_compute(rank, world_size, port, dist_merge_method): | ||
_init_torch_dist(rank, world_size, comm_backend='gloo', port=port) | ||
|
||
metric = Mertic( | ||
dist_backend='torch_cpu', dist_merge_method=dist_merge_method) | ||
|
||
if dist_merge_method == 'unzip': | ||
data_slice = range(rank, 10 * world_size, world_size) | ||
else: | ||
data_slice = range(rank * 10, (rank + 1) * 10) | ||
|
||
for i in data_slice: | ||
metric.add(i) | ||
|
||
results = metric.compute() | ||
assert results == {'results': [i for i in range(10 * world_size)]} | ||
|
||
|
||
@pytest.mark.skipif(torch is None, reason='PyTorch is not available!') | ||
@pytest.mark.skipif( | ||
not torch.distributed.is_available(), | ||
reason='torch.distributed is not available!') | ||
@pytest.mark.parametrize( | ||
argnames=['process_num', 'comm_port', 'dist_merge_method'], | ||
argvalues=[(1, 2346, 'unzip'), (4, 2346, 'unzip'), (4, 2346, 'cat')]) | ||
def test_metric_compute_dist(process_num, comm_port, dist_merge_method): | ||
torch.multiprocessing.spawn( | ||
_test_metric_compute, | ||
nprocs=process_num, | ||
args=(process_num, comm_port, dist_merge_method)) | ||
|
||
|
||
if __name__ == '__main__': | ||
pytest.main([__file__, '-v', '--capture=no']) |