Skip to content

Commit

Permalink
Fix DEKR's replace_head & improve __repr__ for keypoints transforms (#…
Browse files Browse the repository at this point in the history
…1227)

* YOLO-NAS Pose Estimation Experiment

* Added logging

* Remove test

* Tune recipe for L

* Tune recipe for S

* Tune optimizer params

* Lower LR

* Lower LR

* Added __repr__ for keypoint transforms to improve their printing in notebooks

* If OKS sigmas are not given explicitly, initialize with default 17 keypoints only if num_joints == 17, otherwise use default values and emit a warning

* Added recipe to train rescoring for yolo_nas_pose_l

* Added YOLO-NAS-POSE scores

* Added YOLO-NAS-POSE-M recipe

* Fine-tuning notebook for pose est

* Added lings to pretrained models

* Remove links to S & M models

* Update notebook

* Fix apply_sigmoid=True

* Increased eps to prevent divizion by zero

* Update notebook

* Adding predict

* Predict

* Predict

* Adding predict

* Adding predict

* Adding joint information to dataset configs

* Added makefile target recipe_accuracy_tests

* Remove temp files

* Rename variables for better clarity

* Move predict() related files to super_gradients.training.utils.predict

* Move predict() related files to super_gradients.training.utils.predict

* Rename file poses.py -> pose_estimation.py

* Rename joint_colors/joint_links -> edge_colors/edge_links

* Disable showing bounding box by default

* Allow passing edge & keypoints as None, in this case colors will be generated randomly

* Update docstrings

* Fix test

* Added unit tests to verify settings preprocesisng params from dataset works

* Added predict

* Added default prerprocessing settings for yolo-nas-pose

* Added default prerprocessing settings for yolo-nas-pose

* Added __repr__ to KeypointsImageToTensor

* _pad_image cannot work with pad_value that is tuple (r,g,b).
So we change the keypoint transforms defaults in config to use single scalar value

* _pad_image cannot work with pad_value that is tuple (r,g,b).
So we change the keypoint transforms defaults in config to use single scalar value

* _pad_image cannot work with pad_value that is tuple (r,g,b).
So we change the keypoint transforms defaults in config to use single scalar value

* Fix pad_value in keypoints transforms to accept single scalar value to make compatible with _pad_image

* Update signature of base YoloNasPose class (dropped arch_params)

* Simplify recipes to train YOLO-NAS-POSE

* Implement replace_head for dekr

* Implement replace_head for dekr

* Make more beautiful __repr__ implementation

* Change .format to string interpolation
  • Loading branch information
BloodAxe authored Jun 28, 2023
1 parent 41d455f commit ab2e792
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,14 @@ def __init__(
self.greater_component_is_better = dict((k, True) for k in self.stats_names)

if oks_sigmas is None:
oks_sigmas = np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) / 10.0
if num_joints == 17:
oks_sigmas = np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) / 10.0
else:
oks_sigmas = np.array([0.1] * num_joints)
logger.warning(
f"Using default OKS sigmas of `0.1` for a custom dataset with {num_joints} joints. "
f"To silence this warning, you may want to specify OKS sigmas explicitly as it has direct impact on the AP score."
)

if len(oks_sigmas) != num_joints:
raise ValueError(f"Length of oks_sigmas ({len(oks_sigmas)}) should be equal to num_joints {num_joints}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,22 +328,32 @@ def __init__(self, arch_params):
setattr(self, "stage{}".format(i + 2), stage)

# build head net
inp_channels = int(sum(self.stages_spec.NUM_CHANNELS[-1]))
config_heatmap = self.spec.HEAD_HEATMAP
config_offset = self.spec.HEAD_OFFSET
self.head_inp_channels = int(sum(self.stages_spec.NUM_CHANNELS[-1]))
self.config_heatmap = self.spec.HEAD_HEATMAP
self.config_offset = self.spec.HEAD_OFFSET
self.num_joints = arch_params.num_classes
self.num_offset = self.num_joints * 2
self.num_joints_with_center = self.num_joints + 1
self.offset_prekpt = config_offset["NUM_CHANNELS_PERKPT"]
self.offset_prekpt = self.config_offset["NUM_CHANNELS_PERKPT"]

offset_channels = self.num_joints * self.offset_prekpt
self.transition_heatmap = self._make_transition_for_head(inp_channels, config_heatmap["NUM_CHANNELS"])
self.transition_offset = self._make_transition_for_head(inp_channels, offset_channels)
self.head_heatmap = self._make_heatmap_head(config_heatmap)
self.offset_feature_layers, self.offset_final_layer = self._make_separete_regression_head(config_offset)
self.heatmap_activation = nn.Sigmoid() if config_heatmap["HEATMAP_APPLY_SIGMOID"] else nn.Identity()
self.transition_heatmap = self._make_transition_for_head(self.head_inp_channels, self.config_heatmap["NUM_CHANNELS"])
self.transition_offset = self._make_transition_for_head(self.head_inp_channels, offset_channels)
self.head_heatmap = self._make_heatmap_head(self.config_heatmap)
self.offset_feature_layers, self.offset_final_layer = self._make_separete_regression_head(self.config_offset)
self.heatmap_activation = nn.Sigmoid() if self.config_heatmap["HEATMAP_APPLY_SIGMOID"] else nn.Identity()
self.init_weights()

def replace_head(self, new_num_classes: int):
self.num_joints = new_num_classes
self.num_offset = new_num_classes * 2
self.num_joints_with_center = new_num_classes + 1

offset_channels = self.num_joints * self.offset_prekpt
self.head_heatmap = self._make_heatmap_head(self.config_heatmap)
self.transition_offset = self._make_transition_for_head(self.head_inp_channels, offset_channels)
self.offset_feature_layers, self.offset_final_layer = self._make_separete_regression_head(self.config_offset)

def _make_transition_for_head(self, inplanes: int, outplanes: int) -> nn.Module:
transition_layer = [nn.Conv2d(inplanes, outplanes, 1, 1, 0, bias=False), nn.BatchNorm2d(outplanes), nn.ReLU(True)]
return nn.Sequential(*transition_layer)
Expand Down
48 changes: 48 additions & 0 deletions src/super_gradients/training/transforms/keypoint_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def get_equivalent_preprocessing(self) -> List:
preprocessing += t.get_equivalent_preprocessing()
return preprocessing

def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += f"\t{repr(t)}"
format_string += "\n)"
return format_string


@register_transform(Transforms.KeypointsImageToTensor)
class KeypointsImageToTensor(KeypointTransform):
Expand All @@ -87,6 +94,9 @@ def get_equivalent_preprocessing(self) -> List:
{Processings.ImagePermute: {"permutation": (2, 0, 1)}},
]

def __repr__(self):
return self.__class__.__name__ + f"(permutation={self.permutation})"


@register_transform(Transforms.KeypointsImageStandardize)
class KeypointsImageStandardize(KeypointTransform):
Expand All @@ -107,6 +117,9 @@ def __call__(self, image: np.ndarray, mask: np.ndarray, joints: np.ndarray, area
def get_equivalent_preprocessing(self) -> List[Dict]:
return [{Processings.StandardizeImage: {"max_value": self.max_value}}]

def __repr__(self):
return self.__class__.__name__ + f"(max_value={self.max_value})"


@register_transform(Transforms.KeypointsImageNormalize)
class KeypointsImageNormalize(KeypointTransform):
Expand All @@ -122,6 +135,9 @@ def __call__(self, image: np.ndarray, mask: np.ndarray, joints: np.ndarray, area
image = (image - self.mean) / self.std
return image, mask, joints, areas, bboxes

def __repr__(self):
return self.__class__.__name__ + f"(mean={self.mean}, std={self.std})"

def get_equivalent_preprocessing(self) -> List:
return [{Processings.NormalizeImage: {"mean": self.mean, "std": self.std}}]

Expand All @@ -143,6 +159,9 @@ def __init__(self, flip_index: List[int], prob: float = 0.5):
self.flip_index = flip_index
self.prob = prob

def __repr__(self):
return self.__class__.__name__ + f"(flip_index={self.flip_index}, prob={self.prob})"

def __call__(self, image, mask, joints, areas: Optional[np.ndarray], bboxes: Optional[np.ndarray]):
if image.shape[:2] != mask.shape[:2]:
raise RuntimeError(f"Image shape ({image.shape[:2]}) does not match mask shape ({mask.shape[:2]}).")
Expand Down Expand Up @@ -218,6 +237,9 @@ def apply_to_bboxes(self, bboxes, rows):
def get_equivalent_preprocessing(self) -> List:
raise RuntimeError("KeypointsRandomHorizontalFlip does not have equivalent preprocessing.")

def __repr__(self):
return self.__class__.__name__ + f"(prob={self.prob})"


@register_transform(Transforms.KeypointsLongestMaxSize)
class KeypointsLongestMaxSize(KeypointTransform):
Expand Down Expand Up @@ -278,6 +300,13 @@ def apply_to_keypoints(cls, keypoints, scale):
def apply_to_bboxes(cls, bboxes, scale):
return bboxes * scale

def __repr__(self):
return (
self.__class__.__name__ + f"(max_height={self.max_height}, "
f"max_width={self.max_width}, "
f"interpolation={self.interpolation}, prob={self.prob})"
)

def get_equivalent_preprocessing(self) -> List:
return [{Processings.KeypointsLongestMaxSizeRescale: {"output_shape": (self.max_height, self.max_width)}}]

Expand Down Expand Up @@ -318,6 +347,14 @@ def __call__(self, image, mask, joints, areas: Optional[np.ndarray], bboxes: Opt

return image, mask, joints, areas, bboxes

def __repr__(self):
return (
self.__class__.__name__ + f"(min_height={self.min_height}, "
f"min_width={self.min_width}, "
f"image_pad_value={self.image_pad_value}, "
f"mask_pad_value={self.mask_pad_value})"
)

def get_equivalent_preprocessing(self) -> List:
return [{Processings.KeypointsBottomRightPadding: {"output_shape": (self.min_height, self.min_width), "pad_value": self.image_pad_value}}]

Expand Down Expand Up @@ -353,6 +390,17 @@ def __init__(
self.mask_pad_value = mask_pad_value
self.prob = prob

def __repr__(self):
return (
self.__class__.__name__ + f"(max_rotation={self.max_rotation}, "
f"min_scale={self.min_scale}, "
f"max_scale={self.max_scale}, "
f"max_translate={self.max_translate}, "
f"image_pad_value={self.image_pad_value}, "
f"mask_pad_value={self.mask_pad_value}, "
f"prob={self.prob})"
)

def _get_affine_matrix(self, img, angle, scale, dx, dy):
"""
Expand Down
7 changes: 7 additions & 0 deletions tests/unit_tests/replace_head_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def test_yolo_nas_replace_head(self):
(_, pred_scores), _ = model.forward(input)
self.assertEqual(pred_scores.size(2), 100)

def test_dekr_replace_head(self):
input = torch.randn(1, 3, 640, 640).to(self.device)
model = models.get(Models.DEKR_W32_NO_DC, num_classes=20, pretrained_weights="coco_pose").to(self.device).eval()
heatmap, offsets = model.forward(input)
self.assertEqual(heatmap.size(1), 20 + 1)
self.assertEqual(offsets.size(1), 20 * 2)

def tearDown(self) -> None:
if os.path.exists("~/.cache/torch/hub/"):
shutil.rmtree("~/.cache/torch/hub/")
Expand Down

0 comments on commit ab2e792

Please sign in to comment.