Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Avoid scope switching when using mmdet inference interface #2039

Merged
merged 3 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demo/topdown_demo_with_mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import mmcv
import mmengine
import numpy as np
from mmengine.registry import init_default_scope

from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples, split_instances
from mmpose.utils import adapt_mmdet_pipeline

try:
from mmdet.apis import inference_detector, init_detector
Expand All @@ -28,7 +28,6 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer,
"""Visualize predicted keypoints (and heatmaps) of one image."""

# predict bbox
init_default_scope(detector.cfg.get('default_scope', 'mmdet'))
det_result = inference_detector(detector, img_path)
pred_instance = det_result.pred_instances.cpu().numpy()
bboxes = np.concatenate(
Expand Down Expand Up @@ -147,6 +146,7 @@ def main():
# build detector
detector = init_detector(
args.det_config, args.det_checkpoint, device=args.device)
detector.cfg = adapt_mmdet_pipeline(detector.cfg)

# build pose estimator
pose_estimator = init_pose_estimator(
Expand Down
4 changes: 2 additions & 2 deletions mmpose/apis/webcam/nodes/model_nodes/detector_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Dict, List, Optional, Union

import numpy as np
from mmengine.registry import init_default_scope

from mmpose.utils import adapt_mmdet_pipeline
from ...utils import get_config_path
from ..node import Node
from ..registry import NODES
Expand Down Expand Up @@ -92,6 +92,7 @@ def __init__(self,
# Init model
self.model = init_detector(
self.model_config, self.model_checkpoint, device=self.device)
self.model.cfg = adapt_mmdet_pipeline(self.model.cfg)

# Register buffers
self.register_input_buffer(input_buffer, 'input', trigger=True)
Expand All @@ -109,7 +110,6 @@ def process(self, input_msgs):

img = input_msg.get_image()

init_default_scope(self.model.cfg.get('default_scope', 'mmdet'))
preds = inference_detector(self.model, img)
objects = self._post_process(preds)
input_msg.update_objects(objects)
Expand Down
4 changes: 3 additions & 1 deletion mmpose/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .camera import SimpleCamera, SimpleCameraTorch
from .collect_env import collect_env
from .config_utils import adapt_mmdet_pipeline
from .logger import get_root_logger
from .setup_env import register_all_modules, setup_multi_processes
from .timer import StopWatch

__all__ = [
'get_root_logger', 'collect_env', 'StopWatch', 'setup_multi_processes',
'register_all_modules', 'SimpleCamera', 'SimpleCameraTorch'
'register_all_modules', 'SimpleCamera', 'SimpleCameraTorch',
'adapt_mmdet_pipeline'
]
26 changes: 26 additions & 0 deletions mmpose/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmpose.utils.typing import ConfigDict


def adapt_mmdet_pipeline(cfg: ConfigDict) -> ConfigDict:
"""Converts pipeline types in MMDetection's test dataloader to use the
'mmdet' namespace.

Args:
cfg (ConfigDict): Configuration dictionary for MMDetection.

Returns:
ConfigDict: Configuration dictionary with updated pipeline types.
"""
# use lazy import to avoid hard dependence on mmdet
from mmdet.datasets import transforms

if 'test_dataloader' not in cfg:
return cfg

pipeline = cfg.test_dataloader.dataset.pipeline
for trans in pipeline:
if trans['type'] in dir(transforms):
trans['type'] = 'mmdet.' + trans['type']

return cfg