Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support auto import modules from registry #1621

Merged
merged 2 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions configs/_base_/datasets/cifar10_noaug.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# custom_imports = dict(
# imports=['mmcls.datasets.transforms'], allow_failed_imports=False)
cifar_pipeline = [dict(type='PackEditInputs')]
cifar_dataset = dict(
type='CIFAR10',
Expand Down
4 changes: 2 additions & 2 deletions configs/disco_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ Running the following codes, you can get a text-generated image.

```python
from mmengine import Config, MODELS
from mmedit.utils import register_all_modules
from mmengine.registry import init_default_scope
from torchvision.utils import save_image

register_all_modules()
init_default_scope('mmedit')

disco = MODELS.build(
Config.fromfile('configs/disco_diffusion/disco-baseline.py').model).cuda().eval()
Expand Down
8 changes: 4 additions & 4 deletions configs/disco_diffusion/tutorials.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,15 @@
"source": [
"import torch\n",
"from mmengine import Config, MODELS\n",
"from mmedit.utils import register_all_modules\n",
"from mmengine.registry import init_default_scope\n",
"from mmedit.registry import MODULES\n",
"from mmcv import tensor2imgs\n",
"from matplotlib import pyplot as plt\n",
"\n",
"from torchvision.transforms import ToPILImage, Normalize, Compose\n",
"from IPython.display import Image\n",
"\n",
"register_all_modules()\n",
"init_default_scope('mmedit')\n",
"\n",
"\n",
"def show_tensor(image_tensor, index=0):\n",
Expand Down Expand Up @@ -1417,7 +1417,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 64-bit",
"display_name": "Python 3.8.13 ('mmedit2': conda)",
"language": "python",
"name": "python3"
},
Expand All @@ -1436,7 +1436,7 @@
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "ab5d8c58fa4ba0eabce645db4cbe37b4cabea4c937ae154ab72fe7cc84b68be8"
"hash": "5d897c5c52e082b514ee6f95b827618ca631a30c2bcba2887693dc2fed97e1f9"
}
}
},
Expand Down
4 changes: 2 additions & 2 deletions configs/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ Running the following codes, you can get a text-generated image.
from mmengine import MODELS, Config
from torchvision import utils

from mmedit.utils import register_all_modules
from mmengine.registry import init_default_scope

register_all_modules()
init_default_scope('mmedit')

config = 'configs/stable_diffusion/stable-diffusion_ddim_denoisingunet.py'
StableDiffuser = MODELS.build(Config.fromfile(config).model)
Expand Down
4 changes: 2 additions & 2 deletions demo/gradio-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import numpy as np
import torch
import yaml
from mmengine.registry import init_default_scope

from mmedit.apis.inferencers.inpainting_inferencer import InpaintingInferencer
from mmedit.utils import register_all_modules


class InpaintingGradio:
Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(self,
extra_parameters: Dict = None,
seed: int = 2022,
**kwargs) -> None:
register_all_modules(init_default_scope=True)
init_default_scope('mmedit')
InpaintingGradio.init_inference_supported_models_cfg()
self.model_name = model_name
self.model_setting = model_setting
Expand Down
19 changes: 10 additions & 9 deletions mmedit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import mmengine

from .version import __version__, version_info

Expand All @@ -25,19 +26,19 @@ def digit_version(version_str):
mmcv_max_version = digit_version(MMCV_MAX)
mmcv_version = digit_version(mmcv.__version__)

# MMENGINE_MIN = '0.1.0'
# MMENGINE_MAX = '0.2.0'
# mmengine_min_version = digit_version(MMENGINE_MIN)
# mmengine_max_version = digit_version(MMENGINE_MAX)
# mmengine_version = digit_version(mmengine.__version__)
MMENGINE_MIN = '0.4.0'
MMENGINE_MAX = '1.0.0'
mmengine_min_version = digit_version(MMENGINE_MIN)
mmengine_max_version = digit_version(MMENGINE_MAX)
mmengine_version = digit_version(mmengine.__version__)

assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \
f'mmcv=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv-full>={mmcv_min_version}, <{mmcv_max_version}.'

# assert (mmengine_min_version <= mmengine_version < mmengine_max_version), \
# f'mmengine=={mmengine.__version__} is used but incompatible. ' \
# f'Please install mmengine>={mmengine_min_version}, ' \
# f'<{mmengine_max_version}.'
assert (mmengine_min_version <= mmengine_version < mmengine_max_version), \
f'mmengine=={mmengine.__version__} is used but incompatible. ' \
f'Please install mmengine>={mmengine_min_version}, ' \
f'<{mmengine_max_version}.'

__all__ = ['__version__', 'version_info']
5 changes: 3 additions & 2 deletions mmedit/apis/inferencers/inference_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from mmengine.dataset import Compose
from mmengine.dataset.utils import default_collate as collate
from mmengine.fileio import get_file_backend
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from mmengine.runner import set_random_seed as set_random_seed_engine
from mmengine.utils import ProgressBar
from torch.nn.parallel import scatter

from mmedit.models.base_models import BaseTranslationModel
from mmedit.registry import MODELS
from mmedit.utils import register_all_modules

VIDEO_EXTENSIONS = ('.mp4', '.mov', '.avi')
FILE_CLIENT = get_file_backend(backend_args={'backend': 'local'})
Expand Down Expand Up @@ -81,7 +81,8 @@ def init_model(config, checkpoint=None, device='cuda:0'):
# config.test_cfg.metrics = None
delete_cfg(config.model, 'init_cfg')

register_all_modules()
init_default_scope(config.get('default_scope', 'mmedit'))

model = MODELS.build(config.model)

if checkpoint is not None:
Expand Down
4 changes: 2 additions & 2 deletions mmedit/edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import torch
import yaml
from mmengine.registry import init_default_scope

from mmedit.apis.inferencers import MMEditInferencer
from mmedit.apis.inferencers.base_mmedit_inferencer import InputsType
from mmedit.utils import register_all_modules


class MMEdit:
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(self,
extra_parameters: Dict = None,
seed: int = 2022,
**kwargs) -> None:
register_all_modules(init_default_scope=True)
init_default_scope('mmedit')
MMEdit.init_inference_supported_models_cfg()
inferencer_kwargs = {}
inferencer_kwargs.update(
Expand Down
59 changes: 41 additions & 18 deletions mmedit/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,31 @@
from mmengine.registry import Registry

# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry('runner', parent=registry.RUNNERS)
RUNNERS = Registry(
'runner', parent=registry.RUNNERS, locations=['mmedit.engine.runner'])
# manage runner constructors that define how to initialize runners
RUNNER_CONSTRUCTORS = Registry(
'runner constructor', parent=registry.RUNNER_CONSTRUCTORS)
'runner constructor',
parent=registry.RUNNER_CONSTRUCTORS,
locations=['mmedit.engine.runner'])
# manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry('loop', parent=registry.LOOPS)
LOOPS = Registry(
'loop', parent=registry.LOOPS, locations=['mmedit.engine.runner'])
# manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry('hook', parent=registry.HOOKS)
HOOKS = Registry(
'hook', parent=registry.HOOKS, locations=['mmedit.engine.hooks'])

# manage data-related modules
DATASETS = Registry('dataset', parent=registry.DATASETS)
DATASETS = Registry(
'dataset', parent=registry.DATASETS, locations=['mmedit.datasets'])
DATA_SAMPLERS = Registry('data sampler', parent=registry.DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=registry.TRANSFORMS)
TRANSFORMS = Registry(
'transform',
parent=registry.TRANSFORMS,
locations=['mmedit.datasets.transforms'])

# manage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=registry.MODELS)
MODELS = Registry('model', parent=registry.MODELS, locations=['mmedit.models'])
MODULES = BACKBONES = COMPONENTS = LOSSES = MODELS
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=registry.MODEL_WRAPPERS)
Expand All @@ -37,31 +46,45 @@

# manage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=registry.OPTIMIZERS)
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper', parent=registry.OPTIM_WRAPPERS)
# manage constructors that customize the optimization hyperparameters.
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer wrapper constructor',
parent=registry.OPTIM_WRAPPER_CONSTRUCTORS)
parent=registry.OPTIM_WRAPPER_CONSTRUCTORS,
locations=['mmedit.engine.optimizers'])
# manage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=registry.PARAM_SCHEDULERS)
'parameter scheduler',
parent=registry.PARAM_SCHEDULERS,
locations=['mmedit.engine.schedulers'])
# manage all kinds of metrics
METRICS = Registry('metric', parent=registry.METRICS)
METRICS = Registry(
'metric', parent=registry.METRICS, locations=['mmedit.evaluation'])
# manage all kinds of evaluators
EVALUATORS = Registry('evaluator', parent=registry.EVALUATOR)
EVALUATORS = Registry(
'evaluator', parent=registry.EVALUATOR, locations=['mmedit.evaluation'])

# manage task-specific modules like anchor generators and box coders
TASK_UTILS = Registry('task util', parent=registry.TASK_UTILS)

# manage visualizer
VISUALIZERS = Registry('visualizer', parent=registry.VISUALIZERS)
VISUALIZERS = Registry(
'visualizer',
parent=registry.VISUALIZERS,
locations=['mmedit.visualization'])
# manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=registry.VISBACKENDS)
VISBACKENDS = Registry(
'vis_backend',
parent=registry.VISBACKENDS,
locations=['mmedit.visualization'])

# manage logprocessor
LOG_PROCESSORS = Registry('log_processor', parent=registry.LOG_PROCESSORS)

# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper', parent=registry.OPTIM_WRAPPERS)
LOG_PROCESSORS = Registry(
'log_processor',
parent=registry.LOG_PROCESSORS,
locations=['mmedit.engine.runner'])

# manage diffusion_schedulers
DIFFUSION_SCHEDULERS = Registry('diffusion scheduler')
DIFFUSION_SCHEDULERS = Registry(
'diffusion scheduler', locations=['mmedit.models.editors'])
4 changes: 2 additions & 2 deletions projects/glide/configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ You can run glide as follows:
```python
import torch
from mmedit.apis import init_model
from mmedit.utils import register_all_modules
from mmengine.registry import init_default_scope
from projects.glide.models import *

register_all_modules()
init_default_scope('mmedit')

config = 'projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py'
ckpt = 'https://download.openmmlab.com/mmediting/glide/glide_laion-64x64-02afff47.pth'
Expand Down
2 changes: 1 addition & 1 deletion requirements/mminstall.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mmcv>=2.0.0rc1
mmengine
mmengine>=0.4.0
7 changes: 4 additions & 3 deletions tools/analysis_tools/get_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import torch
from mmengine import Config
from mmengine.registry import init_default_scope

from mmedit.registry import MODELS
from mmedit.utils import register_all_modules

try:
from mmcv.cnn import get_model_complexity_info
Expand Down Expand Up @@ -39,9 +39,10 @@ def main():
else:
raise ValueError('invalid input shape')

register_all_modules()

cfg = Config.fromfile(args.config)

init_default_scope(cfg.get('default_scope', 'mmedit'))

model = MODELS.build(cfg.model)
if torch.cuda.is_available():
model.cuda()
Expand Down
6 changes: 3 additions & 3 deletions tools/model_converters/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from mmcv.onnx import register_extra_symbolics
from mmengine import Config
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint

from mmedit.apis import delete_cfg
from mmedit.registry import MODELS
from mmedit.utils import register_all_modules


def pytorch2onnx(model,
Expand Down Expand Up @@ -190,15 +190,15 @@ def parse_args():
config = Config.fromfile(args.config)
delete_cfg(config, key='init_cfg')

init_default_scope(config.get('default_scope', 'mmedit'))

# ONNX does not support spectral norm
if model_type == 'mattor':
if hasattr(config.model.backbone.encoder, 'with_spectral_norm'):
config.model.backbone.encoder.with_spectral_norm = False
config.model.backbone.decoder.with_spectral_norm = False
config.test_cfg.metrics = None

register_all_modules()

# build the model
model = MODELS.build(config.model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
Expand Down
6 changes: 1 addition & 5 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mmengine.hooks import Hook
from mmengine.runner import Runner

from mmedit.utils import print_colored_log, register_all_modules
from mmedit.utils import print_colored_log


# TODO: support fuse_conv_bn, visualization, and format_only
Expand Down Expand Up @@ -45,10 +45,6 @@ def parse_args():
def main():
args = parse_args()

# register all modules in mmedit into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)

# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
Expand Down
6 changes: 1 addition & 5 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mmengine.config import Config, DictAction
from mmengine.runner import Runner

from mmedit.utils import print_colored_log, register_all_modules
from mmedit.utils import print_colored_log


def parse_args():
Expand Down Expand Up @@ -51,10 +51,6 @@ def parse_args():
def main():
args = parse_args()

# register all modules in mmedit into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)

# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
Expand Down