Skip to content

Commit

Permalink
compat mmpose v0.26 (#518)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon authored May 27, 2022
1 parent 0878b8f commit 32482e7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
19 changes: 12 additions & 7 deletions mmdeploy/codebase/mmpose/deploy/pose_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def create_input(self,
Returns:
tuple: (data, img), meta information for the input image and input.
"""
from mmpose.apis.inference import _box2cs
from mmpose.datasets.dataset_info import DatasetInfo
from mmpose.datasets.pipelines import Compose

Expand All @@ -160,17 +159,12 @@ def create_input(self,
image_size = input_shape
else:
image_size = np.array(cfg.data_cfg['image_size'])
for bbox in bboxes:
center, scale = _box2cs(cfg, bbox)

for bbox in bboxes:
# prepare data
data = {
'img':
imgs,
'center':
center,
'scale':
scale,
'bbox_score':
bbox[4] if len(bbox) == 5 else 1,
'bbox_id':
Expand All @@ -190,6 +184,17 @@ def create_input(self,
}
}

# for compatibility of mmpose
try:
# for mmpose<=v0.25.1
from mmpose.apis.inference import _box2cs
center, scale = _box2cs(cfg, bbox)
data['center'] = center
data['scale'] = scale
except ImportError:
# for mmpose>=v0.26.0
data['bbox'] = bbox

data = test_pipeline(data)
batch_data.append(data)

Expand Down
7 changes: 7 additions & 0 deletions tests/test_codebase/test_mmpose/data/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# model settings
import mmpose
from packaging import version

channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
Expand Down Expand Up @@ -47,6 +50,7 @@

test_pipeline = [
dict(type='LoadImageFromFile'),
# dict(type='TopDownGetBboxCenterScale'),
dict(type='TopDownAffine'),
dict(type='ToTensor'),
dict(
Expand All @@ -61,6 +65,9 @@
'flip_pairs'
]),
]
# compatible with mmpose >=v0.26.0
if version.parse(mmpose.__version__) >= version.parse('0.26.0'):
test_pipeline.insert(1, dict(type='TopDownGetBboxCenterScale'))

dataset_info = dict(
dataset_name='coco',
Expand Down
1 change: 0 additions & 1 deletion tests/test_codebase/test_mmpose/test_pose_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@


def test_create_input():
model_cfg = load_config(model_cfg_path)[0]
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=Backend.ONNXRUNTIME.value),
Expand Down

0 comments on commit 32482e7

Please sign in to comment.