-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
488 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
## 2D Pose Tracking Demo | ||
|
||
### 2D Top-Down Video Human Pose Tracking Demo | ||
|
||
We provide a video demo to illustrate the pose tracking results. | ||
|
||
Assume that you have already installed [mmdet](https://github.com/open-mmlab/mmdetection). | ||
|
||
```shell | ||
python demo/top_down_video_demo_with_mmdet.py \ | ||
${MMDET_CONFIG_FILE} ${MMDET_CHECKPOINT_FILE} \ | ||
${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \ | ||
--video-path ${VIDEO_FILE} \ | ||
--output-video-root ${OUTPUT_VIDEO_ROOT} \ | ||
[--show --device ${GPU_ID}] \ | ||
[--bbox-thr ${BBOX_SCORE_THR} --kpt-thr ${KPT_SCORE_THR} --iou-thr ${IOU_SCORE_THR}] | ||
``` | ||
|
||
Examples: | ||
|
||
```shell | ||
python demo/top_down_pose_tracking_demo_with_mmdet.py \ | ||
demo/mmdetection_cfg/faster_rcnn_r50_fpn_1x_coco.py \ | ||
http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \ | ||
configs/top_down/resnet/coco/res50_coco_256x192.py \ | ||
https://download.openmmlab.com/mmpose/top_down/resnet/res50_coco_256x192-ec54d7f3_20200709.pth \ | ||
--video-path demo/demo_video.mp4 \ | ||
--out-video-root vis_results | ||
``` | ||
|
||
### Speed Up Inference | ||
|
||
Some tips to speed up MMPose inference: | ||
|
||
For top-down 2D human pose models, try to edit the config file. For example, | ||
|
||
1. set `flip_test=False` in [topdown-res50](/configs/top_down/resnet/coco/res50_coco_256x192.py#L51). | ||
2. set `unbiased_decoding=False` in [topdown-res50](/configs/top_down/resnet/coco/res50_coco_256x192.py#L54). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import os | ||
from argparse import ArgumentParser | ||
|
||
import cv2 | ||
from mmdet.apis import inference_detector, init_detector | ||
|
||
from mmpose.apis import (get_track_id, inference_top_down_pose_model, | ||
init_pose_model, vis_pose_tracking_result) | ||
|
||
|
||
def process_mmdet_results(mmdet_results, cat_id=0): | ||
"""Process mmdet results, and return a list of bboxes. | ||
:param mmdet_results: | ||
:param cat_id: category id (default: 0 for human) | ||
:return: a list of detected bounding boxes | ||
""" | ||
if isinstance(mmdet_results, tuple): | ||
det_results = mmdet_results[0] | ||
else: | ||
det_results = mmdet_results | ||
return det_results[cat_id] | ||
|
||
|
||
def main(): | ||
"""Visualize the demo images. | ||
Using mmdet to detect the human. | ||
""" | ||
parser = ArgumentParser() | ||
parser.add_argument('det_config', help='Config file for detection') | ||
parser.add_argument('det_checkpoint', help='Checkpoint file for detection') | ||
parser.add_argument('pose_config', help='Config file for pose') | ||
parser.add_argument('pose_checkpoint', help='Checkpoint file for pose') | ||
parser.add_argument('--video-path', type=str, help='Video path') | ||
parser.add_argument( | ||
'--show', | ||
action='store_true', | ||
default=False, | ||
help='whether to show visualizations.') | ||
parser.add_argument( | ||
'--out-video-root', | ||
default='', | ||
help='Root of the output video file. ' | ||
'Default not saving the visualization video.') | ||
parser.add_argument( | ||
'--device', default='cuda:0', help='Device used for inference') | ||
parser.add_argument( | ||
'--bbox-thr', | ||
type=float, | ||
default=0.3, | ||
help='Bounding box score threshold') | ||
parser.add_argument( | ||
'--kpt-thr', type=float, default=0.3, help='Keypoint score threshold') | ||
parser.add_argument( | ||
'--iou-thr', type=float, default=0.3, help='IoU score threshold') | ||
|
||
args = parser.parse_args() | ||
|
||
assert args.show or (args.out_video_root != '') | ||
assert args.det_config is not None | ||
assert args.det_checkpoint is not None | ||
|
||
det_model = init_detector( | ||
args.det_config, args.det_checkpoint, device=args.device) | ||
# build the pose model from a config file and a checkpoint file | ||
pose_model = init_pose_model( | ||
args.pose_config, args.pose_checkpoint, device=args.device) | ||
|
||
dataset = pose_model.cfg.data['test']['type'] | ||
|
||
cap = cv2.VideoCapture(args.video_path) | ||
|
||
if args.out_video_root == '': | ||
save_out_video = False | ||
else: | ||
os.makedirs(args.out_video_root, exist_ok=True) | ||
save_out_video = True | ||
|
||
if save_out_video: | ||
fps = cap.get(cv2.CAP_PROP_FPS) | ||
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), | ||
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) | ||
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | ||
videoWriter = cv2.VideoWriter( | ||
os.path.join(args.out_video_root, | ||
f'vis_{os.path.basename(args.video_path)}'), fourcc, | ||
fps, size) | ||
|
||
# optional | ||
return_heatmap = False | ||
|
||
# e.g. use ('backbone', ) to return backbone feature | ||
output_layer_names = None | ||
|
||
next_id = 0 | ||
pose_results = [] | ||
while (cap.isOpened()): | ||
pose_results_last = pose_results | ||
|
||
flag, img = cap.read() | ||
if not flag: | ||
break | ||
# test a single image, the resulting box is (x1, y1, x2, y2) | ||
mmdet_results = inference_detector(det_model, img) | ||
|
||
# keep the person class bounding boxes. | ||
person_bboxes = process_mmdet_results(mmdet_results) | ||
|
||
# test a single image, with a list of bboxes. | ||
pose_results, returned_outputs = inference_top_down_pose_model( | ||
pose_model, | ||
img, | ||
person_bboxes, | ||
bbox_thr=args.bbox_thr, | ||
format='xyxy', | ||
dataset=dataset, | ||
return_heatmap=return_heatmap, | ||
outputs=output_layer_names) | ||
|
||
# get track id for each person instance | ||
pose_results, next_id = get_track_id( | ||
pose_results, pose_results_last, next_id, iou_thr=args.iou_thr) | ||
|
||
# show the results | ||
vis_img = vis_pose_tracking_result( | ||
pose_model, | ||
img, | ||
pose_results, | ||
dataset=dataset, | ||
kpt_score_thr=args.kpt_thr, | ||
show=False) | ||
|
||
if args.show: | ||
cv2.imshow('Image', vis_img) | ||
|
||
if save_out_video: | ||
videoWriter.write(vis_img) | ||
|
||
if cv2.waitKey(1) & 0xFF == ord('q'): | ||
break | ||
|
||
cap.release() | ||
if save_out_video: | ||
videoWriter.release() | ||
cv2.destroyAllWindows() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
from .inference import (inference_bottom_up_pose_model, | ||
inference_top_down_pose_model, init_pose_model, | ||
vis_pose_result) | ||
from .inference_tracking import get_track_id, vis_pose_tracking_result | ||
from .test import multi_gpu_test, single_gpu_test | ||
from .train import train_model | ||
|
||
__all__ = [ | ||
'train_model', 'init_pose_model', 'inference_top_down_pose_model', | ||
'inference_bottom_up_pose_model', 'multi_gpu_test', 'single_gpu_test', | ||
'vis_pose_result' | ||
'vis_pose_result', 'get_track_id', 'vis_pose_tracking_result' | ||
] |
Oops, something went wrong.