-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dev(mmeval/core/dispatch): add multiple dispatch
- Loading branch information
Showing
4 changed files
with
334 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from mmeval.core import dist_backends | ||
from mmeval.core.dispatcher import dispatch | ||
from mmeval.core.dist import (get_dist_backend, list_all_backends, | ||
set_default_dist_backend) | ||
|
||
__all__ = [ | ||
'dist_backends', 'get_dist_backend', 'set_default_dist_backend', | ||
'list_all_backends' | ||
'list_all_backends', 'dispatch' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
"""This module introduces a multiple dispatch mechanism into mmeval. | ||
Some mmeval metrics may have different calculation methods depending on the | ||
deep learning framework or numeric computing libraries used, such as PyTorch | ||
and NumPy. | ||
In order to deal with the dispatch issue of different calculation methods, we | ||
adopt a dynamic multiple dispatch mechanism based on type hints. | ||
A simple example of multiple dispatch based on type hints is as follows: | ||
``` | ||
@dispatch | ||
def compute(x: int, y: int): | ||
print('this is int') | ||
@dispatch | ||
def compute(x: str, y: str): | ||
print('this is str') | ||
``` | ||
Currently, we employ plum (a multiple dispatch library) to implement multiple | ||
dispatch mechanism in mmeval. | ||
In this module, we optimized the execution speed of plum through the following | ||
two tricks: | ||
- Caching plum Type instances | ||
- Caching plum Type hash value | ||
Benefit from the tricks above, plum dispatch got twice faster as before. | ||
More detail can be found at: https://github.com/wesselb/plum/issues/53 | ||
Besides, we implement `MMEvalDispatcher` to extend plum dispatch for better | ||
support of `typing.ForwardRef`. | ||
""" | ||
|
||
import importlib | ||
import inspect | ||
import logging | ||
import plum | ||
import typing | ||
from typing import Any, Callable, Dict, Hashable, Optional, Type | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _singleton_patch() -> None: | ||
"""A monkey patch that makes `plum.type.TypeMeta` become singleton.""" | ||
origin_call = plum.type.TypeMeta.__call__ | ||
plum.type.TypeMeta._instances = {} | ||
|
||
def __call__(cls, *args, **kwargs): | ||
assert not kwargs | ||
key = (cls, args) | ||
if key not in cls._instances: | ||
cls._instances[key] = origin_call(cls, *args, **kwargs) | ||
return cls._instances[key] | ||
|
||
plum.type.TypeMeta.__call__ = __call__ | ||
|
||
|
||
try: | ||
# Since a lot of instance creation in `plum.type_of` can be expensive, | ||
# using the singleton Type can speed up a lot. | ||
_singleton_patch() | ||
except Exception as e: | ||
logger.warning( | ||
f'Patch `plum.type.TypeMeta` with singleton failed, raise error: {e}.') | ||
|
||
|
||
def _hash_cache_patch(hashable_type: Type[Hashable]) -> None: | ||
"""A monkey patch that make class caching hash value. | ||
This is a very useful trick to optimize runtime speed for classes that | ||
frequently call hash methods. | ||
Args: | ||
hashable_type (Type[Hashable]): Hashable type that wants to cache hash | ||
value. | ||
""" | ||
hash_core = hashable_type.__hash__ | ||
hashable_type._hash = None # type: ignore | ||
|
||
def __hash__(self): | ||
if self._hash is None: | ||
self._hash = hash_core(self) | ||
return self._hash | ||
|
||
hashable_type.__hash__ = __hash__ # type: ignore | ||
|
||
|
||
try: | ||
# The hash method of plum Type and Parametric would be called frequently. | ||
# Caching hash value can speed up a lot. | ||
_hash_cache_patch(plum.type.Type) | ||
_hash_cache_patch(plum.type.VarArgs) | ||
_hash_cache_patch(plum.type.Union) | ||
from plum.parametric import Tuple as plum_Tuple | ||
_hash_cache_patch(plum_Tuple) | ||
from plum.parametric import List as plum_List | ||
_hash_cache_patch(plum_List) | ||
from plum.parametric import Dict as plum_Dict | ||
_hash_cache_patch(plum_Dict) | ||
from plum.parametric import Sequence as plum_Sequence | ||
_hash_cache_patch(plum_Sequence) | ||
from plum.parametric import Iterable as plum_Iterable | ||
_hash_cache_patch(plum_Iterable) | ||
except Exception as e: | ||
logger.warning( | ||
f'Patch plum Type with hash value cache failed, raise error: {e}.') | ||
|
||
|
||
class MMEvalDispatcher(plum.Dispatcher): | ||
"""A Dispatcher inherited from `plum.Dispatcher` that resolve ForwardRef. | ||
This dispatcher tries to use `importlib.import_moudle` to import ForwardRef | ||
type and convert unimportable type as a placeholder. | ||
""" | ||
|
||
_unimportable_types: Dict[str, Type] = {} | ||
|
||
def _resolve_importable_type(self, importable_name: str) -> Type: | ||
"""Resolve the given importable name and returns a type. | ||
The given importable name should be a string contains at least one dot, | ||
so that we can split it as module path and type attribute name. | ||
e.g. 'torch.Tensor' and 'numpy.ndarray'. | ||
Args: | ||
importable_name (str): An importable string that wants to resolve. | ||
Returns: | ||
Type: The resolved type or an placeholder for unimportable type. | ||
""" | ||
assert '.' in importable_name, 'The importable name should contain `.`' | ||
module_name, _, module_attr_basename = importable_name.rpartition('.') | ||
try: | ||
module = importlib.import_module(module_name) | ||
resolved_type = getattr(module, module_attr_basename) | ||
except Exception as e: | ||
if importable_name not in self._unimportable_types: | ||
logger.warning( | ||
f"Unimportable: '{importable_name}', raise error: {e}.") | ||
resolved_type = type(importable_name, (), {}) | ||
self._unimportable_types[importable_name] = resolved_type | ||
else: | ||
resolved_type = self._unimportable_types[importable_name] | ||
return resolved_type | ||
|
||
def _traverse_type_hints(self, annotation: Any) -> Any: | ||
"""Traverse nested type hints, and resolve importable ForwardRef. | ||
Note: | ||
In general, we want a type hint, but be aware that function | ||
annotations could be anything. See PEP 3107 and PEP 484 for more. | ||
Args: | ||
annotation (Annotated): The function annotation that wants to | ||
resolve ForwardRef. | ||
Returns: | ||
Annotated: The traversed function annotation. | ||
""" | ||
if isinstance(annotation, (typing.ForwardRef, str)): | ||
# NOTE: ForwardRef could be a string directly. | ||
# https://docs.python.org/3/library/typing.html#typing.ForwardRef | ||
if isinstance(annotation, typing.ForwardRef): | ||
forward_ref_name = annotation.__forward_arg__ | ||
else: | ||
forward_ref_name = annotation | ||
# Currently, we only hold ForwardRef that contain `.` | ||
# In the case of self type, plum has considered that. | ||
if '.' in forward_ref_name: | ||
return self._resolve_importable_type(forward_ref_name) | ||
else: | ||
return annotation | ||
|
||
# Recursively traverse nested type hints. | ||
if isinstance(annotation, typing._GenericAlias): # type: ignore | ||
new_tp_args = [] | ||
for tp_arg in annotation.__args__: | ||
new_tp_arg = self._traverse_type_hints(tp_arg) | ||
new_tp_args.append(new_tp_arg) | ||
annotation.__args__ = tuple(new_tp_args) | ||
|
||
return annotation | ||
|
||
def __call__(self, | ||
method: Optional[Callable] = None, | ||
**kwargs) -> Callable: | ||
"""Process the function annotations and resolve type hints that in | ||
ForwardRef form.""" | ||
if method is not None: | ||
signature = inspect.signature(method) | ||
for param in signature.parameters.values(): | ||
param._annotation = self._traverse_type_hints( # type: ignore | ||
param.annotation) | ||
method.__signature__ = signature # type: ignore | ||
return super().__call__(method=method, **kwargs) | ||
|
||
|
||
dispatch = MMEvalDispatcher() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
numpy | ||
plum-dispatch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import numpy as np | ||
import pytest | ||
from typing import Dict, List, overload | ||
|
||
from mmeval.core.dispatcher import dispatch | ||
|
||
try: | ||
import torch | ||
except ImportError: | ||
torch = None | ||
|
||
|
||
def test_resolve_importable_type(): | ||
assert dispatch._resolve_importable_type('numpy.ndarray') is np.ndarray | ||
|
||
# Got warning: No module named 'np'. | ||
assert dispatch._resolve_importable_type('np.ndarray') is not np.ndarray | ||
|
||
# The placeholder should be a type | ||
assert type(dispatch._resolve_importable_type('np.ndarray')) is type | ||
|
||
# The importable name should contain `.` | ||
with pytest.raises(AssertionError): | ||
dispatch._resolve_importable_type('numpy') | ||
|
||
|
||
def test_multiple_dispatch_function(): | ||
|
||
@overload | ||
@dispatch | ||
def fn(x: int, y: int): | ||
"""In the case of int, return 1.""" | ||
return 1 | ||
|
||
@overload | ||
@dispatch | ||
def fn(x: float, y: float): | ||
"""In the case of float, return 2.""" | ||
return 2 | ||
|
||
@overload | ||
@dispatch | ||
def fn(x: 'numpy.int8', y: 'numpy.int8'): # noqa: F821 | ||
"""In the case of 'numpy.int8', return 3.""" | ||
return 3 | ||
|
||
@overload | ||
@dispatch | ||
def fn(x: 'xxxx.int8', y: 'xxxx.int8'): # noqa: F821 | ||
"""This function test dispatch resolve unimportable type.""" | ||
|
||
@overload | ||
@dispatch | ||
def fn(x: List[int], y: List[int]): | ||
"""In the case of List[int], return 4.""" | ||
return 4 | ||
|
||
@dispatch | ||
def fn(x: Dict[str, int], y: Dict[str, int]): | ||
"""In the case of Dict[str, int], return 5.""" | ||
return 5 | ||
|
||
assert fn(1, 2) == 1 | ||
assert fn(1.0, 2.0) == 2 | ||
assert fn(np.int8(1), np.int8(2)) == 3 | ||
assert fn([1, 2], [3, 4]) == 4 | ||
assert fn({'1': 2}, {'3': 4}) == 5 | ||
|
||
with pytest.raises(LookupError): | ||
fn('1', '2') | ||
|
||
# Test if the type inference is accurate | ||
with pytest.raises(LookupError): | ||
fn({'1': 2}, {'3': 4, 5: '6'}) | ||
|
||
|
||
def test_multiple_dispatch_class_method(): | ||
|
||
class Tester: | ||
|
||
@overload | ||
@dispatch | ||
def __call__(self, x: int, y: int): | ||
"""In the case of int, this is a minimum method.""" | ||
if x < y: | ||
return x | ||
else: | ||
return y | ||
|
||
@dispatch | ||
def __call__(self, x: float, y: float): | ||
"""In the case of int, this is a maximum method.""" | ||
if x > y: | ||
return x | ||
else: | ||
return y | ||
|
||
tester = Tester() | ||
assert tester(1, 2) == 1 | ||
assert tester(1.0, 2.0) == 2.0 | ||
|
||
with pytest.raises(LookupError): | ||
tester('1', '2') | ||
|
||
|
||
@pytest.mark.skipif(torch is None, reason='PyTorch is not available!') | ||
def test_multiple_dispatch_tensor(): | ||
|
||
@overload | ||
@dispatch | ||
def fn(x: 'torch.Tensor', y: 'torch.Tensor'): | ||
"""In the case of 'torch.Tensor', return 1.""" | ||
return 1 | ||
|
||
@dispatch | ||
def fn(x: 'numpy.int8', y: 'numpy.int8'): # noqa: F821 | ||
"""In the case of 'numpy.int8', return 2.""" | ||
return 2 | ||
|
||
assert fn(torch.Tensor([1]), torch.Tensor([2])) == 1 | ||
assert fn(np.int8(1), np.int8(2)) == 2 | ||
|
||
|
||
if __name__ == '__main__': | ||
pytest.main([__file__, '-v', '--capture=no']) |