Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support auto import modules from registry. #1731

Merged
merged 3 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/get_started/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,6 @@ MMOCR has different version requirements on MMEngine, MMCV and MMDetection at ea

| MMOCR | MMEngine | MMCV | MMDetection |
| -------------- | --------------------------- | -------------------------- | --------------------------- |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
2 changes: 1 addition & 1 deletion docs/zh_cn/get_started/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,6 @@ docker run --gpus all --shm-size=8g -it -v {实际数据目录}:/mmocr/data mmoc

| MMOCR | MMEngine | MMCV | MMDetection |
| -------------- | --------------------------- | -------------------------- | --------------------------- |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
2 changes: 1 addition & 1 deletion mmocr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
f'Please install mmengine>={mmengine_minimum_version}, ' \
f'<{mmengine_maximum_version}.'

mmdet_minimum_version = '3.0.0rc0'
mmdet_minimum_version = '3.0.0rc5'
mmdet_maximum_version = '3.1.0'
mmdet_version = digit_version(mmdet.__version__)

Expand Down
5 changes: 3 additions & 2 deletions mmocr/apis/inferencers/base_mmocr_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import numpy as np
from mmengine.dataset import Compose
from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.registry import init_default_scope
from mmengine.structures import InstanceData
from torch import Tensor

from mmocr.utils import ConfigType, register_all_modules
from mmocr.utils import ConfigType

InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(self,
# A global counter tracking the number of images processed, for
# naming of the output images
self.num_visualized_imgs = 0
register_all_modules()
init_default_scope(scope)
super().__init__(
model=model, weights=weights, device=device, scope=scope)

Expand Down
94 changes: 73 additions & 21 deletions mmocr/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,51 +32,103 @@
from mmengine.registry import Registry

# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
RUNNERS = Registry(
'runner',
parent=MMENGINE_RUNNERS,
# TODO: update the location when mmocr has its own runner
locations=['mmocr.engine'])
# manage runner constructors that define how to initialize runners
RUNNER_CONSTRUCTORS = Registry(
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
'runner constructor',
parent=MMENGINE_RUNNER_CONSTRUCTORS,
# TODO: update the location when mmocr has its own runner constructor
locations=['mmocr.engine'])
# manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
LOOPS = Registry(
'loop',
parent=MMENGINE_LOOPS,
# TODO: update the location when mmocr has its own loop
locations=['mmocr.engine'])
# manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
HOOKS = Registry(
'hook', parent=MMENGINE_HOOKS, locations=['mmocr.engine.hooks'])

# manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
DATASETS = Registry(
'dataset', parent=MMENGINE_DATASETS, locations=['mmocr.datasets'])
DATA_SAMPLERS = Registry(
'data sampler',
parent=MMENGINE_DATA_SAMPLERS,
locations=['mmocr.datasets.samplers'])
TRANSFORMS = Registry(
'transform',
parent=MMENGINE_TRANSFORMS,
locations=['mmocr.datasets.transforms'])

# manage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmocr.models'])
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
MODEL_WRAPPERS = Registry(
'model wrapper',
parent=MMENGINE_MODEL_WRAPPERS,
locations=['mmocr.models'])
# manage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
'weight initializer',
parent=MMENGINE_WEIGHT_INITIALIZERS,
locations=['mmocr.models'])

# manage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
OPTIMIZERS = Registry(
'optimizer',
parent=MMENGINE_OPTIMIZERS,
# TODO: update the location when mmocr has its own optimizer
locations=['mmocr.engine'])
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
OPTIM_WRAPPERS = Registry(
'optimizer wrapper',
parent=MMENGINE_OPTIM_WRAPPERS,
# TODO: update the location when mmocr has its own optimizer wrapper
locations=['mmocr.engine'])
# manage constructors that customize the optimization hyperparameters.
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
'optimizer constructor',
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
# TODO: update the location when mmocr has its own optimizer constructor
locations=['mmocr.engine'])
# manage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)

'parameter scheduler',
parent=MMENGINE_PARAM_SCHEDULERS,
# TODO: update the location when mmocr has its own parameter scheduler
locations=['mmocr.engine'])
# manage all kinds of metrics
METRICS = Registry('metric', parent=MMENGINE_METRICS)
METRICS = Registry(
'metric', parent=MMENGINE_METRICS, locations=['mmocr.evaluation.metrics'])
# manage evaluator
EVALUATOR = Registry('evaluator', parent=MMENGINE_EVALUATOR)
EVALUATOR = Registry(
'evaluator',
parent=MMENGINE_EVALUATOR,
locations=['mmocr.evaluation.evaluator'])

# manage task-specific modules like anchor generators and box coders
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
TASK_UTILS = Registry(
'task util', parent=MMENGINE_TASK_UTILS, locations=['mmocr.models'])

# manage visualizer
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
VISUALIZERS = Registry(
'visualizer',
parent=MMENGINE_VISUALIZERS,
locations=['mmocr.visualization'])
# manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
VISBACKENDS = Registry(
'visualizer backend',
parent=MMENGINE_VISBACKENDS,
locations=['mmocr.visualization'])

# manage logprocessor
LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS)
LOG_PROCESSORS = Registry(
'logger processor',
parent=MMENGINE_LOG_PROCESSORS,
# TODO: update the location when mmocr has its own log processor
locations=['mmocr.engine'])
6 changes: 3 additions & 3 deletions requirements/mminstall.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
mmcv>==2.0.0rc1,<2.1.0
mmdet>=3.0.0rc0,<3.1.0
mmengine>= 0.1.0, <1.0.0
mmcv>==2.0.0rc4,<2.1.0
mmdet>=3.0.0rc5,<3.1.0
mmengine>= 0.5.0, <1.0.0
5 changes: 3 additions & 2 deletions tests/test_datasets/test_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from unittest import TestCase
from unittest.mock import MagicMock

from mmengine.registry import init_default_scope

from mmocr.datasets import ConcatDataset, OCRDataset
from mmocr.registry import TRANSFORMS
from mmocr.utils import register_all_modules


class TestConcatDataset(TestCase):
Expand All @@ -22,7 +23,7 @@ def __call__(self, *args, **kwargs):

def setUp(self):

register_all_modules()
init_default_scope('mmocr')
dataset = OCRDataset

# create dataset_a
Expand Down
5 changes: 3 additions & 2 deletions tests/test_models/test_textdet/test_detectors/test_drrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
import numpy as np
import torch
from mmengine.config import Config, ConfigDict
from mmengine.registry import init_default_scope

from mmocr.registry import MODELS
from mmocr.testing.data import create_dummy_textdet_inputs
from mmocr.utils import register_all_modules


class TestDRRG(unittest.TestCase):

def setUp(self):
cfg_path = 'textdet/drrg/drrg_resnet50_fpn-unet_1200e_ctw1500.py'
self.model_cfg = self._get_detector_cfg(cfg_path)
register_all_modules()
cfg = self._get_config_module(cfg_path)
init_default_scope(cfg.get('default_scope', 'mmocr'))
self.model = MODELS.build(self.model_cfg)
self.inputs = create_dummy_textdet_inputs(input_shape=(1, 3, 224, 224))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from mmdet.structures import DetDataSample
from mmdet.testing import demo_mm_inputs
from mmengine.config import Config
from mmengine.registry import init_default_scope
from mmengine.structures import InstanceData

from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import register_all_modules


class TestMMDetWrapper(unittest.TestCase):

def setUp(self):
register_all_modules()
init_default_scope('mmocr')
model_cfg_fcos = dict(
type='MMDetWrapper',
cfg=dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from unittest import TestCase

import torch
from mmengine.registry import init_default_scope

from mmocr.models.textrecog.backbones import ResNet
from mmocr.utils import register_all_modules


class TestResNet(TestCase):

def setUp(self) -> None:
self.img = torch.rand(1, 3, 32, 100)
register_all_modules()
init_default_scope('mmocr')

def test_resnet45_aster(self):
resnet45_aster = ResNet(
Expand Down
5 changes: 2 additions & 3 deletions tools/analysis_tools/browse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import numpy as np
from mmengine.config import Config, DictAction
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar
from mmengine.visualization import Visualizer

from mmocr.registry import DATASETS, VISUALIZERS
from mmocr.utils import register_all_modules


# TODO: Support for printing the change in key of results
Expand Down Expand Up @@ -331,8 +331,7 @@ def main():
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

# register all modules in mmyolo into the registries
register_all_modules()
init_default_scope(cfg.get('default_scope', 'mmocr'))

dataset_cfg, visualizer_cfg = obtain_dataset_cfg(cfg, args.phase,
args.mode, args.task)
Expand Down
5 changes: 2 additions & 3 deletions tools/analysis_tools/get_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table
from mmengine import Config
from mmengine.registry import init_default_scope

from mmocr.registry import MODELS
from mmocr.utils import register_all_modules

register_all_modules()


def parse_args():
Expand Down Expand Up @@ -38,6 +36,7 @@ def main():
input_shape = (1, 3, h, w)

cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmocr'))
model = MODELS.build(cfg.model)

flops = FlopCountAnalysis(model, torch.ones(input_shape))
Expand Down
6 changes: 2 additions & 4 deletions tools/analysis_tools/offline_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import mmengine
from mmengine.config import Config, DictAction
from mmengine.evaluator import Evaluator

from mmocr.utils import register_all_modules
from mmengine.registry import init_default_scope


def parse_args():
Expand All @@ -33,10 +32,9 @@ def parse_args():
def main():
args = parse_args()

register_all_modules()

# load config
cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmocr'))
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

Expand Down
2 changes: 0 additions & 2 deletions tools/dataset_converters/prepare_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import warnings

from mmocr.datasets.preparers import DatasetPreparer
from mmocr.utils import register_all_modules


def parse_args():
Expand Down Expand Up @@ -39,7 +38,6 @@ def parse_args():

def main():
args = parse_args()
register_all_modules()
for dataset in args.datasets:
if not osp.isdir(osp.join(args.dataset_zoo_path, dataset)):
warnings.warn(f'{dataset} is not supported yet. Please check '
Expand Down
6 changes: 0 additions & 6 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from mmengine.registry import RUNNERS
from mmengine.runner import Runner

from mmocr.utils import register_all_modules


def parse_args():
parser = argparse.ArgumentParser(description='Test (and eval) a model')
Expand Down Expand Up @@ -80,10 +78,6 @@ def trigger_visualization_hook(cfg, args):
def main():
args = parse_args()

# register all modules in mmocr into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)

# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
Expand Down
6 changes: 0 additions & 6 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from mmengine.registry import RUNNERS
from mmengine.runner import Runner

from mmocr.utils import register_all_modules


def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
Expand Down Expand Up @@ -54,10 +52,6 @@ def parse_args():
def main():
args = parse_args()

# register all modules in mmdet into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)

# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
Expand Down