Skip to content

Commit

Permalink
[Feature] Support auto import modules from registry (#1961)
Browse files Browse the repository at this point in the history
* update mmengine version & registry path

* replace register_all_modules

* refine dekr configs

* move init_default_scope to inference apis

* fix a filename bug in visualization_hook
  • Loading branch information
Ben-Louis authored Feb 13, 2023
1 parent d83c4ba commit 0b7abbc
Show file tree
Hide file tree
Showing 21 changed files with 78 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# codec settings
codec = dict(
type='RootDisplacement',
type='SPR',
input_size=(512, 512),
heatmap_size=(128, 128),
sigma=(4, 2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# codec settings
codec = dict(
type='RootDisplacement',
type='SPR',
input_size=(640, 640),
heatmap_size=(160, 160),
sigma=(4, 2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# codec settings
codec = dict(
type='RootDisplacement',
type='SPR',
input_size=(512, 512),
heatmap_size=(128, 128),
sigma=(4, 2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# codec settings
codec = dict(
type='RootDisplacement',
type='SPR',
input_size=(640, 640),
heatmap_size=(160, 160),
sigma=(4, 2),
Expand Down
19 changes: 5 additions & 14 deletions demo/MMPose_Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -622,18 +622,17 @@
"import mmcv\n",
"from mmcv import imread\n",
"import mmengine\n",
"from mmengine.registry import init_default_scope\n",
"import numpy as np\n",
"\n",
"from mmpose.apis import inference_topdown\n",
"from mmpose.apis import init_model as init_pose_estimator\n",
"from mmpose.evaluation.functional import nms\n",
"from mmpose.registry import VISUALIZERS\n",
"from mmpose.structures import merge_data_samples\n",
"from mmpose.utils import register_all_modules as register_mmpose_modules\n",
"\n",
"try:\n",
" from mmdet.apis import inference_detector, init_detector\n",
" from mmdet.utils import register_all_modules as register_mmdet_modules\n",
" has_mmdet = True\n",
"except (ImportError, ModuleNotFoundError):\n",
" has_mmdet = False\n",
Expand All @@ -656,7 +655,6 @@
"\n",
"\n",
"# build detector\n",
"register_mmdet_modules()\n",
"detector = init_detector(\n",
" det_config,\n",
" det_checkpoint,\n",
Expand All @@ -665,7 +663,6 @@
"\n",
"\n",
"# build pose estimator\n",
"register_mmpose_modules()\n",
"pose_estimator = init_pose_estimator(\n",
" pose_config,\n",
" pose_checkpoint,\n",
Expand Down Expand Up @@ -696,7 +693,7 @@
" \"\"\"Visualize predicted keypoints (and heatmaps) of one image.\"\"\"\n",
"\n",
" # predict bbox\n",
" register_mmdet_modules()\n",
" init_default_scope(detector.cfg.get('default_scope', 'mmdet'))\n",
" detect_result = inference_detector(detector, img_path)\n",
" pred_instance = detect_result.pred_instances.cpu().numpy()\n",
" bboxes = np.concatenate(\n",
Expand All @@ -706,7 +703,6 @@
" bboxes = bboxes[nms(bboxes, 0.3)][:, :4]\n",
"\n",
" # predict keypoints\n",
" register_mmpose_modules()\n",
" pose_results = inference_topdown(pose_estimator, img_path, bboxes)\n",
" data_samples = merge_data_samples(pose_results)\n",
"\n",
Expand Down Expand Up @@ -3476,11 +3472,6 @@
"source": [
"from mmengine.config import Config, DictAction\n",
"from mmengine.runner import Runner\n",
"from mmpose.utils import register_all_modules\n",
"\n",
"# register all modules in mmpose into the registries\n",
"# do not init the default scope here because it will be init in the runner\n",
"register_all_modules(init_default_scope=False)\n",
"\n",
"# set preprocess configs to model\n",
"cfg.model.setdefault('data_preprocessor', cfg.get('preprocess_cfg', {}))\n",
Expand Down Expand Up @@ -3517,7 +3508,7 @@
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3.7.13 ('pt19cu113')",
"display_name": "dev2.0",
"language": "python",
"name": "python3"
},
Expand All @@ -3531,11 +3522,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
"version": "3.7.13 (default, Mar 29 2022, 02:18:16) \n[GCC 7.5.0]"
},
"vscode": {
"interpreter": {
"hash": "da739a86cd93f7808d44852bc442711db64702daf7deb8b8d6704b313da8028c"
"hash": "383ba00087b5a9caebf3648b758a31e474cc01be975489b58f119fa4bc17e1f8"
}
},
"widgets": {
Expand Down
4 changes: 0 additions & 4 deletions demo/bottomup_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from mmpose.apis import inference_bottomup, init_model
from mmpose.registry import VISUALIZERS
from mmpose.structures import split_instances
from mmpose.utils import register_all_modules


def process_one_image(args, img_path, pose_estimator, visualizer,
Expand Down Expand Up @@ -99,9 +98,6 @@ def main():
args.pred_save_path = f'{args.output_root}/results_' \
f'{os.path.splitext(os.path.basename(args.input))[0]}.json'

# register all modules in mmpose into the registries
register_all_modules()

# build the model from a config file and a checkpoint file
if args.draw_heatmap:
cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True)))
Expand Down
4 changes: 0 additions & 4 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from mmpose.apis import inference_topdown, init_model
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
from mmpose.utils import register_all_modules


def parse_args():
Expand All @@ -28,9 +27,6 @@ def parse_args():
def main():
args = parse_args()

# register all modules in mmpose into the registries
register_all_modules()

# build the model from a config file and a checkpoint file
if args.draw_heatmap:
cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True)))
Expand Down
8 changes: 2 additions & 6 deletions demo/topdown_demo_with_mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
import mmcv
import mmengine
import numpy as np
from mmengine.registry import init_default_scope

from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples, split_instances
from mmpose.utils import register_all_modules as register_mmpose_modules

try:
from mmdet.apis import inference_detector, init_detector
from mmdet.utils import register_all_modules as register_mmdet_modules
has_mmdet = True
except (ImportError, ModuleNotFoundError):
has_mmdet = False
Expand All @@ -29,7 +28,7 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer,
"""Visualize predicted keypoints (and heatmaps) of one image."""

# predict bbox
register_mmdet_modules()
init_default_scope(detector.cfg.get('default_scope', 'mmdet'))
det_result = inference_detector(detector, img_path)
pred_instance = det_result.pred_instances.cpu().numpy()
bboxes = np.concatenate(
Expand All @@ -39,7 +38,6 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer,
bboxes = bboxes[nms(bboxes, args.nms_thr), :4]

# predict keypoints
register_mmpose_modules()
pose_results = inference_topdown(pose_estimator, img_path, bboxes)
data_samples = merge_data_samples(pose_results)

Expand Down Expand Up @@ -146,12 +144,10 @@ def main():
f'{os.path.splitext(os.path.basename(args.input))[0]}.json'

# build detector
register_mmdet_modules()
detector = init_detector(
args.det_config, args.det_checkpoint, device=args.device)

# build pose estimator
register_mmpose_modules()
pose_estimator = init_pose_estimator(
args.pose_config,
args.pose_checkpoint,
Expand Down
2 changes: 0 additions & 2 deletions demo/topdown_face_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples, split_instances
from mmpose.utils import register_all_modules as register_mmpose_modules

try:
import face_recognition
Expand Down Expand Up @@ -146,7 +145,6 @@ def main():
f'{os.path.splitext(os.path.basename(args.input))[0]}.json'

# build pose estimator
register_mmpose_modules()
pose_estimator = init_pose_estimator(
args.pose_config,
args.pose_checkpoint,
Expand Down
2 changes: 1 addition & 1 deletion mmpose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
mmcv_maximum_version = '2.1.0'
mmcv_version = digit_version(mmcv.__version__)

mmengine_minimum_version = '0.1.0'
mmengine_minimum_version = '0.4.0'
mmengine_maximum_version = '1.0.0'
mmengine_version = digit_version(mmengine.__version__)

Expand Down
5 changes: 5 additions & 0 deletions mmpose/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn
from mmengine.config import Config
from mmengine.dataset import Compose, pseudo_collate
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from PIL import Image

Expand Down Expand Up @@ -93,6 +94,9 @@ def init_model(config: Union[str, Path, Config],
config.model.backbone.init_cfg = None
config.model.train_cfg = None

# register all modules in mmpose into the registries
init_default_scope(config.get('default_scope', 'mmpose'))

model = build_pose_estimator(config.model)
# get dataset_meta in this priority: checkpoint > config > default (COCO)
dataset_meta = None
Expand Down Expand Up @@ -143,6 +147,7 @@ def inference_topdown(model: nn.Module,
``data_sample.pred_instances.keypoints`` and
``data_sample.pred_instances.keypoint_scores``.
"""
init_default_scope(model.cfg.get('default_scope', 'mmpose'))
pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)

if bboxes is None:
Expand Down
5 changes: 2 additions & 3 deletions mmpose/apis/webcam/nodes/model_nodes/detector_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from typing import Dict, List, Optional, Union

import numpy as np
from mmengine.registry import init_default_scope

from ...utils import get_config_path
from ..node import Node
from ..registry import NODES

try:
from mmdet.apis import inference_detector, init_detector
from mmdet.utils import register_all_modules
has_mmdet = True
except (ImportError, ModuleNotFoundError):
has_mmdet = False
Expand Down Expand Up @@ -90,7 +90,6 @@ def __init__(self,
self.bbox_thr = bbox_thr

# Init model
register_all_modules()
self.model = init_detector(
self.model_config, self.model_checkpoint, device=self.device)

Expand All @@ -110,7 +109,7 @@ def process(self, input_msgs):

img = input_msg.get_image()

register_all_modules()
init_default_scope(self.model.cfg.get('default_scope', 'mmdet'))
preds = inference_detector(self.model, img)
objects = self._post_process(preds)
input_msg.update_objects(objects)
Expand Down
3 changes: 0 additions & 3 deletions mmpose/apis/webcam/nodes/model_nodes/pose_estimator_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np

from mmpose.apis import inference_topdown, init_model
from mmpose.utils import register_all_modules
from ...utils import get_config_path
from ..node import Node
from ..registry import NODES
Expand Down Expand Up @@ -91,7 +90,6 @@ def __init__(self,
self.bbox_thr = bbox_thr

# Init model
register_all_modules()
self.model = init_model(
self.model_config, self.model_checkpoint, device=self.device)

Expand Down Expand Up @@ -119,7 +117,6 @@ def process(self, input_msgs):
if len(objects) > 0:
# Inference pose
bboxes = np.stack([object['bbox'] for object in objects])
register_all_modules()
pose_results = inference_topdown(self.model, img, bboxes)

# Update objects
Expand Down
3 changes: 2 additions & 1 deletion mmpose/engine/hooks/visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,

out_file = None
if self.out_dir is not None:
out_file_name, postfix = os.path.basename(img_path).split('.')
out_file_name, postfix = os.path.basename(img_path).rsplit(
'.', 1)
index = len([
fname for fname in os.listdir(self.out_dir)
if fname.startswith(out_file_name)
Expand Down
Loading

0 comments on commit 0b7abbc

Please sign in to comment.