Skip to content

Commit

Permalink
[Enhance] Add build function for scheduler. (open-mmlab#372)
Browse files Browse the repository at this point in the history
* add build function for scheduler

* add unit test

add unit test

* handle convert_to_iter in build_scheduler_from_cfg

* restore deleted code

* format import

* fix lint
  • Loading branch information
HAOCHENYE authored Aug 8, 2022
1 parent 99de095 commit a07a063
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 218 deletions.
5 changes: 3 additions & 2 deletions mmengine/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .build_functions import (build_from_cfg, build_model_from_cfg,
build_runner_from_cfg)
build_runner_from_cfg, build_scheduler_from_cfg)
from .default_scope import DefaultScope
from .registry import Registry
from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
Expand All @@ -17,5 +17,6 @@
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS',
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR',
'DefaultScope', 'traverse_registry_tree', 'count_registered_modules',
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg'
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg',
'build_scheduler_from_cfg'
]
82 changes: 79 additions & 3 deletions mmengine/registry/build_functions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import logging
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

import torch.nn as nn

from ..config import Config, ConfigDict
from ..utils import ManagerMixin
from .registry import Registry

if TYPE_CHECKING:
from ..optim.scheduler import _ParamScheduler
from ..runner import Runner


def build_from_cfg(
cfg: Union[dict, ConfigDict, Config],
Expand Down Expand Up @@ -131,7 +137,7 @@ def build_from_cfg(


def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
registry: Registry) -> Any:
registry: Registry) -> 'Runner':
"""Build a Runner object.
Examples:
>>> from mmengine.registry import Registry, build_runner_from_cfg
Expand Down Expand Up @@ -203,7 +209,11 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
f'{cls_location}.py: {e}')


def build_model_from_cfg(cfg, registry, default_args=None):
def build_model_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
nn.Module:
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
Expand All @@ -226,3 +236,69 @@ def build_model_from_cfg(cfg, registry, default_args=None):
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)


def build_scheduler_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
'_ParamScheduler':
"""Builds a ``ParamScheduler`` instance from config.
``ParamScheduler`` supports building instance by its constructor or
method ``build_iter_from_epoch``. Therefore, its registry needs a build
function to handle both cases.
Args:
cfg (dict or ConfigDict or Config): Config dictionary. If it contains
the key ``convert_to_iter_based``, instance will be built by method
``convert_to_iter_based``, otherwise instance will be built by its
constructor.
registry (:obj:`Registry`): The ``PARAM_SCHEDULERS`` registry.
default_args (dict or ConfigDict or Config, optional): Default
initialization arguments. It must contain key ``optimizer``. If
``convert_to_iter_based`` is defined in ``cfg``, it must
additionally contain key ``epoch_length``. Defaults to None.
Returns:
object: The constructed ``ParamScheduler``.
"""
assert isinstance(
cfg,
(dict, ConfigDict, Config
)), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}'
assert isinstance(
registry, Registry), ('registry should be a mmengine.Registry object',
f'but got {type(registry)}')

args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
scope = args.pop('_scope_', None)
with registry.switch_scope_and_registry(scope) as registry:
convert_to_iter = args.pop('convert_to_iter_based', False)
if convert_to_iter:
scheduler_type = args.pop('type')
assert 'epoch_length' in args and args.get('by_epoch', True), (
'Only epoch-based parameter scheduler can be converted to '
'iter-based, and `epoch_length` should be set')
if isinstance(scheduler_type, str):
scheduler_cls = registry.get(scheduler_type)
if scheduler_cls is None:
raise KeyError(
f'{scheduler_type} is not in the {registry.name} '
'registry. Please check whether the value of '
f'`{scheduler_type}` is correct or it was registered '
'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
)
elif inspect.isclass(scheduler_type):
scheduler_cls = scheduler_type
else:
raise TypeError('type must be a str or valid type, but got '
f'{type(scheduler_type)}')
return scheduler_cls.build_iter_from_epoch( # type: ignore
**args)
else:
args.pop('epoch_length', None)
return build_from_cfg(args, registry)
6 changes: 4 additions & 2 deletions mmengine/registry/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""

from mmengine.registry import build_model_from_cfg, build_runner_from_cfg
from .build_functions import (build_model_from_cfg, build_runner_from_cfg,
build_scheduler_from_cfg)
from .registry import Registry

# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
Expand Down Expand Up @@ -37,7 +38,8 @@
# manage constructors that customize the optimization hyperparameters.
OPTIM_WRAPPER_CONSTRUCTORS = Registry('optimizer wrapper constructor')
# mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry('parameter scheduler')
PARAM_SCHEDULERS = Registry(
'parameter scheduler', build_func=build_scheduler_from_cfg)

# manage all kinds of metrics
METRICS = Registry('metric')
Expand Down
28 changes: 5 additions & 23 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,34 +1134,16 @@ def _build_param_scheduler(
f'The `end` of {_scheduler["type"]} is not set. '
'Use the max epochs/iters of train loop as default.')

convert_to_iter = _scheduler.pop('convert_to_iter_based',
False)
if convert_to_iter:
assert _scheduler.get(
'by_epoch',
True), ('only epoch-based parameter scheduler can be '
'converted to iter-based')
assert isinstance(self._train_loop, BaseLoop), \
'Scheduler can only be converted to iter-based ' \
'when train loop is built.'
cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
param_schedulers.append(
cls.build_iter_from_epoch( # type: ignore
param_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_args=dict(
optimizer=optim_wrapper,
**_scheduler,
epoch_length=len(
self.train_dataloader), # type: ignore
))
else:
param_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_args=dict(optimizer=optim_wrapper)))
epoch_length=len(self.train_dataloader))))
else:
raise TypeError(
'scheduler should be a _ParamScheduler object or dict, '
f'but got {scheduler}')

return param_schedulers

def build_param_scheduler(
Expand Down
Loading

0 comments on commit a07a063

Please sign in to comment.