Skip to content

Commit

Permalink
Merge 99ea1a1 into abe09d3
Browse files Browse the repository at this point in the history
  • Loading branch information
Tau-J authored Jul 27, 2023
2 parents abe09d3 + 99ea1a1 commit f9617be
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 12 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ We provided a series of tutorials about the basic usage of MMPose for new users:
Results and models are available in the **README.md** of each method's config directory.
A summary can be found in the [Model Zoo](https://mmpose.readthedocs.io/en/latest/model_zoo.html) page.

<details close>
<details open>
<summary><b>Supported algorithms:</b></summary>

- [x] [DeepPose](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#deeppose-cvpr-2014) (CVPR'2014)
Expand All @@ -240,7 +240,7 @@ A summary can be found in the [Model Zoo](https://mmpose.readthedocs.io/en/lates

</details>

<details close>
<details open>
<summary><b>Supported techniques:</b></summary>

- [x] [FPN](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/techniques.html#fpn-cvpr-2017) (CVPR'2017)
Expand All @@ -255,7 +255,7 @@ A summary can be found in the [Model Zoo](https://mmpose.readthedocs.io/en/lates

</details>

<details close>
<details open>
<summary><b>Supported datasets:</b></summary>

- [x] [AFLW](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/datasets.html#aflw-iccvw-2011) \[[homepage](https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/aflw/)\] (ICCVW'2011)
Expand Down Expand Up @@ -294,7 +294,7 @@ A summary can be found in the [Model Zoo](https://mmpose.readthedocs.io/en/lates

</details>

<details close>
<details open>
<summary><b>Supported backbones:</b></summary>

- [x] [AlexNet](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/backbones.html#alexnet-neurips-2012) (NeurIPS'2012)
Expand Down
8 changes: 4 additions & 4 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ MMPose v1.0.0 是一个重大更新,包括了大量的 API 和配置文件的
各个模型的结果和设置都可以在对应的 config(配置)目录下的 **README.md** 中查看。
整体的概况也可也在 [模型库](https://mmpose.readthedocs.io/zh_CN/latest/model_zoo.html) 页面中查看。

<details close>
<details open>
<summary><b>支持的算法</b></summary>

- [x] [DeepPose](https://mmpose.readthedocs.io/zh_CN/latest/model_zoo_papers/algorithms.html#deeppose-cvpr-2014) (CVPR'2014)
Expand All @@ -238,7 +238,7 @@ MMPose v1.0.0 是一个重大更新,包括了大量的 API 和配置文件的

</details>

<details close>
<details open>
<summary><b>支持的技术</b></summary>

- [x] [FPN](https://mmpose.readthedocs.io/zh_CN/latest/model_zoo_papers/techniques.html#fpn-cvpr-2017) (CVPR'2017)
Expand All @@ -253,7 +253,7 @@ MMPose v1.0.0 是一个重大更新,包括了大量的 API 和配置文件的

</details>

<details close>
<details open>
<summary><b>支持的数据集</b></summary>

- [x] [AFLW](https://mmpose.readthedocs.io/zh_CN/latest/model_zoo_papers/datasets.html#aflw-iccvw-2011) \[[主页](https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/aflw/)\] (ICCVW'2011)
Expand Down Expand Up @@ -292,7 +292,7 @@ MMPose v1.0.0 是一个重大更新,包括了大量的 API 和配置文件的

</details>

<details close>
<details open>
<summary><b>支持的骨干网络</b></summary>

- [x] [AlexNet](https://mmpose.readthedocs.io/zh_CN/latest/model_zoo_papers/backbones.html#alexnet-neurips-2012) (NeurIPS'2012)
Expand Down
48 changes: 46 additions & 2 deletions demo/body3d_pose_lifter_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,45 @@ def process_one_image(args, detector, frame, frame_idx, pose_estimator,
visualize_frame, visualizer):
"""Visualize detected and predicted keypoints of one image.
Pipeline of this function:
frame
|
V
+-----------------+
| detector |
+-----------------+
| det_result
V
+-----------------+
| pose_estimator |
+-----------------+
| pose_est_results
V
+--------------------------------------------+
| convert 2d kpts into pose-lifting format |
+--------------------------------------------+
| pose_est_results_list
V
+-----------------------+
| extract_pose_sequence |
+-----------------------+
| pose_seq_2d
V
+-------------+
| pose_lifter |
+-------------+
| pose_lift_results
V
+-----------------+
| post-processing |
+-----------------+
| pred_3d_data_samples
V
+------------+
| visualizer |
+------------+
Args:
args (Argument): Custom command-line arguments.
detector (mmdet.BaseDetector): The mmdet detector.
Expand Down Expand Up @@ -170,10 +209,13 @@ def process_one_image(args, detector, frame, frame_idx, pose_estimator,
"""
pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset

# First stage: conduct 2D pose detection in a Topdown manner
# use detector to obtain person bounding boxes
det_result = inference_detector(detector, frame)
pred_instance = det_result.pred_instances.cpu().numpy()

# First stage: 2D pose detection
# filter out the person instances with category and bbox threshold
# e.g. 0 for person in COCO
bboxes = pred_instance.bboxes
bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id,
pred_instance.scores > args.bbox_thr)]
Expand All @@ -190,6 +232,8 @@ def process_one_image(args, detector, frame, frame_idx, pose_estimator,
pose_det_dataset = pose_estimator.cfg.test_dataloader.dataset
pose_est_results_converted = []

# convert 2d pose estimation results into the format for pose-lifting
# such as changing the keypoint order, flipping the keypoint, etc.
for i, data_sample in enumerate(pose_est_results):
pred_instances = data_sample.pred_instances.cpu().numpy()
keypoints = pred_instances.keypoints
Expand Down Expand Up @@ -256,7 +300,7 @@ def process_one_image(args, detector, frame, frame_idx, pose_estimator,
seq_len=pose_lift_dataset.get('seq_len', 1),
step=pose_lift_dataset.get('seq_step', 1))

# 2D-to-3D pose lifting
# conduct 2D-to-3D pose lifting
norm_pose_2d = not args.disable_norm_pose_2d
pose_lift_results = inference_pose_lifter_model(
pose_lifter,
Expand Down
61 changes: 60 additions & 1 deletion docs/en/advanced_guides/implement_new_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Finally, please remember to import your new prediction head in `[__init__.py](ht

### Head with Keypoints Visibility Prediction

Many models predict keypoint visibility based on confidence in coordinate predictions. However, this approach is suboptimal. Our [`VisPredictHead`](https://github.com/open-mmlab/mmpose/blob/dev-1.x/mmpose/models/heads/hybrid_heads/vis_head.py) wrapper enables heads to directly predict keypoint visibility from ground truth training data, improving reliability. To add visibility prediction, wrap your head module with VisPredictHead in the config file.
Many models predict keypoint visibility based on confidence in coordinate predictions. However, this approach is suboptimal. Our [VisPredictHead](https://github.com/open-mmlab/mmpose/blob/dev-1.x/mmpose/models/heads/hybrid_heads/vis_head.py) wrapper enables heads to directly predict keypoint visibility from ground truth training data, improving reliability. To add visibility prediction, wrap your head module with VisPredictHead in the config file.

```python
model=dict(
Expand All @@ -103,3 +103,62 @@ model=dict(
...
)
```

To implement such a head module wrapper, we only need to inherit [BaseHead](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/base_head.py), then pass the pose head configuration in `__init__()` and instantiate it through `MODELS.build()`. As shown below:

```python
@MODELS.register_module()
class VisPredictHead(BaseHead):
"""VisPredictHead must be used together with other heads. It can predict
keypoints coordinates of and their visibility simultaneously. In the
current version, it only supports top-down approaches.
Args:
pose_cfg (Config): Config to construct keypoints prediction head
loss (Config): Config for visibility loss. Defaults to use
:class:`BCELoss`
use_sigmoid (bool): Whether to use sigmoid activation function
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
"""

def __init__(self,
pose_cfg: ConfigType,
loss: ConfigType = dict(
type='BCELoss', use_target_weight=False,
use_sigmoid=True),
init_cfg: OptConfigType = None):

if init_cfg is None:
init_cfg = self.default_init_cfg

super().__init__(init_cfg)

self.in_channels = pose_cfg['in_channels']
if pose_cfg.get('num_joints', None) is not None:
self.out_channels = pose_cfg['num_joints']
elif pose_cfg.get('out_channels', None) is not None:
self.out_channels = pose_cfg['out_channels']
else:
raise ValueError('VisPredictHead requires \'num_joints\' or'
' \'out_channels\' in the pose_cfg.')

self.loss_module = MODELS.build(loss)

self.pose_head = MODELS.build(pose_cfg)
self.pose_cfg = pose_cfg

self.use_sigmoid = loss.get('use_sigmoid', False)

modules = [
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(self.in_channels, self.out_channels)
]
if self.use_sigmoid:
modules.append(nn.Sigmoid())

self.vis_head = nn.Sequential(*modules)
```

Then you can implement other parts of the code as a normal head.
8 changes: 8 additions & 0 deletions docs/en/guide_to_framework.md
Original file line number Diff line number Diff line change
Expand Up @@ -684,3 +684,11 @@ def loss(self,

return losses
```

```{note}
If you wish to learn more about the implementation of Model, like:
- Head with Keypoints Visibility Prediction
- Pose Lifting Models
please refer to [Advanced Guides - Implement New Model](./advanced_guides/implement_new_models.md) for more details.
```
30 changes: 30 additions & 0 deletions docs/src/papers/algorithms/motionbert.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# MotionBERT: Unified Pretraining for Human Motion Analysis

<!-- [BACKBONE] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2210.06551">MotionBERT (ICCV'2023)</a></summary>

```bibtex
@misc{Zhu_Ma_Liu_Liu_Wu_Wang_2022,
title={Learning Human Motion Representations: A Unified Perspective},
author={Zhu, Wentao and Ma, Xiaoxuan and Liu, Zhaoyang and Liu, Libin and Wu, Wayne and Wang, Yizhou},
year={2022},
month={Oct},
language={en-US}
}
```

</details>

## Abstract

<!-- [ABSTRACT] -->

We present MotionBERT, a unified pretraining framework, to tackle different sub-tasks of human motion analysis including 3D pose estimation, skeleton-based action recognition, and mesh recovery. The proposed framework is capable of utilizing all kinds of human motion data resources, including motion capture data and in-the-wild videos. During pretraining, the pretext task requires the motion encoder to recover the underlying 3D motion from noisy partial 2D observations. The pretrained motion representation thus acquires geometric, kinematic, and physical knowledge about human motion and therefore can be easily transferred to multiple downstream tasks. We implement the motion encoder with a novel Dual-stream Spatio-temporal Transformer (DSTformer) neural network. It could capture long-range spatio-temporal relationships among the skeletal joints comprehensively and adaptively, exemplified by the lowest 3D pose estimation error so far when trained from scratch. More importantly, the proposed framework achieves state-of-the-art performance on all three downstream tasks by simply finetuning the pretrained motion encoder with 1-2 linear layers, which demonstrates the versatility of the learned motion representations.

<!-- [IMAGE] -->

<div align=center>
<img src="https://github.com/open-mmlab/mmpose/assets/13503330/877d47ee-b821-476c-a805-f39ca656913c">
</div>
59 changes: 58 additions & 1 deletion docs/zh_cn/advanced_guides/implement_new_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class YourNewHead(BaseHead):

### 关键点可见性预测头部

许多模型都是通过对关键点坐标预测的置信度来判断关键点的可见性的。然而,这种解决方案并非最优。我们提供了一个叫做 [`VisPredictHead`](https://github.com/open-mmlab/mmpose/blob/dev-1.x/mmpose/models/heads/hybrid_heads/vis_head.py) 的头部模块包装器,使得头部模块能够直接预测关键点的可见性。这个包装器是用训练数据中关键点可见性真值来训练的。因此,其预测会更加可靠。用户可以通过修改配置文件来对自己的头部模块加上这个包装器。下面是一个例子:
许多模型都是通过对关键点坐标预测的置信度来判断关键点的可见性的。然而,这种解决方案并非最优。我们提供了一个叫做 [VisPredictHead](https://github.com/open-mmlab/mmpose/blob/dev-1.x/mmpose/models/heads/hybrid_heads/vis_head.py) 的头部模块包装器,使得头部模块能够直接预测关键点的可见性。这个包装器是用训练数据中关键点可见性真值来训练的。因此,其预测会更加可靠。用户可以通过修改配置文件来对自己的头部模块加上这个包装器。下面是一个例子:

```python
model=dict(
Expand All @@ -102,3 +102,60 @@ model=dict(
...
)
```

要实现这样一个预测头部模块包装器,我们只需要像定义正常的预测头部一样,继承 [BaseHead](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/base_head.py),然后在 `__init__()` 中传入关键点定位的头部配置,并通过 `MODELS.build()` 进行实例化。如下所示:

```python
@MODELS.register_module()
class VisPredictHead(BaseHead):
"""VisPredictHead must be used together with other heads. It can predict
keypoints coordinates of and their visibility simultaneously. In the
current version, it only supports top-down approaches.
Args:
pose_cfg (Config): Config to construct keypoints prediction head
loss (Config): Config for visibility loss. Defaults to use
:class:`BCELoss`
use_sigmoid (bool): Whether to use sigmoid activation function
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
"""

def __init__(self,
pose_cfg: ConfigType,
loss: ConfigType = dict(
type='BCELoss', use_target_weight=False,
use_sigmoid=True),
init_cfg: OptConfigType = None):

if init_cfg is None:
init_cfg = self.default_init_cfg

super().__init__(init_cfg)

self.in_channels = pose_cfg['in_channels']
if pose_cfg.get('num_joints', None) is not None:
self.out_channels = pose_cfg['num_joints']
elif pose_cfg.get('out_channels', None) is not None:
self.out_channels = pose_cfg['out_channels']
else:
raise ValueError('VisPredictHead requires \'num_joints\' or'
' \'out_channels\' in the pose_cfg.')

self.loss_module = MODELS.build(loss)

self.pose_head = MODELS.build(pose_cfg)
self.pose_cfg = pose_cfg

self.use_sigmoid = loss.get('use_sigmoid', False)

modules = [
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(self.in_channels, self.out_channels)
]
if self.use_sigmoid:
modules.append(nn.Sigmoid())

self.vis_head = nn.Sequential(*modules)
```
8 changes: 8 additions & 0 deletions docs/zh_cn/guide_to_framework.md
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,11 @@ def loss(self,

return losses
```

```{note}
如果你想了解更多模型实现的内容,如:
- 支持关键点可见性预测的头部
- 2D-to-3D 模型实现
请前往 [【进阶教程 - 实现新模型】](./advanced_guides/implement_new_models.md)
```
15 changes: 15 additions & 0 deletions projects/rtmpose/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,21 @@ Feel free to join our community group for more help:

</details>

<details open>
<summary><b>Human-Art</b></summary>

- Details see [Human-Art](https://github.com/IDEA-Research/HumanArt)
- <img src="https://github.com/open-mmlab/mmpose/assets/13503330/685bc610-dd9e-4e6f-9c41-dbc8220584f4" height="300px">

| Config | Input Size | AP<sup><br>(Human-Art GT) | Params<sup><br>(M) | FLOPS<sup><br>(G) | ORT-Latency<sup><br>(ms)<sup><br>(i7-11700) | TRT-FP16-Latency<sup><br>(ms)<sup><br>(GTX 1660Ti) | ncnn-FP16-Latency<sup><br>(ms)<sup><br>(Snapdragon 865) | Download |
| :-----------------------------------------------------------------------------: | :--------: | :-----------------------: | :----------------: | :---------------: | :-----------------------------------------: | :------------------------------------------------: | :-----------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------: |
| [RTMPose-t\*](./rtmpose/body_2d_keypoint/rtmpose-t_8xb256-420e_coco-256x192.py) | 256x192 | 65.5 | 3.34 | 0.36 | 3.20 | 1.06 | 9.02 | [Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth) |
| [RTMPose-s\*](./rtmpose/body_2d_keypoint/rtmpose-s_8xb256-420e_coco-256x192.py) | 256x192 | 69.8 | 5.47 | 0.68 | 4.48 | 1.39 | 13.89 | [Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.pth) |
| [RTMPose-m\*](./rtmpose/body_2d_keypoint/rtmpose-m_8xb256-420e_coco-256x192.py) | 256x192 | 72.8 | 13.59 | 1.93 | 11.06 | 2.29 | 26.44 | [Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.pth) |
| [RTMPose-l\*](./rtmpose/body_2d_keypoint/rtmpose-l_8xb256-420e_coco-256x192.py) | 256x192 | 75.3 | 27.66 | 4.16 | 18.85 | 3.46 | 45.37 | [Model](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.pth) |

</details>

#### 26 Keypoints

- Keypoints are defined as [Halpe26](https://github.com/Fang-Haoshu/Halpe-FullBody/). For details please refer to the [meta info](/configs/_base_/datasets/halpe26.py).
Expand Down
Loading

0 comments on commit f9617be

Please sign in to comment.