Skip to content

Commit

Permalink
add pose tracking demo (#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
luminxu authored Dec 31, 2020
1 parent 964bafd commit 5382dff
Show file tree
Hide file tree
Showing 6 changed files with 488 additions and 1 deletion.
38 changes: 38 additions & 0 deletions demo/2d_pose_tracking_demo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
## 2D Pose Tracking Demo

### 2D Top-Down Video Human Pose Tracking Demo

We provide a video demo to illustrate the pose tracking results.

Assume that you have already installed [mmdet](https://github.com/open-mmlab/mmdetection).

```shell
python demo/top_down_video_demo_with_mmdet.py \
${MMDET_CONFIG_FILE} ${MMDET_CHECKPOINT_FILE} \
${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \
--video-path ${VIDEO_FILE} \
--output-video-root ${OUTPUT_VIDEO_ROOT} \
[--show --device ${GPU_ID}] \
[--bbox-thr ${BBOX_SCORE_THR} --kpt-thr ${KPT_SCORE_THR} --iou-thr ${IOU_SCORE_THR}]
```

Examples:

```shell
python demo/top_down_pose_tracking_demo_with_mmdet.py \
demo/mmdetection_cfg/faster_rcnn_r50_fpn_1x_coco.py \
http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \
configs/top_down/resnet/coco/res50_coco_256x192.py \
https://download.openmmlab.com/mmpose/top_down/resnet/res50_coco_256x192-ec54d7f3_20200709.pth \
--video-path demo/demo_video.mp4 \
--out-video-root vis_results
```

### Speed Up Inference

Some tips to speed up MMPose inference:

For top-down 2D human pose models, try to edit the config file. For example,

1. set `flip_test=False` in [topdown-res50](/configs/top_down/resnet/coco/res50_coco_256x192.py#L51).
2. set `unbiased_decoding=False` in [topdown-res50](/configs/top_down/resnet/coco/res50_coco_256x192.py#L54).
1 change: 1 addition & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ This page provides tutorials about running demos.

- [2D human pose demo](2d_human_pose_demo.md)
- [2D hand demo](2d_hand_demo.md)
- [2D pose tracking demo](2d_pose_tracking_demo.md)

<!-- TOC -->
150 changes: 150 additions & 0 deletions demo/top_down_pose_tracking_demo_with_mmdet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os
from argparse import ArgumentParser

import cv2
from mmdet.apis import inference_detector, init_detector

from mmpose.apis import (get_track_id, inference_top_down_pose_model,
init_pose_model, vis_pose_tracking_result)


def process_mmdet_results(mmdet_results, cat_id=0):
"""Process mmdet results, and return a list of bboxes.
:param mmdet_results:
:param cat_id: category id (default: 0 for human)
:return: a list of detected bounding boxes
"""
if isinstance(mmdet_results, tuple):
det_results = mmdet_results[0]
else:
det_results = mmdet_results
return det_results[cat_id]


def main():
"""Visualize the demo images.
Using mmdet to detect the human.
"""
parser = ArgumentParser()
parser.add_argument('det_config', help='Config file for detection')
parser.add_argument('det_checkpoint', help='Checkpoint file for detection')
parser.add_argument('pose_config', help='Config file for pose')
parser.add_argument('pose_checkpoint', help='Checkpoint file for pose')
parser.add_argument('--video-path', type=str, help='Video path')
parser.add_argument(
'--show',
action='store_true',
default=False,
help='whether to show visualizations.')
parser.add_argument(
'--out-video-root',
default='',
help='Root of the output video file. '
'Default not saving the visualization video.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--bbox-thr',
type=float,
default=0.3,
help='Bounding box score threshold')
parser.add_argument(
'--kpt-thr', type=float, default=0.3, help='Keypoint score threshold')
parser.add_argument(
'--iou-thr', type=float, default=0.3, help='IoU score threshold')

args = parser.parse_args()

assert args.show or (args.out_video_root != '')
assert args.det_config is not None
assert args.det_checkpoint is not None

det_model = init_detector(
args.det_config, args.det_checkpoint, device=args.device)
# build the pose model from a config file and a checkpoint file
pose_model = init_pose_model(
args.pose_config, args.pose_checkpoint, device=args.device)

dataset = pose_model.cfg.data['test']['type']

cap = cv2.VideoCapture(args.video_path)

if args.out_video_root == '':
save_out_video = False
else:
os.makedirs(args.out_video_root, exist_ok=True)
save_out_video = True

if save_out_video:
fps = cap.get(cv2.CAP_PROP_FPS)
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
videoWriter = cv2.VideoWriter(
os.path.join(args.out_video_root,
f'vis_{os.path.basename(args.video_path)}'), fourcc,
fps, size)

# optional
return_heatmap = False

# e.g. use ('backbone', ) to return backbone feature
output_layer_names = None

next_id = 0
pose_results = []
while (cap.isOpened()):
pose_results_last = pose_results

flag, img = cap.read()
if not flag:
break
# test a single image, the resulting box is (x1, y1, x2, y2)
mmdet_results = inference_detector(det_model, img)

# keep the person class bounding boxes.
person_bboxes = process_mmdet_results(mmdet_results)

# test a single image, with a list of bboxes.
pose_results, returned_outputs = inference_top_down_pose_model(
pose_model,
img,
person_bboxes,
bbox_thr=args.bbox_thr,
format='xyxy',
dataset=dataset,
return_heatmap=return_heatmap,
outputs=output_layer_names)

# get track id for each person instance
pose_results, next_id = get_track_id(
pose_results, pose_results_last, next_id, iou_thr=args.iou_thr)

# show the results
vis_img = vis_pose_tracking_result(
pose_model,
img,
pose_results,
dataset=dataset,
kpt_score_thr=args.kpt_thr,
show=False)

if args.show:
cv2.imshow('Image', vis_img)

if save_out_video:
videoWriter.write(vis_img)

if cv2.waitKey(1) & 0xFF == ord('q'):
break

cap.release()
if save_out_video:
videoWriter.release()
cv2.destroyAllWindows()


if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion mmpose/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .inference import (inference_bottom_up_pose_model,
inference_top_down_pose_model, init_pose_model,
vis_pose_result)
from .inference_tracking import get_track_id, vis_pose_tracking_result
from .test import multi_gpu_test, single_gpu_test
from .train import train_model

__all__ = [
'train_model', 'init_pose_model', 'inference_top_down_pose_model',
'inference_bottom_up_pose_model', 'multi_gpu_test', 'single_gpu_test',
'vis_pose_result'
'vis_pose_result', 'get_track_id', 'vis_pose_tracking_result'
]
Loading

0 comments on commit 5382dff

Please sign in to comment.