diff --git a/basicsr/utils/registry.py b/basicsr/utils/registry.py index 655753b3b..5e72ef7ff 100644 --- a/basicsr/utils/registry.py +++ b/basicsr/utils/registry.py @@ -35,12 +35,15 @@ def __init__(self, name): self._name = name self._obj_map = {} - def _do_register(self, name, obj): + def _do_register(self, name, obj, suffix=None): + if isinstance(suffix, str): + name = name + '_' + suffix + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " f"in '{self._name}' registry!") self._obj_map[name] = obj - def register(self, obj=None): + def register(self, obj=None, suffix=None): """ Register the given object under the the name `obj.__name__`. Can be used as either a decorator or not. @@ -50,17 +53,20 @@ def register(self, obj=None): # used as a decorator def deco(func_or_class): name = func_or_class.__name__ - self._do_register(name, func_or_class) + self._do_register(name, func_or_class, suffix) return func_or_class return deco # used as a function call name = obj.__name__ - self._do_register(name, obj) + self._do_register(name, obj, suffix) - def get(self, name): + def get(self, name, suffix='basicsr'): ret = self._obj_map.get(name) + if ret is None: + ret = self._obj_map.get(name + '_' + suffix) + print(f'Name {name} is not found, use name: {name}_{suffix}!') if ret is None: raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") return ret