diff --git a/src/super_gradients/training/metrics/pose_estimation_metrics.py b/src/super_gradients/training/metrics/pose_estimation_metrics.py index 1868e99128..66860e213f 100644 --- a/src/super_gradients/training/metrics/pose_estimation_metrics.py +++ b/src/super_gradients/training/metrics/pose_estimation_metrics.py @@ -103,7 +103,14 @@ def __init__( self.greater_component_is_better = dict((k, True) for k in self.stats_names) if oks_sigmas is None: - oks_sigmas = np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) / 10.0 + if num_joints == 17: + oks_sigmas = np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) / 10.0 + else: + oks_sigmas = np.array([0.1] * num_joints) + logger.warning( + f"Using default OKS sigmas of `0.1` for a custom dataset with {num_joints} joints. " + f"To silence this warning, you may want to specify OKS sigmas explicitly as it has direct impact on the AP score." + ) if len(oks_sigmas) != num_joints: raise ValueError(f"Length of oks_sigmas ({len(oks_sigmas)}) should be equal to num_joints {num_joints}") diff --git a/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py b/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py index edcb4b9f5c..c627f121e9 100644 --- a/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py +++ b/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py @@ -328,22 +328,32 @@ def __init__(self, arch_params): setattr(self, "stage{}".format(i + 2), stage) # build head net - inp_channels = int(sum(self.stages_spec.NUM_CHANNELS[-1])) - config_heatmap = self.spec.HEAD_HEATMAP - config_offset = self.spec.HEAD_OFFSET + self.head_inp_channels = int(sum(self.stages_spec.NUM_CHANNELS[-1])) + self.config_heatmap = self.spec.HEAD_HEATMAP + self.config_offset = self.spec.HEAD_OFFSET self.num_joints = arch_params.num_classes self.num_offset = self.num_joints * 2 self.num_joints_with_center = self.num_joints + 1 - self.offset_prekpt = config_offset["NUM_CHANNELS_PERKPT"] + self.offset_prekpt = self.config_offset["NUM_CHANNELS_PERKPT"] offset_channels = self.num_joints * self.offset_prekpt - self.transition_heatmap = self._make_transition_for_head(inp_channels, config_heatmap["NUM_CHANNELS"]) - self.transition_offset = self._make_transition_for_head(inp_channels, offset_channels) - self.head_heatmap = self._make_heatmap_head(config_heatmap) - self.offset_feature_layers, self.offset_final_layer = self._make_separete_regression_head(config_offset) - self.heatmap_activation = nn.Sigmoid() if config_heatmap["HEATMAP_APPLY_SIGMOID"] else nn.Identity() + self.transition_heatmap = self._make_transition_for_head(self.head_inp_channels, self.config_heatmap["NUM_CHANNELS"]) + self.transition_offset = self._make_transition_for_head(self.head_inp_channels, offset_channels) + self.head_heatmap = self._make_heatmap_head(self.config_heatmap) + self.offset_feature_layers, self.offset_final_layer = self._make_separete_regression_head(self.config_offset) + self.heatmap_activation = nn.Sigmoid() if self.config_heatmap["HEATMAP_APPLY_SIGMOID"] else nn.Identity() self.init_weights() + def replace_head(self, new_num_classes: int): + self.num_joints = new_num_classes + self.num_offset = new_num_classes * 2 + self.num_joints_with_center = new_num_classes + 1 + + offset_channels = self.num_joints * self.offset_prekpt + self.head_heatmap = self._make_heatmap_head(self.config_heatmap) + self.transition_offset = self._make_transition_for_head(self.head_inp_channels, offset_channels) + self.offset_feature_layers, self.offset_final_layer = self._make_separete_regression_head(self.config_offset) + def _make_transition_for_head(self, inplanes: int, outplanes: int) -> nn.Module: transition_layer = [nn.Conv2d(inplanes, outplanes, 1, 1, 0, bias=False), nn.BatchNorm2d(outplanes), nn.ReLU(True)] return nn.Sequential(*transition_layer) diff --git a/src/super_gradients/training/transforms/keypoint_transforms.py b/src/super_gradients/training/transforms/keypoint_transforms.py index b805d6315f..508eab1ab2 100644 --- a/src/super_gradients/training/transforms/keypoint_transforms.py +++ b/src/super_gradients/training/transforms/keypoint_transforms.py @@ -71,6 +71,13 @@ def get_equivalent_preprocessing(self) -> List: preprocessing += t.get_equivalent_preprocessing() return preprocessing + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += f"\t{repr(t)}" + format_string += "\n)" + return format_string + @register_transform(Transforms.KeypointsImageToTensor) class KeypointsImageToTensor(KeypointTransform): @@ -87,6 +94,9 @@ def get_equivalent_preprocessing(self) -> List: {Processings.ImagePermute: {"permutation": (2, 0, 1)}}, ] + def __repr__(self): + return self.__class__.__name__ + f"(permutation={self.permutation})" + @register_transform(Transforms.KeypointsImageStandardize) class KeypointsImageStandardize(KeypointTransform): @@ -107,6 +117,9 @@ def __call__(self, image: np.ndarray, mask: np.ndarray, joints: np.ndarray, area def get_equivalent_preprocessing(self) -> List[Dict]: return [{Processings.StandardizeImage: {"max_value": self.max_value}}] + def __repr__(self): + return self.__class__.__name__ + f"(max_value={self.max_value})" + @register_transform(Transforms.KeypointsImageNormalize) class KeypointsImageNormalize(KeypointTransform): @@ -122,6 +135,9 @@ def __call__(self, image: np.ndarray, mask: np.ndarray, joints: np.ndarray, area image = (image - self.mean) / self.std return image, mask, joints, areas, bboxes + def __repr__(self): + return self.__class__.__name__ + f"(mean={self.mean}, std={self.std})" + def get_equivalent_preprocessing(self) -> List: return [{Processings.NormalizeImage: {"mean": self.mean, "std": self.std}}] @@ -143,6 +159,9 @@ def __init__(self, flip_index: List[int], prob: float = 0.5): self.flip_index = flip_index self.prob = prob + def __repr__(self): + return self.__class__.__name__ + f"(flip_index={self.flip_index}, prob={self.prob})" + def __call__(self, image, mask, joints, areas: Optional[np.ndarray], bboxes: Optional[np.ndarray]): if image.shape[:2] != mask.shape[:2]: raise RuntimeError(f"Image shape ({image.shape[:2]}) does not match mask shape ({mask.shape[:2]}).") @@ -218,6 +237,9 @@ def apply_to_bboxes(self, bboxes, rows): def get_equivalent_preprocessing(self) -> List: raise RuntimeError("KeypointsRandomHorizontalFlip does not have equivalent preprocessing.") + def __repr__(self): + return self.__class__.__name__ + f"(prob={self.prob})" + @register_transform(Transforms.KeypointsLongestMaxSize) class KeypointsLongestMaxSize(KeypointTransform): @@ -278,6 +300,13 @@ def apply_to_keypoints(cls, keypoints, scale): def apply_to_bboxes(cls, bboxes, scale): return bboxes * scale + def __repr__(self): + return ( + self.__class__.__name__ + f"(max_height={self.max_height}, " + f"max_width={self.max_width}, " + f"interpolation={self.interpolation}, prob={self.prob})" + ) + def get_equivalent_preprocessing(self) -> List: return [{Processings.KeypointsLongestMaxSizeRescale: {"output_shape": (self.max_height, self.max_width)}}] @@ -318,6 +347,14 @@ def __call__(self, image, mask, joints, areas: Optional[np.ndarray], bboxes: Opt return image, mask, joints, areas, bboxes + def __repr__(self): + return ( + self.__class__.__name__ + f"(min_height={self.min_height}, " + f"min_width={self.min_width}, " + f"image_pad_value={self.image_pad_value}, " + f"mask_pad_value={self.mask_pad_value})" + ) + def get_equivalent_preprocessing(self) -> List: return [{Processings.KeypointsBottomRightPadding: {"output_shape": (self.min_height, self.min_width), "pad_value": self.image_pad_value}}] @@ -353,6 +390,17 @@ def __init__( self.mask_pad_value = mask_pad_value self.prob = prob + def __repr__(self): + return ( + self.__class__.__name__ + f"(max_rotation={self.max_rotation}, " + f"min_scale={self.min_scale}, " + f"max_scale={self.max_scale}, " + f"max_translate={self.max_translate}, " + f"image_pad_value={self.image_pad_value}, " + f"mask_pad_value={self.mask_pad_value}, " + f"prob={self.prob})" + ) + def _get_affine_matrix(self, img, angle, scale, dx, dy): """ diff --git a/tests/unit_tests/replace_head_test.py b/tests/unit_tests/replace_head_test.py index 4e32cd6136..753dfaa860 100644 --- a/tests/unit_tests/replace_head_test.py +++ b/tests/unit_tests/replace_head_test.py @@ -30,6 +30,13 @@ def test_yolo_nas_replace_head(self): (_, pred_scores), _ = model.forward(input) self.assertEqual(pred_scores.size(2), 100) + def test_dekr_replace_head(self): + input = torch.randn(1, 3, 640, 640).to(self.device) + model = models.get(Models.DEKR_W32_NO_DC, num_classes=20, pretrained_weights="coco_pose").to(self.device).eval() + heatmap, offsets = model.forward(input) + self.assertEqual(heatmap.size(1), 20 + 1) + self.assertEqual(offsets.size(1), 20 * 2) + def tearDown(self) -> None: if os.path.exists("~/.cache/torch/hub/"): shutil.rmtree("~/.cache/torch/hub/")