From d87c99b99252a0550be6fbee68f604cba1776d6b Mon Sep 17 00:00:00 2001 From: Tau-J <674106399@qq.com> Date: Mon, 6 Mar 2023 11:25:21 +0800 Subject: [PATCH] copy metainfo in get_data_info --- mmpose/datasets/dataset_wrappers.py | 18 +++++++++--------- .../datasets/base/base_coco_style_dataset.py | 15 ++++++++++++++- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/mmpose/datasets/dataset_wrappers.py b/mmpose/datasets/dataset_wrappers.py index 3836100ed2..28eeac9945 100644 --- a/mmpose/datasets/dataset_wrappers.py +++ b/mmpose/datasets/dataset_wrappers.py @@ -86,15 +86,6 @@ def prepare_data(self, idx: int) -> Any: data_info = self.get_data_info(idx) - # Add metainfo items that are required in the pipeline and the model - metainfo_keys = [ - 'upper_body_ids', 'lower_body_ids', 'flip_pairs', - 'dataset_keypoint_weights', 'flip_indices' - ] - - for key in metainfo_keys: - data_info[key] = deepcopy(self._metainfo[key]) - return self.pipeline(data_info) def get_data_info(self, idx: int) -> dict: @@ -109,6 +100,15 @@ def get_data_info(self, idx: int) -> dict: # Get data sample processed by ``subset.pipeline`` data_info = self.datasets[subset_idx][sample_idx] + # Add metainfo items that are required in the pipeline and the model + metainfo_keys = [ + 'upper_body_ids', 'lower_body_ids', 'flip_pairs', + 'dataset_keypoint_weights', 'flip_indices' + ] + + for key in metainfo_keys: + data_info[key] = deepcopy(self._metainfo[key]) + return data_info def full_init(self): diff --git a/mmpose/datasets/datasets/base/base_coco_style_dataset.py b/mmpose/datasets/datasets/base/base_coco_style_dataset.py index b301863536..a1c3bc2f5c 100644 --- a/mmpose/datasets/datasets/base/base_coco_style_dataset.py +++ b/mmpose/datasets/datasets/base/base_coco_style_dataset.py @@ -147,6 +147,19 @@ def prepare_data(self, idx) -> Any: """ data_info = self.get_data_info(idx) + return self.pipeline(data_info) + + def get_data_info(self, idx: int) -> dict: + """Get data info by index. + + Args: + idx (int): Index of data info. + + Returns: + dict: Data info. + """ + data_info = super().get_data_info(idx) + # Add metainfo items that are required in the pipeline and the model metainfo_keys = [ 'upper_body_ids', 'lower_body_ids', 'flip_pairs', @@ -160,7 +173,7 @@ def prepare_data(self, idx) -> Any: data_info[key] = deepcopy(self._metainfo[key]) - return self.pipeline(data_info) + return data_info def load_data_list(self) -> List[dict]: """Load data list from COCO annotation file or person detection result