Skip to content

Commit

Permalink
[Feature] Support hand3d inferencer (#2729)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored Sep 27, 2023
1 parent 18842d6 commit efe0c8d
Show file tree
Hide file tree
Showing 13 changed files with 662 additions and 294 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Collections:
Models:
- Config: configs/hand_3d_keypoint/internet/interhand3d/internet_res50_4xb16-20e_interhand3d-256x256.py
In Collection: InterNet
Alias: hand3d
Metadata:
Architecture: &id001
- InterNet
Expand Down
14 changes: 14 additions & 0 deletions demo/docs/en/3d_hand_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,17 @@ python demo/hand3d_internet_demo.py \
--save-predictions \
--output-root vis_results
```

### 3D Hand Pose Estimation with Inferencer

The Inferencer provides a convenient interface for inference, allowing customization using model aliases instead of configuration files and checkpoint paths. It supports various input formats, including image paths, video paths, image folder paths, and webcams. Below is an example command:

```shell
python demo/inferencer_demo.py tests/data/interhand2.6m/image29590.jpg --pose3d hand3d --vis-out-dir vis_results/hand3d
```

This command infers the image and saves the visualization results in the `vis_results/hand3d` directory.

<img src="https://github.com/open-mmlab/mmpose/assets/26127467/29218285-aff6-455f-9763-39e8539eae61" alt="Image 1" height="300"/>

In addition, the Inferencer supports saving predicted poses. For more information, please refer to the [inferencer document](https://mmpose.readthedocs.io/en/latest/user_guides/inference.html#inferencer-a-unified-inference-interface).
5 changes: 4 additions & 1 deletion demo/hand3d_internet_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ def main():
instance_info=pred_instances_list),
f,
indent='\t')
print(f'predictions have been saved at {args.pred_save_path}')
print_log(
f'predictions have been saved at {args.pred_save_path}',
logger='current',
level=logging.INFO)

if output_file is not None:
input_type = input_type.replace('webcam', 'video')
Expand Down
7 changes: 4 additions & 3 deletions docs/en/user_guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,10 @@ The MMPose library has predefined aliases for several frequently used models. Th

The following table lists the available 3D model aliases and their corresponding configuration names:

| Alias | Configuration Name | Task | 3D Pose Estimator | 2D Pose Estimator | Detector |
| ------- | --------------------------------- | ------------------------ | ----------------- | ----------------- | -------- |
| human3d | vid_pl_motionbert_8xb32-120e_h36m | Human 3D pose estimation | MotionBert | RTMPose-m | RTMDet-m |
| Alias | Configuration Name | Task | 3D Pose Estimator | 2D Pose Estimator | Detector |
| ------- | -------------------------------------------- | ------------------------ | ----------------- | ----------------- | ----------- |
| human3d | vid_pl_motionbert_8xb32-120e_h36m | Human 3D pose estimation | MotionBert | RTMPose-m | RTMDet-m |
| hand3d | internet_res50_4xb16-20e_interhand3d-256x256 | Hand 3D pose estimation | InterNet | - | whole image |

In addition, users can utilize the CLI tool to display all available aliases with the following command:

Expand Down
7 changes: 4 additions & 3 deletions docs/zh_cn/user_guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,10 @@ MMPose 为常用模型提供了一组预定义的别名。在初始化 [MMPoseIn

下表列出了可用的 3D 姿态估计模型别名及其对应的配置文件:

| 别名 | 配置文件名称 | 对应任务 | 3D 姿态估计模型 | 2D 姿态估计模型 | 检测模型 |
| ------- | --------------------------------- | --------------- | --------------- | --------------- | -------- |
| human3d | vid_pl_motionbert_8xb32-120e_h36m | 3D 人体姿态估计 | MotionBert | RTMPose-m | RTMDet-m |
| 别名 | 配置文件名称 | 对应任务 | 3D 姿态估计模型 | 2D 姿态估计模型 | 检测模型 |
| ------- | -------------------------------------------- | ----------------- | --------------- | --------------- | -------- |
| human3d | vid_pl_motionbert_8xb32-120e_h36m | 3D 人体姿态估计 | MotionBert | RTMPose-m | RTMDet-m |
| hand3d | internet_res50_4xb16-20e_interhand3d-256x256 | 3D 手部关键点检测 | InterNet | - | 全图 |

此外,用户可以使用命令行界面工具显示所有可用的别名,使用以下命令:

Expand Down
3 changes: 2 additions & 1 deletion mmpose/apis/inferencers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hand3d_inferencer import Hand3DInferencer
from .mmpose_inferencer import MMPoseInferencer
from .pose2d_inferencer import Pose2DInferencer
from .pose3d_inferencer import Pose3DInferencer
from .utils import get_model_aliases

__all__ = [
'Pose2DInferencer', 'MMPoseInferencer', 'get_model_aliases',
'Pose3DInferencer'
'Pose3DInferencer', 'Hand3DInferencer'
]
217 changes: 183 additions & 34 deletions mmpose/apis/inferencers/base_mmpose_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from collections import defaultdict
from typing import (Callable, Dict, Generator, Iterable, List, Optional,
Sequence, Union)
Sequence, Tuple, Union)

import cv2
import mmcv
Expand All @@ -15,15 +15,23 @@
from mmengine.dataset import Compose
from mmengine.fileio import (get_file_backend, isdir, join_path,
list_dir_or_file)
from mmengine.infer.infer import BaseInferencer
from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.logging import print_log
from mmengine.registry import init_default_scope
from mmengine.runner.checkpoint import _load_checkpoint_to_model
from mmengine.structures import InstanceData
from mmengine.utils import mkdir_or_exist

from mmpose.apis.inference import dataset_meta_from_config
from mmpose.registry import DATASETS
from mmpose.structures import PoseDataSample, split_instances
from .utils import default_det_models

try:
from mmdet.apis.det_inferencer import DetInferencer
has_mmdet = True
except (ImportError, ModuleNotFoundError):
has_mmdet = False

InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
Expand All @@ -45,6 +53,44 @@ class BaseMMPoseInferencer(BaseInferencer):
}
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}

def _init_detector(
self,
det_model: Optional[Union[ModelType, str]] = None,
det_weights: Optional[str] = None,
det_cat_ids: Optional[Union[int, Tuple]] = None,
device: Optional[str] = None,
):
object_type = DATASETS.get(self.cfg.dataset_type).__module__.split(
'datasets.')[-1].split('.')[0].lower()

if det_model in ('whole_image', 'whole-image') or \
(det_model is None and
object_type not in default_det_models):
self.detector = None

else:
det_scope = 'mmdet'
if det_model is None:
det_info = default_det_models[object_type]
det_model, det_weights, det_cat_ids = det_info[
'model'], det_info['weights'], det_info['cat_ids']
elif os.path.exists(det_model):
det_cfg = Config.fromfile(det_model)
det_scope = det_cfg.default_scope

if has_mmdet:
self.detector = DetInferencer(
det_model, det_weights, device=device, scope=det_scope)
else:
raise RuntimeError(
'MMDetection (v3.0.0 or above) is required to build '
'inferencers for top-down pose estimation models.')

if isinstance(det_cat_ids, (tuple, list)):
self.det_cat_ids = det_cat_ids
else:
self.det_cat_ids = (det_cat_ids, )

def _load_weights_to_model(self, model: nn.Module,
checkpoint: Optional[dict],
cfg: Optional[ConfigType]) -> None:
Expand Down Expand Up @@ -266,6 +312,101 @@ def preprocess(self,
# only supports inference with batch size 1
yield self.collate_fn(data_infos), [input]

def __call__(
self,
inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1,
out_dir: Optional[str] = None,
**kwargs,
) -> dict:
"""Call the inferencer.
Args:
inputs (InputsType): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
out_dir (str, optional): directory to save visualization
results and predictions. Will be overoden if vis_out_dir or
pred_out_dir are given. Defaults to None
**kwargs: Key words arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``,
``visualize_kwargs`` and ``postprocess_kwargs``.
Returns:
dict: Inference and visualization results.
"""
if out_dir is not None:
if 'vis_out_dir' not in kwargs:
kwargs['vis_out_dir'] = f'{out_dir}/visualizations'
if 'pred_out_dir' not in kwargs:
kwargs['pred_out_dir'] = f'{out_dir}/predictions'

(
preprocess_kwargs,
forward_kwargs,
visualize_kwargs,
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)

self.update_model_visualizer_settings(**kwargs)

# preprocessing
if isinstance(inputs, str) and inputs.startswith('webcam'):
inputs = self._get_webcam_inputs(inputs)
batch_size = 1
if not visualize_kwargs.get('show', False):
print_log(
'The display mode is closed when using webcam '
'input. It will be turned on automatically.',
logger='current',
level=logging.WARNING)
visualize_kwargs['show'] = True
else:
inputs = self._inputs_to_list(inputs)

# check the compatibility between inputs/outputs
if not self._video_input and len(inputs) > 0:
vis_out_dir = visualize_kwargs.get('vis_out_dir', None)
if vis_out_dir is not None:
_, file_extension = os.path.splitext(vis_out_dir)
assert not file_extension, f'the argument `vis_out_dir` ' \
f'should be a folder while the input contains multiple ' \
f'images, but got {vis_out_dir}'

if 'bbox_thr' in self.forward_kwargs:
forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1)
inputs = self.preprocess(
inputs, batch_size=batch_size, **preprocess_kwargs)

preds = []

for proc_inputs, ori_inputs in inputs:
preds = self.forward(proc_inputs, **forward_kwargs)

visualization = self.visualize(ori_inputs, preds,
**visualize_kwargs)
results = self.postprocess(
preds,
visualization,
return_datasamples=return_datasamples,
**postprocess_kwargs)
yield results

if self._video_input:
self._finalize_video_processing(
postprocess_kwargs.get('pred_out_dir', ''))

# In 3D Inferencers, some intermediate results (e.g. 2d keypoints)
# will be temporarily stored in `self._buffer`. It's essential to
# clear this information to prevent any interference with subsequent
# inferences.
if hasattr(self, '_buffer'):
self._buffer.clear()

def visualize(self,
inputs: list,
preds: List[PoseDataSample],
Expand Down Expand Up @@ -349,44 +490,52 @@ def visualize(self,
results.append(visualization)

if vis_out_dir:
out_img = mmcv.rgb2bgr(visualization)
_, file_extension = os.path.splitext(vis_out_dir)
if file_extension:
dir_name = os.path.dirname(vis_out_dir)
file_name = os.path.basename(vis_out_dir)
else:
dir_name = vis_out_dir
file_name = None
mkdir_or_exist(dir_name)

if self._video_input:

if self.video_info['writer'] is None:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
if file_name is None:
file_name = os.path.basename(
self.video_info['name'])
out_file = join_path(dir_name, file_name)
self.video_info['output_file'] = out_file
self.video_info['writer'] = cv2.VideoWriter(
out_file, fourcc, self.video_info['fps'],
(visualization.shape[1], visualization.shape[0]))
self.video_info['writer'].write(out_img)

else:
file_name = file_name if file_name else img_name
out_file = join_path(dir_name, file_name)
mmcv.imwrite(out_img, out_file)
print_log(
f'the output image has been saved at {out_file}',
logger='current',
level=logging.INFO)
self.save_visualization(
visualization,
vis_out_dir,
img_name=img_name,
)

if return_vis:
return results
else:
return []

def save_visualization(self, visualization, vis_out_dir, img_name=None):
out_img = mmcv.rgb2bgr(visualization)
_, file_extension = os.path.splitext(vis_out_dir)
if file_extension:
dir_name = os.path.dirname(vis_out_dir)
file_name = os.path.basename(vis_out_dir)
else:
dir_name = vis_out_dir
file_name = None
mkdir_or_exist(dir_name)

if self._video_input:

if self.video_info['writer'] is None:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
if file_name is None:
file_name = os.path.basename(self.video_info['name'])
out_file = join_path(dir_name, file_name)
self.video_info['output_file'] = out_file
self.video_info['writer'] = cv2.VideoWriter(
out_file, fourcc, self.video_info['fps'],
(visualization.shape[1], visualization.shape[0]))
self.video_info['writer'].write(out_img)

else:
if file_name is None:
file_name = img_name if img_name else 'visualization.jpg'

out_file = join_path(dir_name, file_name)
mmcv.imwrite(out_img, out_file)
print_log(
f'the output image has been saved at {out_file}',
logger='current',
level=logging.INFO)

def postprocess(
self,
preds: List[PoseDataSample],
Expand Down
Loading

0 comments on commit efe0c8d

Please sign in to comment.