diff --git a/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py b/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py index 27a32eb933..446a4af398 100644 --- a/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py +++ b/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py @@ -25,6 +25,7 @@ from super_gradients.common.registry.registry import register_model from super_gradients.common.object_names import Models from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.module_interfaces import HasPredict from super_gradients.training.utils.predict import ImagesPoseEstimationPrediction from super_gradients.training.models.sg_module import SgModule from super_gradients.training.models.arch_params_factory import get_arch_params @@ -294,7 +295,7 @@ def forward(self, x): @register_model(Models.DEKR_CUSTOM) -class DEKRPoseEstimationModel(SgModule): +class DEKRPoseEstimationModel(SgModule, HasPredict): """ Implementation of HRNet model from DEKR paper (https://arxiv.org/abs/2104.02300). diff --git a/src/super_gradients/training/transforms/keypoint_transforms.py b/src/super_gradients/training/transforms/keypoint_transforms.py index 508eab1ab2..20ceab29f5 100644 --- a/src/super_gradients/training/transforms/keypoint_transforms.py +++ b/src/super_gradients/training/transforms/keypoint_transforms.py @@ -95,7 +95,7 @@ def get_equivalent_preprocessing(self) -> List: ] def __repr__(self): - return self.__class__.__name__ + f"(permutation={self.permutation})" + return self.__class__.__name__ + "()" @register_transform(Transforms.KeypointsImageStandardize) diff --git a/tests/unit_tests/test_predict.py b/tests/unit_tests/test_predict.py index 248c63289f..0b8fc3dec9 100644 --- a/tests/unit_tests/test_predict.py +++ b/tests/unit_tests/test_predict.py @@ -23,6 +23,22 @@ def test_classification_models(self): predictions.show() predictions.save(output_folder=tmp_dirname) + def test_pose_estimation_models(self): + model = models.get(Models.DEKR_W32_NO_DC, pretrained_weights="coco_pose") + + with tempfile.TemporaryDirectory() as tmp_dirname: + predictions = model.predict(self.images) + predictions.show() + predictions.save(output_folder=tmp_dirname) + + def test_detection_models(self): + model = models.get(Models.YOLO_NAS_S, pretrained_weights="coco") + + with tempfile.TemporaryDirectory() as tmp_dirname: + predictions = model.predict(self.images) + predictions.show() + predictions.save(output_folder=tmp_dirname) + if __name__ == "__main__": unittest.main()