Skip to content

Commit

Permalink
dev(mmeval/core/dispatch): add multiple dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
ice-tong committed Sep 6, 2022
1 parent 12e17f4 commit e054de0
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mmeval/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.

from mmeval.core import dist_backends
from mmeval.core.dispatcher import dispatch
from mmeval.core.dist import (get_dist_backend, list_all_backends,
set_default_dist_backend)

__all__ = [
'dist_backends', 'get_dist_backend', 'set_default_dist_backend',
'list_all_backends'
'list_all_backends', 'dispatch'
]
204 changes: 204 additions & 0 deletions mmeval/core/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This module introduces a multiple dispatch mechanism into mmeval.
Some mmeval metrics may have different calculation methods depending on the
deep learning framework or numeric computing libraries used, such as PyTorch
and NumPy.
In order to deal with the dispatch issue of different calculation methods, we
adopt a dynamic multiple dispatch mechanism based on type hints.
A simple example of multiple dispatch based on type hints is as follows:
```
@dispatch
def compute(x: int, y: int):
print('this is int')
@dispatch
def compute(x: str, y: str):
print('this is str')
```
Currently, we employ plum (a multiple dispatch library) to implement multiple
dispatch mechanism in mmeval.
In this module, we optimized the execution speed of plum through the following
two tricks:
- Caching plum Type instances
- Caching plum Type hash value
Benefit from the tricks above, plum dispatch got twice faster as before.
More detail can be found at: https://github.com/wesselb/plum/issues/53
Besides, we implement `MMEvalDispatcher` to extend plum dispatch for better
support of `typing.ForwardRef`.
"""

import importlib
import inspect
import logging
import plum
import typing
from typing import Any, Callable, Dict, Hashable, Optional, Type

logger = logging.getLogger(__name__)


def _singleton_patch() -> None:
"""A monkey patch that makes `plum.type.TypeMeta` become singleton."""
origin_call = plum.type.TypeMeta.__call__
plum.type.TypeMeta._instances = {}

def __call__(cls, *args, **kwargs):
assert not kwargs
key = (cls, args)
if key not in cls._instances:
cls._instances[key] = origin_call(cls, *args, **kwargs)
return cls._instances[key]

plum.type.TypeMeta.__call__ = __call__


try:
# Since a lot of instance creation in `plum.type_of` can be expensive,
# using the singleton Type can speed up a lot.
_singleton_patch()
except Exception as e:
logger.warning(
f'Patch `plum.type.TypeMeta` with singleton failed, raise error: {e}.')


def _hash_cache_patch(hashable_type: Type[Hashable]) -> None:
"""A monkey patch that make class caching hash value.
This is a very useful trick to optimize runtime speed for classes that
frequently call hash methods.
Args:
hashable_type (Type[Hashable]): Hashable type that wants to cache hash
value.
"""
hash_core = hashable_type.__hash__
hashable_type._hash = None # type: ignore

def __hash__(self):
if self._hash is None:
self._hash = hash_core(self)
return self._hash

hashable_type.__hash__ = __hash__ # type: ignore


try:
# The hash method of plum Type and Parametric would be called frequently.
# Caching hash value can speed up a lot.
_hash_cache_patch(plum.type.Type)
_hash_cache_patch(plum.type.VarArgs)
_hash_cache_patch(plum.type.Union)
from plum.parametric import Tuple as plum_Tuple
_hash_cache_patch(plum_Tuple)
from plum.parametric import List as plum_List
_hash_cache_patch(plum_List)
from plum.parametric import Dict as plum_Dict
_hash_cache_patch(plum_Dict)
from plum.parametric import Sequence as plum_Sequence
_hash_cache_patch(plum_Sequence)
from plum.parametric import Iterable as plum_Iterable
_hash_cache_patch(plum_Iterable)
except Exception as e:
logger.warning(
f'Patch plum Type with hash value cache failed, raise error: {e}.')


class MMEvalDispatcher(plum.Dispatcher):
"""A Dispatcher inherited from `plum.Dispatcher` that resolve ForwardRef.
This dispatcher tries to use `importlib.import_moudle` to import ForwardRef
type and convert unimportable type as a placeholder.
"""

_unimportable_types: Dict[str, Type] = {}

def _resolve_importable_type(self, importable_name: str) -> Type:
"""Resolve the given importable name and returns a type.
The given importable name should be a string contains at least one dot,
so that we can split it as module path and type attribute name.
e.g. 'torch.Tensor' and 'numpy.ndarray'.
Args:
importable_name (str): An importable string that wants to resolve.
Returns:
Type: The resolved type or an placeholder for unimportable type.
"""
assert '.' in importable_name, 'The importable name should contain `.`'
module_name, _, module_attr_basename = importable_name.rpartition('.')
try:
module = importlib.import_module(module_name)
resolved_type = getattr(module, module_attr_basename)
except Exception as e:
if importable_name not in self._unimportable_types:
logger.warning(
f"Unimportable: '{importable_name}', raise error: {e}.")
resolved_type = type(importable_name, (), {})
self._unimportable_types[importable_name] = resolved_type
else:
resolved_type = self._unimportable_types[importable_name]
return resolved_type

def _traverse_type_hints(self, annotation: Any) -> Any:
"""Traverse nested type hints, and resolve importable ForwardRef.
Note:
In general, we want a type hint, but be aware that function
annotations could be anything. See PEP 3107 and PEP 484 for more.
Args:
annotation (Annotated): The function annotation that wants to
resolve ForwardRef.
Returns:
Annotated: The traversed function annotation.
"""
if isinstance(annotation, (typing.ForwardRef, str)):
# NOTE: ForwardRef could be a string directly.
# https://docs.python.org/3/library/typing.html#typing.ForwardRef
if isinstance(annotation, typing.ForwardRef):
forward_ref_name = annotation.__forward_arg__
else:
forward_ref_name = annotation
# Currently, we only hold ForwardRef that contain `.`
# In the case of self type, plum has considered that.
if '.' in forward_ref_name:
return self._resolve_importable_type(forward_ref_name)
else:
return annotation

# Recursively traverse nested type hints.
if isinstance(annotation, typing._GenericAlias): # type: ignore
new_tp_args = []
for tp_arg in annotation.__args__:
new_tp_arg = self._traverse_type_hints(tp_arg)
new_tp_args.append(new_tp_arg)
annotation.__args__ = tuple(new_tp_args)

return annotation

def __call__(self,
method: Optional[Callable] = None,
**kwargs) -> Callable:
"""Process the function annotations and resolve type hints that in
ForwardRef form."""
if method is not None:
signature = inspect.signature(method)
for param in signature.parameters.values():
param._annotation = self._traverse_type_hints( # type: ignore
param.annotation)
method.__signature__ = signature # type: ignore
return super().__call__(method=method, **kwargs)


dispatch = MMEvalDispatcher()
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
numpy
plum-dispatch
127 changes: 127 additions & 0 deletions tests/test_core/test_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) OpenMMLab. All rights reserved.

import numpy as np
import pytest
from typing import Dict, List, overload

from mmeval.core.dispatcher import dispatch

try:
import torch
except ImportError:
torch = None


def test_resolve_importable_type():
assert dispatch._resolve_importable_type('numpy.ndarray') is np.ndarray

# Got warning: No module named 'np'.
assert dispatch._resolve_importable_type('np.ndarray') is not np.ndarray

# The placeholder should be a type
assert type(dispatch._resolve_importable_type('np.ndarray')) is type

# The importable name should contain `.`
with pytest.raises(AssertionError):
dispatch._resolve_importable_type('numpy')


def test_multiple_dispatch_function():

@overload
@dispatch
def fn(x: int, y: int):
"""In the case of int, return 1."""
return 1

@overload
@dispatch
def fn(x: float, y: float):
"""In the case of float, return 2."""
return 2

@overload
@dispatch
def fn(x: 'numpy.int8', y: 'numpy.int8'): # noqa: F821
"""In the case of 'numpy.int8', return 3."""
return 3

@overload
@dispatch
def fn(x: 'xxxx.int8', y: 'xxxx.int8'): # noqa: F821
"""This function test dispatch resolve unimportable type."""

@overload
@dispatch
def fn(x: List[int], y: List[int]):
"""In the case of List[int], return 4."""
return 4

@dispatch
def fn(x: Dict[str, int], y: Dict[str, int]):
"""In the case of Dict[str, int], return 5."""
return 5

assert fn(1, 2) == 1
assert fn(1.0, 2.0) == 2
assert fn(np.int8(1), np.int8(2)) == 3
assert fn([1, 2], [3, 4]) == 4
assert fn({'1': 2}, {'3': 4}) == 5

with pytest.raises(LookupError):
fn('1', '2')

# Test if the type inference is accurate
with pytest.raises(LookupError):
fn({'1': 2}, {'3': 4, 5: '6'})


def test_multiple_dispatch_class_method():

class Tester:

@overload
@dispatch
def __call__(self, x: int, y: int):
"""In the case of int, this is a minimum method."""
if x < y:
return x
else:
return y

@dispatch
def __call__(self, x: float, y: float):
"""In the case of int, this is a maximum method."""
if x > y:
return x
else:
return y

tester = Tester()
assert tester(1, 2) == 1
assert tester(1.0, 2.0) == 2.0

with pytest.raises(LookupError):
tester('1', '2')


@pytest.mark.skipif(torch is None, reason='PyTorch is not available!')
def test_multiple_dispatch_tensor():

@overload
@dispatch
def fn(x: 'torch.Tensor', y: 'torch.Tensor'):
"""In the case of 'torch.Tensor', return 1."""
return 1

@dispatch
def fn(x: 'numpy.int8', y: 'numpy.int8'): # noqa: F821
"""In the case of 'numpy.int8', return 2."""
return 2

assert fn(torch.Tensor([1]), torch.Tensor([2])) == 1
assert fn(np.int8(1), np.int8(2)) == 2


if __name__ == '__main__':
pytest.main([__file__, '-v', '--capture=no'])

0 comments on commit e054de0

Please sign in to comment.