Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support registering function #1858

Merged
merged 13 commits into from
May 2, 2022
Merged
18 changes: 12 additions & 6 deletions mmcv/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def build_from_cfg(cfg, registry, default_args=None):
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
274869388 marked this conversation as resolved.
Show resolved Hide resolved
obj_cls = obj_type
else:
raise TypeError(
Expand All @@ -56,16 +56,21 @@ def build_from_cfg(cfg, registry, default_args=None):


class Registry:
"""A registry to map strings to classes.
"""A registry to map strings to classes or functions.

Registered object could be built from registry.
Registered object could be built from registry. Meanwhile, registered
functions could be called from registry.

Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = MODELS.build(dict(type='ResNet'))
>>> @MODELS.register_module()
>>> def resnet50():
>>> pass
>>> resnet = MODELS.build(dict(type='resnet50'))

Please refer to
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
Expand Down Expand Up @@ -233,8 +238,9 @@ def _add_children(self, registry):
self.children[registry.scope] = registry

def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
if not inspect.isclass(module_class) and not inspect.isfunction(
module_class):
raise TypeError('module must be a class or a function, '
f'but got {type(module_class)}')

if module_name is None:
Expand Down Expand Up @@ -286,7 +292,7 @@ def register_module(self, name=None, force=False, module=None):
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
module (type): Module class or function to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
Expand Down
19 changes: 15 additions & 4 deletions tests/test_utils/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,23 @@ class SphynxCat:
with pytest.raises(TypeError):
CATS.register_module(0)

# can only decorate a class
@CATS.register_module()
def muchkin():
pass

assert CATS.get('muchkin') is muchkin
assert 'muchkin' in CATS

# can only decorate a class or a function
with pytest.raises(TypeError):

@CATS.register_module()
def some_method():
pass
class Demo:

def some_method(self):
pass

method = Demo().some_method
CATS.register_module(name='some_method', module=method)
274869388 marked this conversation as resolved.
Show resolved Hide resolved

# begin: test old APIs
with pytest.warns(DeprecationWarning):
Expand Down