diff --git a/configs/_base_/datasets/cifar10_noaug.py b/configs/_base_/datasets/cifar10_noaug.py index 732433cea4..c135542120 100644 --- a/configs/_base_/datasets/cifar10_noaug.py +++ b/configs/_base_/datasets/cifar10_noaug.py @@ -1,5 +1,3 @@ -# custom_imports = dict( -# imports=['mmcls.datasets.transforms'], allow_failed_imports=False) cifar_pipeline = [dict(type='PackEditInputs')] cifar_dataset = dict( type='CIFAR10', diff --git a/configs/disco_diffusion/README.md b/configs/disco_diffusion/README.md index 4d538f63b4..8456276eda 100644 --- a/configs/disco_diffusion/README.md +++ b/configs/disco_diffusion/README.md @@ -60,10 +60,10 @@ Running the following codes, you can get a text-generated image. ```python from mmengine import Config, MODELS -from mmedit.utils import register_all_modules +from mmengine.registry import init_default_scope from torchvision.utils import save_image -register_all_modules() +init_default_scope('mmedit') disco = MODELS.build( Config.fromfile('configs/disco_diffusion/disco-baseline.py').model).cuda().eval() diff --git a/configs/disco_diffusion/tutorials.ipynb b/configs/disco_diffusion/tutorials.ipynb index 4c6e54d9b2..a730bf18e6 100644 --- a/configs/disco_diffusion/tutorials.ipynb +++ b/configs/disco_diffusion/tutorials.ipynb @@ -97,7 +97,7 @@ "source": [ "import torch\n", "from mmengine import Config, MODELS\n", - "from mmedit.utils import register_all_modules\n", + "from mmengine.registry import init_default_scope\n", "from mmedit.registry import MODULES\n", "from mmcv import tensor2imgs\n", "from matplotlib import pyplot as plt\n", @@ -105,7 +105,7 @@ "from torchvision.transforms import ToPILImage, Normalize, Compose\n", "from IPython.display import Image\n", "\n", - "register_all_modules()\n", + "init_default_scope('mmedit')\n", "\n", "\n", "def show_tensor(image_tensor, index=0):\n", @@ -1417,7 +1417,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 64-bit", + "display_name": "Python 3.8.13 ('mmedit2': conda)", "language": "python", "name": "python3" }, @@ -1436,7 +1436,7 @@ "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "ab5d8c58fa4ba0eabce645db4cbe37b4cabea4c937ae154ab72fe7cc84b68be8" + "hash": "5d897c5c52e082b514ee6f95b827618ca631a30c2bcba2887693dc2fed97e1f9" } } }, diff --git a/configs/stable_diffusion/README.md b/configs/stable_diffusion/README.md index 53274530dc..8d83643a0f 100644 --- a/configs/stable_diffusion/README.md +++ b/configs/stable_diffusion/README.md @@ -55,9 +55,9 @@ Running the following codes, you can get a text-generated image. from mmengine import MODELS, Config from torchvision import utils -from mmedit.utils import register_all_modules +from mmengine.registry import init_default_scope -register_all_modules() +init_default_scope('mmedit') config = 'configs/stable_diffusion/stable-diffusion_ddim_denoisingunet.py' StableDiffuser = MODELS.build(Config.fromfile(config).model) diff --git a/demo/gradio-demo.py b/demo/gradio-demo.py index f7c79a8b22..fd52d0465b 100644 --- a/demo/gradio-demo.py +++ b/demo/gradio-demo.py @@ -12,9 +12,9 @@ import numpy as np import torch import yaml +from mmengine.registry import init_default_scope from mmedit.apis.inferencers.inpainting_inferencer import InpaintingInferencer -from mmedit.utils import register_all_modules class InpaintingGradio: @@ -43,7 +43,7 @@ def __init__(self, extra_parameters: Dict = None, seed: int = 2022, **kwargs) -> None: - register_all_modules(init_default_scope=True) + init_default_scope('mmedit') InpaintingGradio.init_inference_supported_models_cfg() self.model_name = model_name self.model_setting = model_setting diff --git a/mmedit/__init__.py b/mmedit/__init__.py index 4fb5cc0c4c..bf9af38140 100644 --- a/mmedit/__init__.py +++ b/mmedit/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import mmcv +import mmengine from .version import __version__, version_info @@ -25,19 +26,19 @@ def digit_version(version_str): mmcv_max_version = digit_version(MMCV_MAX) mmcv_version = digit_version(mmcv.__version__) -# MMENGINE_MIN = '0.1.0' -# MMENGINE_MAX = '0.2.0' -# mmengine_min_version = digit_version(MMENGINE_MIN) -# mmengine_max_version = digit_version(MMENGINE_MAX) -# mmengine_version = digit_version(mmengine.__version__) +MMENGINE_MIN = '0.4.0' +MMENGINE_MAX = '1.0.0' +mmengine_min_version = digit_version(MMENGINE_MIN) +mmengine_max_version = digit_version(MMENGINE_MAX) +mmengine_version = digit_version(mmengine.__version__) assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ f'mmcv=={mmcv.__version__} is used but incompatible. ' \ f'Please install mmcv-full>={mmcv_min_version}, <{mmcv_max_version}.' -# assert (mmengine_min_version <= mmengine_version < mmengine_max_version), \ -# f'mmengine=={mmengine.__version__} is used but incompatible. ' \ -# f'Please install mmengine>={mmengine_min_version}, ' \ -# f'<{mmengine_max_version}.' +assert (mmengine_min_version <= mmengine_version < mmengine_max_version), \ + f'mmengine=={mmengine.__version__} is used but incompatible. ' \ + f'Please install mmengine>={mmengine_min_version}, ' \ + f'<{mmengine_max_version}.' __all__ = ['__version__', 'version_info'] diff --git a/mmedit/apis/inferencers/inference_functions.py b/mmedit/apis/inferencers/inference_functions.py index f6bc78656f..51343fc919 100644 --- a/mmedit/apis/inferencers/inference_functions.py +++ b/mmedit/apis/inferencers/inference_functions.py @@ -13,6 +13,7 @@ from mmengine.dataset import Compose from mmengine.dataset.utils import default_collate as collate from mmengine.fileio import get_file_backend +from mmengine.registry import init_default_scope from mmengine.runner import load_checkpoint from mmengine.runner import set_random_seed as set_random_seed_engine from mmengine.utils import ProgressBar @@ -20,7 +21,6 @@ from mmedit.models.base_models import BaseTranslationModel from mmedit.registry import MODELS -from mmedit.utils import register_all_modules VIDEO_EXTENSIONS = ('.mp4', '.mov', '.avi') FILE_CLIENT = get_file_backend(backend_args={'backend': 'local'}) @@ -81,7 +81,8 @@ def init_model(config, checkpoint=None, device='cuda:0'): # config.test_cfg.metrics = None delete_cfg(config.model, 'init_cfg') - register_all_modules() + init_default_scope(config.get('default_scope', 'mmedit')) + model = MODELS.build(config.model) if checkpoint is not None: diff --git a/mmedit/edit.py b/mmedit/edit.py index 5fa6198670..e0d9658f8a 100644 --- a/mmedit/edit.py +++ b/mmedit/edit.py @@ -6,10 +6,10 @@ import torch import yaml +from mmengine.registry import init_default_scope from mmedit.apis.inferencers import MMEditInferencer from mmedit.apis.inferencers.base_mmedit_inferencer import InputsType -from mmedit.utils import register_all_modules class MMEdit: @@ -84,7 +84,7 @@ def __init__(self, extra_parameters: Dict = None, seed: int = 2022, **kwargs) -> None: - register_all_modules(init_default_scope=True) + init_default_scope('mmedit') MMEdit.init_inference_supported_models_cfg() inferencer_kwargs = {} inferencer_kwargs.update( diff --git a/mmedit/registry.py b/mmedit/registry.py index 60564ca5da..ac6bf10de8 100644 --- a/mmedit/registry.py +++ b/mmedit/registry.py @@ -12,22 +12,31 @@ from mmengine.registry import Registry # manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` -RUNNERS = Registry('runner', parent=registry.RUNNERS) +RUNNERS = Registry( + 'runner', parent=registry.RUNNERS, locations=['mmedit.engine.runner']) # manage runner constructors that define how to initialize runners RUNNER_CONSTRUCTORS = Registry( - 'runner constructor', parent=registry.RUNNER_CONSTRUCTORS) + 'runner constructor', + parent=registry.RUNNER_CONSTRUCTORS, + locations=['mmedit.engine.runner']) # manage all kinds of loops like `EpochBasedTrainLoop` -LOOPS = Registry('loop', parent=registry.LOOPS) +LOOPS = Registry( + 'loop', parent=registry.LOOPS, locations=['mmedit.engine.runner']) # manage all kinds of hooks like `CheckpointHook` -HOOKS = Registry('hook', parent=registry.HOOKS) +HOOKS = Registry( + 'hook', parent=registry.HOOKS, locations=['mmedit.engine.hooks']) # manage data-related modules -DATASETS = Registry('dataset', parent=registry.DATASETS) +DATASETS = Registry( + 'dataset', parent=registry.DATASETS, locations=['mmedit.datasets']) DATA_SAMPLERS = Registry('data sampler', parent=registry.DATA_SAMPLERS) -TRANSFORMS = Registry('transform', parent=registry.TRANSFORMS) +TRANSFORMS = Registry( + 'transform', + parent=registry.TRANSFORMS, + locations=['mmedit.datasets.transforms']) # manage all kinds of modules inheriting `nn.Module` -MODELS = Registry('model', parent=registry.MODELS) +MODELS = Registry('model', parent=registry.MODELS, locations=['mmedit.models']) MODULES = BACKBONES = COMPONENTS = LOSSES = MODELS # manage all kinds of model wrappers like 'MMDistributedDataParallel' MODEL_WRAPPERS = Registry('model_wrapper', parent=registry.MODEL_WRAPPERS) @@ -37,31 +46,45 @@ # manage all kinds of optimizers like `SGD` and `Adam` OPTIMIZERS = Registry('optimizer', parent=registry.OPTIMIZERS) +# manage optimizer wrapper +OPTIM_WRAPPERS = Registry('optim_wrapper', parent=registry.OPTIM_WRAPPERS) # manage constructors that customize the optimization hyperparameters. OPTIM_WRAPPER_CONSTRUCTORS = Registry( 'optimizer wrapper constructor', - parent=registry.OPTIM_WRAPPER_CONSTRUCTORS) + parent=registry.OPTIM_WRAPPER_CONSTRUCTORS, + locations=['mmedit.engine.optimizers']) # manage all kinds of parameter schedulers like `MultiStepLR` PARAM_SCHEDULERS = Registry( - 'parameter scheduler', parent=registry.PARAM_SCHEDULERS) + 'parameter scheduler', + parent=registry.PARAM_SCHEDULERS, + locations=['mmedit.engine.schedulers']) # manage all kinds of metrics -METRICS = Registry('metric', parent=registry.METRICS) +METRICS = Registry( + 'metric', parent=registry.METRICS, locations=['mmedit.evaluation']) # manage all kinds of evaluators -EVALUATORS = Registry('evaluator', parent=registry.EVALUATOR) +EVALUATORS = Registry( + 'evaluator', parent=registry.EVALUATOR, locations=['mmedit.evaluation']) # manage task-specific modules like anchor generators and box coders TASK_UTILS = Registry('task util', parent=registry.TASK_UTILS) # manage visualizer -VISUALIZERS = Registry('visualizer', parent=registry.VISUALIZERS) +VISUALIZERS = Registry( + 'visualizer', + parent=registry.VISUALIZERS, + locations=['mmedit.visualization']) # manage visualizer backend -VISBACKENDS = Registry('vis_backend', parent=registry.VISBACKENDS) +VISBACKENDS = Registry( + 'vis_backend', + parent=registry.VISBACKENDS, + locations=['mmedit.visualization']) # manage logprocessor -LOG_PROCESSORS = Registry('log_processor', parent=registry.LOG_PROCESSORS) - -# manage optimizer wrapper -OPTIM_WRAPPERS = Registry('optim_wrapper', parent=registry.OPTIM_WRAPPERS) +LOG_PROCESSORS = Registry( + 'log_processor', + parent=registry.LOG_PROCESSORS, + locations=['mmedit.engine.runner']) # manage diffusion_schedulers -DIFFUSION_SCHEDULERS = Registry('diffusion scheduler') +DIFFUSION_SCHEDULERS = Registry( + 'diffusion scheduler', locations=['mmedit.models.editors']) diff --git a/projects/glide/configs/README.md b/projects/glide/configs/README.md index e9081da481..321fec7b9f 100644 --- a/projects/glide/configs/README.md +++ b/projects/glide/configs/README.md @@ -45,10 +45,10 @@ You can run glide as follows: ```python import torch from mmedit.apis import init_model -from mmedit.utils import register_all_modules +from mmengine.registry import init_default_scope from projects.glide.models import * -register_all_modules() +init_default_scope('mmedit') config = 'projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py' ckpt = 'https://download.openmmlab.com/mmediting/glide/glide_laion-64x64-02afff47.pth' diff --git a/requirements/mminstall.txt b/requirements/mminstall.txt index 580af80e5e..992f383c8b 100644 --- a/requirements/mminstall.txt +++ b/requirements/mminstall.txt @@ -1,2 +1,2 @@ mmcv>=2.0.0rc1 -mmengine +mmengine>=0.4.0 diff --git a/tools/analysis_tools/get_flops.py b/tools/analysis_tools/get_flops.py index b1846af030..0f8685066b 100644 --- a/tools/analysis_tools/get_flops.py +++ b/tools/analysis_tools/get_flops.py @@ -3,9 +3,9 @@ import torch from mmengine import Config +from mmengine.registry import init_default_scope from mmedit.registry import MODELS -from mmedit.utils import register_all_modules try: from mmcv.cnn import get_model_complexity_info @@ -39,9 +39,10 @@ def main(): else: raise ValueError('invalid input shape') - register_all_modules() - cfg = Config.fromfile(args.config) + + init_default_scope(cfg.get('default_scope', 'mmedit')) + model = MODELS.build(cfg.model) if torch.cuda.is_available(): model.cuda() diff --git a/tools/model_converters/pytorch2onnx.py b/tools/model_converters/pytorch2onnx.py index f0dabb00ad..14a83d8a4e 100644 --- a/tools/model_converters/pytorch2onnx.py +++ b/tools/model_converters/pytorch2onnx.py @@ -11,11 +11,11 @@ from mmcv.onnx import register_extra_symbolics from mmengine import Config from mmengine.dataset import Compose +from mmengine.registry import init_default_scope from mmengine.runner import load_checkpoint from mmedit.apis import delete_cfg from mmedit.registry import MODELS -from mmedit.utils import register_all_modules def pytorch2onnx(model, @@ -190,6 +190,8 @@ def parse_args(): config = Config.fromfile(args.config) delete_cfg(config, key='init_cfg') + init_default_scope(config.get('default_scope', 'mmedit')) + # ONNX does not support spectral norm if model_type == 'mattor': if hasattr(config.model.backbone.encoder, 'with_spectral_norm'): @@ -197,8 +199,6 @@ def parse_args(): config.model.backbone.decoder.with_spectral_norm = False config.test_cfg.metrics = None - register_all_modules() - # build the model model = MODELS.build(config.model) checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') diff --git a/tools/test.py b/tools/test.py index f353544daa..1a167cafd7 100644 --- a/tools/test.py +++ b/tools/test.py @@ -8,7 +8,7 @@ from mmengine.hooks import Hook from mmengine.runner import Runner -from mmedit.utils import print_colored_log, register_all_modules +from mmedit.utils import print_colored_log # TODO: support fuse_conv_bn, visualization, and format_only @@ -45,10 +45,6 @@ def parse_args(): def main(): args = parse_args() - # register all modules in mmedit into the registries - # do not init the default scope here because it will be init in the runner - register_all_modules(init_default_scope=False) - # load config cfg = Config.fromfile(args.config) cfg.launcher = args.launcher diff --git a/tools/train.py b/tools/train.py index 3da472583a..37d6214ea4 100644 --- a/tools/train.py +++ b/tools/train.py @@ -7,7 +7,7 @@ from mmengine.config import Config, DictAction from mmengine.runner import Runner -from mmedit.utils import print_colored_log, register_all_modules +from mmedit.utils import print_colored_log def parse_args(): @@ -51,10 +51,6 @@ def parse_args(): def main(): args = parse_args() - # register all modules in mmedit into the registries - # do not init the default scope here because it will be init in the runner - register_all_modules(init_default_scope=False) - # load config cfg = Config.fromfile(args.config) cfg.launcher = args.launcher