From dd35080cab08597a4c810699de84983c45ea2d6b Mon Sep 17 00:00:00 2001 From: Eugene Date: Sat, 4 Nov 2023 19:48:53 +0200 Subject: [PATCH] Added missing rgb2bgr conversion --- .../coco_pose_estimation_dataset.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/super_gradients/training/datasets/pose_estimation_datasets/coco_pose_estimation_dataset.py b/src/super_gradients/training/datasets/pose_estimation_datasets/coco_pose_estimation_dataset.py index f2cc200c0a..15ef1fc679 100644 --- a/src/super_gradients/training/datasets/pose_estimation_datasets/coco_pose_estimation_dataset.py +++ b/src/super_gradients/training/datasets/pose_estimation_datasets/coco_pose_estimation_dataset.py @@ -10,7 +10,7 @@ from super_gradients.common.decorators.factory_decorator import resolve_param from super_gradients.common.factories.transforms_factory import TransformsFactory from super_gradients.common.factories.type_factory import TypeFactory -from super_gradients.common.object_names import Datasets +from super_gradients.common.object_names import Datasets, Processings from super_gradients.common.registry.registry import register_dataset from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xywh_to_xyxy, xyxy_to_xywh from super_gradients.training.datasets.pose_estimation_datasets.abstract_pose_estimation_dataset import AbstractPoseEstimationDataset @@ -206,3 +206,20 @@ def _get_crowd_mask(self, anno, img_info) -> np.ndarray: m += mask return (m < 0.5).astype(np.float32) + + def get_dataset_preprocessing_params(self) -> dict: + """ + This method returns a dictionary of parameters describing preprocessing steps to be applied to the dataset. + :return: + """ + rgb_to_bgr = {Processings.ReverseImageChannels: {}} + image_to_tensor = {Processings.ImagePermute: {"permutation": (2, 0, 1)}} + pipeline = [rgb_to_bgr] + self.transforms.get_equivalent_preprocessing() + [image_to_tensor] + params = dict( + conf=0.05, + image_processor={Processings.ComposeProcessing: {"processings": pipeline}}, + edge_links=self.edge_links, + edge_colors=self.edge_colors, + keypoint_colors=self.keypoint_colors, + ) + return params