From 137e8648a56db3b8a42fd356838b63be81ec45ef Mon Sep 17 00:00:00 2001 From: ailingzengzzz <32029490+ailingzengzzz@users.noreply.github.com> Date: Fri, 11 Mar 2022 20:31:39 +0800 Subject: [PATCH] [Feature] Add two traditional filters to smooth poses (#1127) * add savgol&gaus1d filters * Update __init__.py * Update test_temporal_filter.py * Update and rename gaus1d_filter.py to gauss1d_filter.py * Update __init__.py * pass pre-commit hooks * Add smoother * add Smoother * refactor filters with base filter * add unittest for Smoother * add temporal filter unittests * remove misadded file * fix bugs in smoother and filters update get_track_id and demos * fix unittest * fix smoother bug with empty input * fix smoothing in 3d video demo Co-authored-by: ly015 --- configs/_base_/filters/gausian_filter.py | 0 configs/_base_/filters/gaussian.py | 5 + configs/_base_/filters/one_euro.py | 5 + configs/_base_/filters/savizky_golay.py | 5 + demo/body3d_two_stage_video_demo.py | 32 ++- demo/bottom_up_pose_tracking_demo.py | 33 ++- .../top_down_pose_tracking_demo_with_mmdet.py | 37 ++- ...down_pose_tracking_demo_with_mmtracking.py | 21 ++ demo/webcam_demo.py | 37 ++- mmpose/apis/inference_tracking.py | 18 +- mmpose/core/post_processing/__init__.py | 6 +- .../core/post_processing/one_euro_filter.py | 17 +- mmpose/core/post_processing/smoother.py | 218 ++++++++++++++++++ .../temporal_filters/__init__.py | 9 + .../temporal_filters/builder.py | 9 + .../temporal_filters/filter.py | 35 +++ .../temporal_filters/gaussian_filter.py | 44 ++++ .../temporal_filters/one_euro_filter.py | 110 +++++++++ .../temporal_filters/savizky_golay_filter.py | 50 ++++ tests/test_apis/test_inference_tracking.py | 13 +- .../test_one_euro_filter_compatibility.py} | 0 tests/test_post_processing/test_smoother.py | 143 ++++++++++++ .../test_temporal_filter.py | 47 ++++ tools/webcam/webcam_apis/nodes/mmpose_node.py | 71 +++--- 24 files changed, 902 insertions(+), 63 deletions(-) delete mode 100644 configs/_base_/filters/gausian_filter.py create mode 100644 configs/_base_/filters/gaussian.py create mode 100644 configs/_base_/filters/one_euro.py create mode 100644 configs/_base_/filters/savizky_golay.py create mode 100644 mmpose/core/post_processing/smoother.py create mode 100644 mmpose/core/post_processing/temporal_filters/__init__.py create mode 100644 mmpose/core/post_processing/temporal_filters/builder.py create mode 100644 mmpose/core/post_processing/temporal_filters/filter.py create mode 100644 mmpose/core/post_processing/temporal_filters/gaussian_filter.py create mode 100644 mmpose/core/post_processing/temporal_filters/one_euro_filter.py create mode 100644 mmpose/core/post_processing/temporal_filters/savizky_golay_filter.py rename tests/{test_post_processing/test_filter.py => test_backward_compatibility/test_one_euro_filter_compatibility.py} (100%) create mode 100644 tests/test_post_processing/test_smoother.py create mode 100644 tests/test_post_processing/test_temporal_filter.py diff --git a/configs/_base_/filters/gausian_filter.py b/configs/_base_/filters/gausian_filter.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/configs/_base_/filters/gaussian.py b/configs/_base_/filters/gaussian.py new file mode 100644 index 00000000000..b855f4bde1e --- /dev/null +++ b/configs/_base_/filters/gaussian.py @@ -0,0 +1,5 @@ +filter_cfg = dict( + type='GaussianFilter', + window_size=11, + sigma=4.0, +) diff --git a/configs/_base_/filters/one_euro.py b/configs/_base_/filters/one_euro.py new file mode 100644 index 00000000000..61f797efdf9 --- /dev/null +++ b/configs/_base_/filters/one_euro.py @@ -0,0 +1,5 @@ +filter_cfg = dict( + type='OneEuroFilter', + min_cutoff=0.004, + beta=0.7, +) diff --git a/configs/_base_/filters/savizky_golay.py b/configs/_base_/filters/savizky_golay.py new file mode 100644 index 00000000000..40302b00446 --- /dev/null +++ b/configs/_base_/filters/savizky_golay.py @@ -0,0 +1,5 @@ +filter_cfg = dict( + type='SavizkyGolayFilter', + window_size=11, + polyorder=2, +) diff --git a/demo/body3d_two_stage_video_demo.py b/demo/body3d_two_stage_video_demo.py index 5f47f62aeb8..26161228d8e 100644 --- a/demo/body3d_two_stage_video_demo.py +++ b/demo/body3d_two_stage_video_demo.py @@ -12,6 +12,7 @@ inference_pose_lifter_model, inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_3d_pose_result) +from mmpose.core import Smoother try: from mmdet.apis import inference_detector, init_detector @@ -123,10 +124,6 @@ def main(): '--use-oks-tracking', action='store_true', help='Using OKS tracking') parser.add_argument( '--tracking-thr', type=float, default=0.3, help='Tracking threshold') - parser.add_argument( - '--euro', - action='store_true', - help='Using One_Euro_Filter for smoothing') parser.add_argument( '--radius', type=int, @@ -137,6 +134,17 @@ def main(): type=int, default=2, help='Link thickness for visualization') + parser.add_argument( + '--smooth', + action='store_true', + help='Apply a temporal filter to smooth the pose estimation results. ' + 'See also --smooth-filter-cfg.') + parser.add_argument( + '--smooth-filter-cfg', + type=str, + default='configs/_base_/filters/one_euro.py', + help='Config file of the filter to smooth the pose estimation ' + 'results. See also --smooth.') assert has_mmdet, 'Please install mmdet to run the demo.' @@ -167,7 +175,7 @@ def main(): pose_det_results_list = [] next_id = 0 pose_det_results = [] - for frame in video: + for frame in mmcv.track_iter_progress(video): pose_det_results_last = pose_det_results # test a single image, the resulting box is (x1, y1, x2, y2) @@ -194,9 +202,7 @@ def main(): pose_det_results_last, next_id, use_oks=args.use_oks_tracking, - tracking_thr=args.tracking_thr, - use_one_euro=args.euro, - fps=video.fps) + tracking_thr=args.tracking_thr) pose_det_results_list.append(copy.deepcopy(pose_det_results)) @@ -237,6 +243,12 @@ def main(): else: data_cfg = pose_lift_model.cfg.data_cfg + # build pose smoother for temporal refinement + if args.smooth: + smoother = Smoother(filter_cfg=args.smooth_filter_cfg, keypoint_dim=3) + else: + smoother = None + num_instances = args.num_instances for i, pose_det_results in enumerate( mmcv.track_iter_progress(pose_det_results_list)): @@ -279,6 +291,10 @@ def main(): res['track_id'] = instance_id pose_lift_results_vis.append(res) + # Smoothing + if smoother: + pose_lift_results = smoother.smooth(pose_lift_results) + # Visualization if num_instances < 0: num_instances = len(pose_lift_results_vis) diff --git a/demo/bottom_up_pose_tracking_demo.py b/demo/bottom_up_pose_tracking_demo.py index b79e1f40de8..6cda0b3a325 100644 --- a/demo/bottom_up_pose_tracking_demo.py +++ b/demo/bottom_up_pose_tracking_demo.py @@ -7,6 +7,7 @@ from mmpose.apis import (get_track_id, inference_bottom_up_pose_model, init_pose_model, vis_pose_tracking_result) +from mmpose.core import Smoother from mmpose.datasets import DatasetInfo @@ -42,7 +43,19 @@ def main(): parser.add_argument( '--euro', action='store_true', - help='Using One_Euro_Filter for smoothing') + help='(Deprecated, please use --smooth and --smooth-filter-cfg) ' + 'Using One_Euro_Filter for smoothing.') + parser.add_argument( + '--smooth', + action='store_true', + help='Apply a temporal filter to smooth the pose estimation results. ' + 'See also --smooth-filter-cfg.') + parser.add_argument( + '--smooth-filter-cfg', + type=str, + default='configs/_base_/filters/one_euro.py', + help='Config file of the filter to smooth the pose estimation ' + 'results. See also --smooth.') parser.add_argument( '--radius', type=int, @@ -97,6 +110,20 @@ def main(): # optional return_heatmap = False + # build pose smoother for temporal refinement + if args.euro: + warnings.warn( + 'Argument --euro will be deprecated in the future. ' + 'Please use --smooth to enable temporal smoothing, and ' + '--smooth-filter-cfg to set the filter config.', + DeprecationWarning) + smoother = Smoother( + filter_cfg='configs/_base_/filters/one_euro.py', keypoint_dim=2) + elif args.smooth: + smoother = Smoother(filter_cfg=args.smooth_filter_cfg, keypoint_dim=2) + else: + smoother = None + # e.g. use ('backbone', ) to return backbone feature output_layer_names = None next_id = 0 @@ -126,6 +153,10 @@ def main(): use_one_euro=args.euro, fps=fps) + # post-process the pose results with smoother + if smoother: + pose_results = smoother.smooth(pose_results) + # show the results vis_img = vis_pose_tracking_result( pose_model, diff --git a/demo/top_down_pose_tracking_demo_with_mmdet.py b/demo/top_down_pose_tracking_demo_with_mmdet.py index 5ddcd934ee3..2b8d3f80eb0 100644 --- a/demo/top_down_pose_tracking_demo_with_mmdet.py +++ b/demo/top_down_pose_tracking_demo_with_mmdet.py @@ -8,6 +8,7 @@ from mmpose.apis import (get_track_id, inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_tracking_result) +from mmpose.core import Smoother from mmpose.datasets import DatasetInfo try: @@ -59,7 +60,19 @@ def main(): parser.add_argument( '--euro', action='store_true', - help='Using One_Euro_Filter for smoothing') + help='(Deprecated, please use --smooth and --smooth-filter-cfg) ' + 'Using One_Euro_Filter for smoothing.') + parser.add_argument( + '--smooth', + action='store_true', + help='Apply a temporal filter to smooth the pose estimation results. ' + 'See also --smooth-filter-cfg.') + parser.add_argument( + '--smooth-filter-cfg', + type=str, + default='configs/_base_/filters/one_euro.py', + help='Config file of the filter to smooth the pose estimation ' + 'results. See also --smooth.') parser.add_argument( '--radius', type=int, @@ -122,6 +135,20 @@ def main(): # e.g. use ('backbone', ) to return backbone feature output_layer_names = None + # build pose smoother for temporal refinement + if args.euro: + warnings.warn( + 'Argument --euro will be deprecated in the future. ' + 'Please use --smooth to enable temporal smoothing, and ' + '--smooth-filter-cfg to set the filter config.', + DeprecationWarning) + smoother = Smoother( + filter_cfg='configs/_base_/filters/one_euro.py', keypoint_dim=2) + elif args.smooth: + smoother = Smoother(filter_cfg=args.smooth_filter_cfg, keypoint_dim=2) + else: + smoother = None + next_id = 0 pose_results = [] while (cap.isOpened()): @@ -154,9 +181,11 @@ def main(): pose_results_last, next_id, use_oks=args.use_oks_tracking, - tracking_thr=args.tracking_thr, - use_one_euro=args.euro, - fps=fps) + tracking_thr=args.tracking_thr) + + # post-process the pose results with smoother + if smoother: + pose_results = smoother.smooth(pose_results) # show the results vis_img = vis_pose_tracking_result( diff --git a/demo/top_down_pose_tracking_demo_with_mmtracking.py b/demo/top_down_pose_tracking_demo_with_mmtracking.py index 9902e0674ec..9de6369cd78 100644 --- a/demo/top_down_pose_tracking_demo_with_mmtracking.py +++ b/demo/top_down_pose_tracking_demo_with_mmtracking.py @@ -7,6 +7,7 @@ from mmpose.apis import (inference_top_down_pose_model, init_pose_model, vis_pose_tracking_result) +from mmpose.core import Smoother from mmpose.datasets import DatasetInfo try: @@ -78,6 +79,17 @@ def main(): type=int, default=1, help='Link thickness for visualization') + parser.add_argument( + '--smooth', + action='store_true', + help='Apply a temporal filter to smooth the pose estimation results. ' + 'See also --smooth-filter-cfg.') + parser.add_argument( + '--smooth-filter-cfg', + type=str, + default='configs/_base_/filters/one_euro.py', + help='Config file of the filter to smooth the pose estimation ' + 'results. See also --smooth.') assert has_mmtrack, 'Please install mmtrack to run the demo.' @@ -127,6 +139,12 @@ def main(): # e.g. use ('backbone', ) to return backbone feature output_layer_names = None + # build pose smoother for temporal refinement + if args.smooth: + smoother = Smoother(filter_cfg=args.smooth_filter_cfg, keypoint_dim=2) + else: + smoother = None + frame_id = 0 while (cap.isOpened()): flag, img = cap.read() @@ -151,6 +169,9 @@ def main(): return_heatmap=return_heatmap, outputs=output_layer_names) + if smoother: + pose_results = smoother.smooth(pose_results) + # show the results vis_img = vis_pose_tracking_result( pose_model, diff --git a/demo/webcam_demo.py b/demo/webcam_demo.py index e3801a38d33..e390569a31e 100644 --- a/demo/webcam_demo.py +++ b/demo/webcam_demo.py @@ -10,7 +10,7 @@ from mmpose.apis import (get_track_id, inference_top_down_pose_model, init_pose_model, vis_pose_result) -from mmpose.core import apply_bugeye_effect, apply_sunglasses_effect +from mmpose.core import Smoother, apply_bugeye_effect, apply_sunglasses_effect from mmpose.utils import StopWatch try: @@ -152,6 +152,18 @@ def parse_args(): help='Enable synchronous mode that video I/O and inference will be ' 'temporally aligned. Note that this will reduce the display FPS.') + parser.add_argument( + '--smooth', + action='store_true', + help='Apply a temporal filter to smooth the pose estimation results. ' + 'See also --smooth-filter-cfg.') + parser.add_argument( + '--smooth-filter-cfg', + type=str, + default='configs/_base_/filters/one_euro.py', + help='Config file of the filter to smooth the pose estimation ' + 'results. See also --smooth.') + return parser.parse_args() @@ -265,8 +277,9 @@ def inference_pose(): ts_input, frame, t_info, mmdet_results = det_result_queue.popleft() pose_results_list = [] - for model_info, pose_history in zip(pose_model_list, - pose_history_list): + for model_info, pose_history, smoother in zip(pose_model_list, + pose_history_list, + pose_smoother_list): model_name = model_info['name'] pose_model = model_info['model'] cat_ids = model_info['cat_ids'] @@ -295,9 +308,10 @@ def inference_pose(): pose_results_last, next_id, use_oks=False, - tracking_thr=0.3, - use_one_euro=True, - fps=None) + tracking_thr=0.3) + + if smoother: + pose_results = smoother.smooth(pose_results) pose_results_list.append(pose_results) @@ -497,6 +511,7 @@ def main(): global pose_result_queue, pose_result_queue_mutex global det_model, pose_model_list, pose_history_list global event_exit, event_inference_done + global pose_smoother_list args = parse_args() @@ -540,6 +555,16 @@ def main(): for _ in range(len(pose_model_list)): pose_history_list.append({'pose_results_last': [], 'next_id': 0}) + # build pose smoother for temporal refinement + pose_smoother_list = [] + for _ in range(len(pose_model_list)): + if args.smooth: + smoother = Smoother( + filter_cfg=args.smooth_filter_cfg, keypoint_dim=2) + else: + smoother = None + pose_smoother_list.append(smoother) + # frame buffer if args.buffer_size > 0: buffer_size = args.buffer_size diff --git a/mmpose/apis/inference_tracking.py b/mmpose/apis/inference_tracking.py index 9494fbaa75c..2c85d77c468 100644 --- a/mmpose/apis/inference_tracking.py +++ b/mmpose/apis/inference_tracking.py @@ -177,8 +177,9 @@ def get_track_id(results, Args: results (list[dict]): The bbox & pose results of the current frame (bbox_result, pose_result). - results_last (list[dict]): The bbox & pose & track_id info of the - last frame (bbox_result, pose_result, track_id). + results_last (list[dict], optional): The bbox & pose & track_id info + of the last frame (bbox_result, pose_result, track_id). None is + equivalent to an empty result list. Default: None next_id (int): The track id for the new person instance. min_keypoints (int): Minimum number of keypoints recognized as person. default: 3. @@ -194,6 +195,18 @@ def get_track_id(results, current frame (bbox_result, pose_result, track_id). - next_id (int): The track id for the new person instance. """ + if use_one_euro: + warnings.warn( + 'In the future, get_track_id() will no longer perform ' + 'temporal refinement and the arguments `use_one_euro` and ' + '`fps` will be deprecated. This part of function has been ' + 'migrated to Smoother (mmpose.core.Smoother). See ' + 'demo/top_down_pose_trackign_demo_with_mmdet.py for an ' + 'example.', DeprecationWarning) + + if results_last is None: + results_last = [] + results = _get_area(results) if use_oks: @@ -216,6 +229,7 @@ def get_track_id(results, result['track_id'] = -1 else: result['track_id'] = track_id + if use_one_euro: result['keypoints'] = _temporal_refine( result, match_result, fps=fps) diff --git a/mmpose/core/post_processing/__init__.py b/mmpose/core/post_processing/__init__.py index 1ee6858d953..8076b799b9e 100644 --- a/mmpose/core/post_processing/__init__.py +++ b/mmpose/core/post_processing/__init__.py @@ -1,14 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. + from .nms import oks_iou, oks_nms, soft_oks_nms from .one_euro_filter import OneEuroFilter from .post_transforms import (affine_transform, flip_back, fliplr_joints, fliplr_regression, get_affine_transform, get_warp_matrix, rotate_point, transform_preds, warp_affine_joints) +from .smoother import Smoother __all__ = [ 'oks_nms', 'soft_oks_nms', 'affine_transform', 'rotate_point', 'flip_back', 'fliplr_joints', 'fliplr_regression', 'transform_preds', - 'get_affine_transform', 'get_warp_matrix', 'warp_affine_joints', - 'OneEuroFilter', 'oks_iou' + 'get_affine_transform', 'get_warp_matrix', 'warp_affine_joints', 'oks_iou', + 'OneEuroFilter', 'Smoother' ] diff --git a/mmpose/core/post_processing/one_euro_filter.py b/mmpose/core/post_processing/one_euro_filter.py index 01ffa5fda9b..325466522db 100644 --- a/mmpose/core/post_processing/one_euro_filter.py +++ b/mmpose/core/post_processing/one_euro_filter.py @@ -2,6 +2,7 @@ # Adapted from https://github.com/HoBeom/OneEuroFilter-Numpy # Original licence: Copyright (c) HoBeom Jeon, under the MIT License. # ------------------------------------------------------------------------------ +import warnings from time import time import numpy as np @@ -35,6 +36,13 @@ def __init__(self, d_cutoff (float): Input data FPS fps (float): Video FPS for video inference """ + warnings.warn( + 'OneEuroFilter from ' + '`mmpose/core/post_processing/one_euro_filter.py` will ' + 'be deprecated in the future. Please use Smoother' + '(`mmpose/core/post_processing/smoother.py`) with ' + 'OneEuroFilter (`mmpose/core/post_processing/temporal_' + 'filters/one_euro_filter.py`).', DeprecationWarning) # The parameters. self.data_shape = x0.shape @@ -50,10 +58,13 @@ def __init__(self, # Using in realtime inference self.t_e = None self.skip_frame_factor = d_cutoff + self.fps = d_cutoff else: # fps using video inference self.realtime = False - self.d_cutoff = np.full(x0.shape, float(fps)) + self.fps = float(fps) + self.d_cutoff = np.full(x0.shape, self.fps) + self.t_prev = time() def __call__(self, x, t_e=1.0): @@ -81,13 +92,13 @@ def __call__(self, x, t_e=1.0): mask = np.ma.masked_where(x <= 0, x) # The filtered derivative of the signal. - a_d = smoothing_factor(t_e, self.d_cutoff) + a_d = smoothing_factor(t_e / self.fps, self.d_cutoff) dx = (x - self.x_prev) / t_e dx_hat = exponential_smoothing(a_d, dx, self.dx_prev) # The filtered signal. cutoff = self.min_cutoff + self.beta * np.abs(dx_hat) - a = smoothing_factor(t_e, cutoff) + a = smoothing_factor(t_e / self.fps, cutoff) x_hat = exponential_smoothing(a, x, self.x_prev) # missing keypoints remove diff --git a/mmpose/core/post_processing/smoother.py b/mmpose/core/post_processing/smoother.py new file mode 100644 index 00000000000..61a68509ec7 --- /dev/null +++ b/mmpose/core/post_processing/smoother.py @@ -0,0 +1,218 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from typing import Dict, Union + +import numpy as np +from mmcv import Config, is_seq_of + +from mmpose.core.post_processing.temporal_filters import build_filter + + +class Smoother(): + """Smoother to apply temporal smoothing on pose estimation results with a + filter. + + Note: + T: The temporal length of the pose sequence + K: The keypoint number of each target + C: The keypoint coordinate dimension + + Args: + filter_cfg (dict | str): The filter config. See example config files in + `configs/_base_/filters/` for details. Alternatively a config file + path can be accepted and the config will be loaded. + keypoint_dim (int): The keypoint coordinate dimension, which is + also indicated as C. Default: 2 + keypoint_key (str): The dict key of the keypoints in the pose results. + Default: 'keypoints' + Example: + >>> import numpy as np + >>> # Build dummy pose result + >>> results = [] + >>> for t in range(10): + >>> results_t = [] + >>> for track_id in range(2): + >>> result = { + >>> 'track_id': track_id, + >>> 'keypoints': np.random.rand(17, 3) + >>> } + >>> results_t.append(result) + >>> results.append(results_t) + >>> # Example 1: Smooth multi-frame pose results offline. + >>> filter_cfg = dict(type='GaussianFilter', window_size=3) + >>> smoother = Smoother(filter_cfg, keypoint_dim=2) + >>> smoothed_results = smoother.smooth(results) + >>> # Example 2: Smooth pose results online frame-by-frame + >>> filter_cfg = dict(type='GaussianFilter', window_size=3) + >>> smoother = Smoother(filter_cfg, keypoint_dim=2) + >>> for result_t in results: + >>> smoothed_result_t = smoother.smooth(result_t) + """ + + def __init__(self, + filter_cfg: Union[Dict, str], + keypoint_dim: int = 2, + keypoint_key: str = 'keypoints'): + if isinstance(filter_cfg, str): + filter_cfg = Config.fromfile(filter_cfg).filter_cfg + self.filter_cfg = filter_cfg + self.keypoint_dim = keypoint_dim + self.key = keypoint_key + self.padding_size = build_filter(filter_cfg).window_size - 1 + self.history = {} + + def _collate_pose(self, results): + """Collate the pose results to pose sequences. + + Args: + results (list[list[dict]]): The pose results of multiple frames. + + Returns: + dict[str, np.ndarray]: A dict of collated pose sequences, where + the key is the track_id (in untracked scenario, the target index + will be used as the track_id), and the value is the pose sequence + in an array of shape [T, K, C] + """ + + if self._has_track_id(results): + # If the results have track_id, use it as the target indicator + results = [{res['track_id']: res + for res in results_t} for results_t in results] + track_ids = results[0].keys() + + for t, results_t in enumerate(results[1:]): + if results_t.keys() != track_ids: + raise ValueError(f'Inconsistent track ids in frame {t+1}') + + collated = { + id: np.stack([ + results_t[id][self.key][:, :self.keypoint_dim] + for results_t in results + ]) + for id in track_ids + } + else: + # If the results don't have track_id, use the target index + # as the target indicator + n_target = len(results[0]) + for t, results_t in enumerate(results[1:]): + if len(results_t) != n_target: + raise ValueError( + f'Inconsistent target number in frame {t+1}: ' + f'{len(results_t)} vs {n_target}') + + collated = { + id: np.stack([ + results_t[id][self.key][:, :self.keypoint_dim] + for results_t in results + ]) + for id in range(n_target) + } + + return collated + + def _scatter_pose(self, results, poses): + """Scatter the smoothed pose sequences and use them to update the pose + results. + + Args: + results (list[list[dict]]): The original pose results + poses (dict[str, np.ndarray]): The smoothed pose sequences + + Returns: + list[list[dict]]: The updated pose results + """ + updated_results = [] + for t, results_t in enumerate(results): + updated_results_t = [] + if self._has_track_id(results): + id2result = ((result['track_id'], result) + for result in results_t) + else: + id2result = enumerate(results_t) + + for track_id, result in id2result: + result = copy.deepcopy(result) + result[self.key][:, :self.keypoint_dim] = poses[track_id][t] + updated_results_t.append(result) + + updated_results.append(updated_results_t) + return updated_results + + @staticmethod + def _has_track_id(results): + """Check if the pose results contain track_id.""" + return 'track_id' in results[0][0] + + def smooth(self, results): + """Apply temporal smoothing on pose estimation sequences. + + Args: + results (list[dict] | list[list[dict]]): The pose results of a + single frame (non-nested list) or multiple frames (nested + list). The result of each target is a dict, which should + contains: + + - track_id (optional, Any): The track ID of the target + - keypoints (np.ndarray): The keypoint coordinates in [K, C] + + Returns: + (list[dict] | list[list[dict]]): Temporal smoothed pose results, + which has the same data structure as the input's. + """ + + # Check if input is empty + if not (results) or not (results[0]): + warnings.warn('Smoother received empty result.') + return results + + # Check input is single frame or sequence + if is_seq_of(results, dict): + single_frame = True + results = [results] + else: + assert is_seq_of(results, list) + single_frame = False + + # Get temporal length of input + T = len(results) + + # Collate the input results to pose sequences + poses = self._collate_pose(results) + + # Smooth the pose sequence of each target + smoothed_poses = {} + update_history = {} + for track_id, pose in poses.items(): + if track_id in self.history: + # For tracked target, get its filter and pose history + pose_history, pose_filter = self.history[track_id] + if self.padding_size > 0: + # Pad the pose sequence with pose history + pose = np.concatenate((pose_history, pose), axis=0) + else: + # For new target, build a new filter + pose_filter = build_filter(self.filter_cfg) + + # Update the history information + if self.padding_size > 0: + pose_history = pose[-self.padding_size:].copy() + else: + pose_history = None + update_history[track_id] = (pose_history, pose_filter) + + # Smooth the pose sequence with the filter + smoothed_pose = pose_filter(pose) + smoothed_poses[track_id] = smoothed_pose[-T:] + + self.history = update_history + + # Scatter the pose sequences back to the format of results + smoothed_results = self._scatter_pose(results, smoothed_poses) + + # If the input is single frame, remove the nested list to keep the + # output structure consistent with the input's + if single_frame: + smoothed_results = smoothed_results[0] + return smoothed_results diff --git a/mmpose/core/post_processing/temporal_filters/__init__.py b/mmpose/core/post_processing/temporal_filters/__init__.py new file mode 100644 index 00000000000..f5d8a26d82e --- /dev/null +++ b/mmpose/core/post_processing/temporal_filters/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import build_filter +from .gaussian_filter import GaussianFilter +from .one_euro_filter import OneEuroFilter +from .savizky_golay_filter import SavizkyGolayFilter + +__all__ = [ + 'build_filter', 'GaussianFilter', 'OneEuroFilter', 'SavizkyGolayFilter' +] diff --git a/mmpose/core/post_processing/temporal_filters/builder.py b/mmpose/core/post_processing/temporal_filters/builder.py new file mode 100644 index 00000000000..adb914c5222 --- /dev/null +++ b/mmpose/core/post_processing/temporal_filters/builder.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.utils import Registry + +FILTERS = Registry('filters') + + +def build_filter(cfg): + """Build filters function.""" + return FILTERS.build(cfg) diff --git a/mmpose/core/post_processing/temporal_filters/filter.py b/mmpose/core/post_processing/temporal_filters/filter.py new file mode 100644 index 00000000000..2789f560d1d --- /dev/null +++ b/mmpose/core/post_processing/temporal_filters/filter.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + + +class TemporalFilter(metaclass=ABCMeta): + """Base class of temporal filter. + + A subclass should implement the method __call__(). + + Parameters: + window_size (int): the size of the sliding window. + """ + + def __init__(self, window_size=1): + self._window_size = window_size + + @property + def window_size(self): + return self._window_size + + @abstractmethod + def __call__(self, x): + """Apply filter to a pose sequence. + + Note: + T: The temporal length of the pose sequence + K: The keypoint number of each target + C: The keypoint coordinate dimension + + Args: + x (np.ndarray): input pose sequence in shape [T, K, C] + + Returns: + np.ndarray: Smoothed pose sequence in shape [T, K, C] + """ diff --git a/mmpose/core/post_processing/temporal_filters/gaussian_filter.py b/mmpose/core/post_processing/temporal_filters/gaussian_filter.py new file mode 100644 index 00000000000..b737cdb15ae --- /dev/null +++ b/mmpose/core/post_processing/temporal_filters/gaussian_filter.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from scipy.ndimage.filters import gaussian_filter1d +from scipy.signal import medfilt + +from .builder import FILTERS +from .filter import TemporalFilter + + +@FILTERS.register_module(name=['GaussianFilter', 'gaussian']) +class GaussianFilter(TemporalFilter): + """Apply median filter and then gaussian filter. + + Adapted from: + https://github.com/akanazawa/human_dynamics/blob/mas + ter/src/util/smooth_bbox.py. + + Args: + window_size (int): The size of the filter window (i.e., the number + of coefficients). window_length must be a positive odd integer. + Default: 11 + sigma (float): Sigma for gaussian smoothing. Default: 4.0 + """ + + def __init__(self, window_size: int = 11, sigma: float = 4.0): + super().__init__(window_size) + assert window_size % 2 == 1, ( + 'The window size of GaussianFilter should' + f'be odd, but got {window_size}') + self.sigma = sigma + + def __call__(self, x: np.ndarray): + + assert x.ndim == 3, ('Input should be an array with shape [T, K, C]' + f', but got invalid shape {x.shape}') + + T = x.shape[0] + if T < self.window_size: + pad_width = [(self.window_size - T, 0), (0, 0), (0, 0)] + x = np.pad(x, pad_width, mode='edge') + smoothed = medfilt(x, (self.window_size, 1, 1)) + + smoothed = gaussian_filter1d(smoothed, self.sigma, axis=0) + return smoothed[-T:] diff --git a/mmpose/core/post_processing/temporal_filters/one_euro_filter.py b/mmpose/core/post_processing/temporal_filters/one_euro_filter.py new file mode 100644 index 00000000000..938375f8845 --- /dev/null +++ b/mmpose/core/post_processing/temporal_filters/one_euro_filter.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------------ +# Adapted from https://github.com/HoBeom/OneEuroFilter-Numpy +# Original licence: Copyright (c) HoBeom Jeon, under the MIT License. +# ------------------------------------------------------------------------------ +import math + +import numpy as np + +from .builder import FILTERS +from .filter import TemporalFilter + + +def smoothing_factor(t_e, cutoff): + r = 2 * math.pi * cutoff * t_e + return r / (r + 1) + + +def exponential_smoothing(a, x, x_prev): + return a * x + (1 - a) * x_prev + + +class OneEuro: + + def __init__(self, t0, x0, dx0, min_cutoff, beta, d_cutoff=1.0): + super(OneEuro, self).__init__() + """Initialize the one euro filter.""" + # The parameters. + self.min_cutoff = float(min_cutoff) + self.beta = float(beta) + self.d_cutoff = float(d_cutoff) + # Previous values. + self.x_prev = x0 + self.dx_prev = dx0 + self.t_prev = t0 + + def __call__(self, x, t=None): + """Compute the filtered signal.""" + + if t is None: + # Assume input is feed frame by frame if not specified + t = self.t_prev + 1 + + t_e = t - self.t_prev + + # The filtered derivative of the signal. + a_d = smoothing_factor(t_e, self.d_cutoff) # [k, c] + dx = (x - self.x_prev) / t_e + dx_hat = exponential_smoothing(a_d, dx, self.dx_prev) + + # The filtered signal. + cutoff = self.min_cutoff + self.beta * np.abs(dx_hat) + a = smoothing_factor(t_e, cutoff) + x_hat = exponential_smoothing(a, x, self.x_prev) + # Memorize the previous values. + self.x_prev = x_hat + self.dx_prev = dx_hat + self.t_prev = t + return x_hat + + +@FILTERS.register_module(name=['OneEuroFilter', 'oneeuro']) +class OneEuroFilter(TemporalFilter): + """Oneeuro filter, source code: https://github.com/mkocabas/VIBE/blob/c0 + c3f77d587351c806e901221a9dc05d1ffade4b/lib/utils/smooth_pose.py. + + Args: + min_cutoff (float, optional): Decreasing the minimum cutoff frequency + decreases slow speed jitter + beta (float, optional): Increasing the speed coefficient(beta) + decreases speed lag. + """ + + def __init__(self, min_cutoff=0.004, beta=0.7): + # OneEuroFilter has Markov Property and maintains status variables + # within the class, thus has a windows_size of 1 + super().__init__(window_size=1) + self.min_cutoff = min_cutoff + self.beta = beta + self._one_euro = None + + def __call__(self, x: np.ndarray): + assert x.ndim == 3, ('Input should be an array with shape [T, K, C]' + f', but got invalid shape {x.shape}') + + pred_pose_hat = x.copy() + + if self._one_euro is None: + # The filter is invoked for the first time + # Initialize the filter + self._one_euro = OneEuro( + np.zeros_like(x[0]), + x[0], + dx0=0.0, + min_cutoff=self.min_cutoff, + beta=self.beta, + ) + t0 = 1 + else: + # The filter has been invoked + t0 = 0 + + for t, pose in enumerate(x): + if t < t0: + # If the filter is invoked for the first time + # set pred_pose_hat[0] = x[0] + continue + pose = self._one_euro(pose) + pred_pose_hat[t] = pose + + return pred_pose_hat diff --git a/mmpose/core/post_processing/temporal_filters/savizky_golay_filter.py b/mmpose/core/post_processing/temporal_filters/savizky_golay_filter.py new file mode 100644 index 00000000000..18e0528f6ce --- /dev/null +++ b/mmpose/core/post_processing/temporal_filters/savizky_golay_filter.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from scipy.signal import savgol_filter + +from .builder import FILTERS +from .filter import TemporalFilter + + +@FILTERS.register_module(name=['SavizkyGolayFilter', 'savgol']) +class SavizkyGolayFilter(TemporalFilter): + """Savizky-Golay filter. + + Adapted from: + https://docs.scipy.org/doc/scipy/reference/generated/ + scipy.signal.savgol_filter.html. + + Args: + window_size (int): The size of the filter window (i.e., the number + of coefficients). window_length must be a positive odd integer. + Default: 11 + polyorder (int): The order of the polynomial used to fit the samples. + polyorder must be less than window_size. + """ + + def __init__(self, window_size: int = 11, polyorder: int = 2): + super().__init__(window_size) + + # 1-D Savitzky-Golay filter + assert polyorder > 0, ( + f'Got invalid parameter polyorder={polyorder}. Polyorder ' + 'should be positive.') + assert polyorder < window_size, ( + f'Got invalid parameters polyorder={polyorder} and ' + f'window_size={window_size}. Polyorder should be less than ' + 'window_size.') + self.polyorder = polyorder + + def __call__(self, x: np.ndarray): + + assert x.ndim == 3, ('Input should be an array with shape [T, K, C]' + f', but got invalid shape {x.shape}') + + T = x.shape[0] + if T < self.window_size: + pad_width = [(self.window_size - T, 0), (0, 0), (0, 0)] + x = np.pad(x, pad_width, mode='edge') + + smoothed = savgol_filter(x, self.window_size, self.polyorder, axis=0) + + return smoothed[-T:] diff --git a/tests/test_apis/test_inference_tracking.py b/tests/test_apis/test_inference_tracking.py index 1ef62b771ae..90f7c673835 100644 --- a/tests/test_apis/test_inference_tracking.py +++ b/tests/test_apis/test_inference_tracking.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pytest + from mmpose.apis import (get_track_id, inference_bottom_up_pose_model, inference_top_down_pose_model, init_pose_model, vis_pose_tracking_result) @@ -152,6 +154,11 @@ def test_bottom_up_pose_tracking_demo(): pose_results, pose_results_last, next_id=next_id, use_oks=True) pose_results_last = pose_results - # one_euro - pose_results, next_id = get_track_id( - pose_results, pose_results_last, next_id=next_id, use_one_euro=True) + + # one_euro (will be deprecated) + with pytest.deprecated_call(): + pose_results, next_id = get_track_id( + pose_results, + pose_results_last, + next_id=next_id, + use_one_euro=True) diff --git a/tests/test_post_processing/test_filter.py b/tests/test_backward_compatibility/test_one_euro_filter_compatibility.py similarity index 100% rename from tests/test_post_processing/test_filter.py rename to tests/test_backward_compatibility/test_one_euro_filter_compatibility.py diff --git a/tests/test_post_processing/test_smoother.py b/tests/test_post_processing/test_smoother.py new file mode 100644 index 00000000000..65ab97bcd8f --- /dev/null +++ b/tests/test_post_processing/test_smoother.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Union +from unittest import TestCase + +import numpy as np +from mmcv import is_list_of + +from mmpose.core.post_processing.smoother import Smoother + + +class TestSmoother(TestCase): + + def build_smoother(self): + smoother = Smoother( + 'configs/_base_/filters/gaussian.py', keypoint_dim=2) + return smoother + + def build_pose_results(self, + num_target: Union[int, List[int]], + num_frame: int = -1, + has_track_id: bool = True): + keypoint_shape = (17, 2) + results = [] + + if isinstance(num_target, list): + num_frame = len(num_target) + else: + assert num_frame >= 0 + num_target = [num_target] * num_frame + + for n in num_target: + results_t = [] + for idx in range(n): + result = dict(keypoints=np.random.rand(*keypoint_shape)) + if has_track_id: + result['track_id'] = str(idx) + results_t.append(result) + results.append(results_t) + return results + + def test_corner_cases(self): + # Test empty input + smoother = self.build_smoother() + results = [] + with self.assertWarnsRegex(UserWarning, + 'Smoother received empty result.'): + _ = smoother.smooth(results) + + # Test inconsistent tracked poses + smoother = self.build_smoother() + results = self.build_pose_results(num_target=[1, 2], has_track_id=True) + with self.assertRaisesRegex(ValueError, 'Inconsistent track ids'): + _ = smoother.smooth(results) + + # Test inconsistent untracked poses + smoother = self.build_smoother() + results = self.build_pose_results( + num_target=[1, 2], has_track_id=False) + with self.assertRaisesRegex(ValueError, 'Inconsistent target number'): + _ = smoother.smooth(results) + + def test_smooth_online_with_trackid(self): + smoother = self.build_smoother() + num_target = [2] * 10 + [3] * 10 + results = self.build_pose_results( + num_target=num_target, has_track_id=True) + for results_t in results: + smoothed_results_t = smoother.smooth(results_t) + + # Sort by track_id + results_t.sort(key=lambda x: x['track_id']) + smoothed_results_t.sort(key=lambda x: x['track_id']) + + # Check the output is non-nested list + self.assertTrue(is_list_of(smoothed_results_t, dict)) + # Check the target number in the frame is correct + self.assertEqual(len(smoothed_results_t), len(results_t)) + + for result, smoothed_result in zip(results_t, smoothed_results_t): + # Check the target_id is correct + self.assertEqual(result['track_id'], + smoothed_result['track_id']) + # Check the pose shape is correct + self.assertEqual(result['keypoints'].shape, + smoothed_result['keypoints'].shape) + + def test_smooth_online_wo_trackid(self): + smoother = self.build_smoother() + num_target = [2] * 10 + [3] * 10 + results = self.build_pose_results( + num_target=num_target, has_track_id=False) + for results_t in results: + smoothed_results_t = smoother.smooth(results_t) + + # Check the output is non-nested list + self.assertTrue(is_list_of(smoothed_results_t, dict)) + # Check the target number in the frame is correct + self.assertEqual(len(smoothed_results_t), len(results_t)) + + for result, smoothed_result in zip(results_t, smoothed_results_t): + # Check the pose shape is correct + self.assertEqual(result['keypoints'].shape, + smoothed_result['keypoints'].shape) + + def test_smooth_offline_with_trackid(self): + smoother = self.build_smoother() + results = self.build_pose_results( + num_target=2, num_frame=20, has_track_id=True) + smoothed_results = smoother.smooth(results) + for results_t, smoothed_results_t in zip(results, smoothed_results): + # Sort by track_id + results_t.sort(key=lambda x: x['track_id']) + smoothed_results_t.sort(key=lambda x: x['track_id']) + + # Check the output is non-nested list + self.assertTrue(is_list_of(smoothed_results_t, dict)) + # Check the target number in the frame is correct + self.assertEqual(len(smoothed_results_t), len(results_t)) + + for result, smoothed_result in zip(results_t, smoothed_results_t): + # Check the target_id is correct + self.assertEqual(result['track_id'], + smoothed_result['track_id']) + # Check the pose shape is correct + self.assertEqual(result['keypoints'].shape, + smoothed_result['keypoints'].shape) + + def test_smooth_offline_wo_trackid(self): + smoother = self.build_smoother() + results = self.build_pose_results( + num_target=2, num_frame=20, has_track_id=False) + smoothed_results = smoother.smooth(results) + + for results_t, smoothed_results_t in zip(results, smoothed_results): + # Check the output is non-nested list + self.assertTrue(is_list_of(smoothed_results_t, dict)) + # Check the target number in the frame is correct + self.assertEqual(len(smoothed_results_t), len(results_t)) + + for result, smoothed_result in zip(results_t, smoothed_results_t): + # Check the pose shape is correct + self.assertEqual(result['keypoints'].shape, + smoothed_result['keypoints'].shape) diff --git a/tests/test_post_processing/test_temporal_filter.py b/tests/test_post_processing/test_temporal_filter.py new file mode 100644 index 00000000000..1afbb6d5724 --- /dev/null +++ b/tests/test_post_processing/test_temporal_filter.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from unittest import TestCase + +import numpy as np +from mmcv import Config + +from mmpose.core.post_processing.temporal_filters import build_filter + + +class TestTemporalFilter(TestCase): + cfg_folder = 'configs/_base_/filters' + + def get_filter_input(self, + num_frame: int, + num_keypoint: int = 17, + keypoint_dim: int = 2): + return np.random.rand(num_frame, num_keypoint, + keypoint_dim).astype(np.float32) + + def get_filter_configs(self): + cfg_files = os.listdir(self.cfg_folder) + for cfg_file in cfg_files: + cfg = Config.fromfile(osp.join(self.cfg_folder, cfg_file)) + assert 'filter_cfg' in cfg + yield cfg.filter_cfg + + def test_temporal_filter(self): + for filter_cfg in self.get_filter_configs(): + with self.subTest(msg=f'Test {filter_cfg.type}'): + filter = build_filter(filter_cfg) + + # Test input with single frame + x = self.get_filter_input(num_frame=1) + y = filter(x) + self.assertTrue(isinstance(y, np.ndarray)) + self.assertEqual(x.shape, y.shape) + + # Test input with length > window_size + window_size = filter.window_size + x = self.get_filter_input(num_frame=window_size + 1) + y = filter(x) + self.assertTrue(isinstance(y, np.ndarray)) + self.assertEqual(x.shape, y.shape) + + # Test invalid diff --git a/tools/webcam/webcam_apis/nodes/mmpose_node.py b/tools/webcam/webcam_apis/nodes/mmpose_node.py index 167d7413ea4..89b9100a7d2 100644 --- a/tools/webcam/webcam_apis/nodes/mmpose_node.py +++ b/tools/webcam/webcam_apis/nodes/mmpose_node.py @@ -1,29 +1,39 @@ # Copyright (c) OpenMMLab. All rights reserved. -import time +from dataclasses import dataclass from typing import Dict, List, Optional, Union from mmpose.apis import (get_track_id, inference_top_down_pose_model, init_pose_model) +from mmpose.core import Smoother from ..utils import Message from .builder import NODES from .node import Node +@dataclass +class TrackInfo: + next_id: int = 0 + last_pose_preds: List = None + + @NODES.register_module() class TopDownPoseEstimatorNode(Node): - def __init__(self, - name: str, - model_config: str, - model_checkpoint: str, - input_buffer: str, - output_buffer: Union[str, List[str]], - enable_key: Optional[Union[str, int]] = None, - enable: bool = True, - device: str = 'cuda:0', - cls_ids: Optional[List] = None, - cls_names: Optional[List] = None, - bbox_thr: float = 0.5): + def __init__( + self, + name: str, + model_config: str, + model_checkpoint: str, + input_buffer: str, + output_buffer: Union[str, List[str]], + enable_key: Optional[Union[str, int]] = None, + enable: bool = True, + device: str = 'cuda:0', + cls_ids: Optional[List] = None, + cls_names: Optional[List] = None, + bbox_thr: float = 0.5, + smooth: bool = False, + smooth_filter_cfg: str = 'configs/_base_/filters/one_euro.py'): super().__init__(name=name, enable_key=enable_key, enable=enable) # Init model @@ -35,6 +45,10 @@ def __init__(self, self.cls_names = cls_names self.bbox_thr = bbox_thr + if smooth: + self.smoother = Smoother(smooth_filter_cfg, keypoint_dim=2) + else: + self.smoother = None # Init model self.model = init_pose_model( self.model_config, @@ -42,11 +56,7 @@ def __init__(self, device=self.device.lower()) # Store history for pose tracking - self.track_info = { - 'next_id': 0, - 'last_pose_preds': [], - 'last_time': None - } + self.track_info = TrackInfo() # Register buffers self.register_input_buffer(input_buffer, 'input', essential=True) @@ -91,26 +101,19 @@ def process(self, input_msgs: Dict[str, Message]) -> Message: format='xyxy') # Pose tracking - current_time = time.time() - if self.track_info['last_time'] is None: - fps = None - elif self.track_info['last_time'] >= current_time: - fps = None - else: - fps = 1.0 / (current_time - self.track_info['last_time']) - pose_preds, next_id = get_track_id( pose_preds, - self.track_info['last_pose_preds'], - self.track_info['next_id'], + self.track_info.last_pose_preds, + self.track_info.next_id, use_oks=False, - tracking_thr=0.3, - use_one_euro=True, - fps=fps) + tracking_thr=0.3) + + self.track_info.next_id = next_id + self.track_info.last_pose_preds = pose_preds.copy() - self.track_info['next_id'] = next_id - self.track_info['last_pose_preds'] = pose_preds.copy() - self.track_info['last_time'] = current_time + # Pose smoothing + if self.smoother: + pose_preds = self.smoother.smooth(pose_preds) pose_result = { 'preds': pose_preds,