Skip to content

Commit

Permalink
predict() support for pose estimation models (#1142)
Browse files Browse the repository at this point in the history
* Adding predict

* Predict

* Predict

* Adding predict

* Adding predict

* Adding joint information to dataset configs

* Added makefile target recipe_accuracy_tests

* Remove temp files

* Rename variables for better clarity

* Move predict() related files to super_gradients.training.utils.predict

* Move predict() related files to super_gradients.training.utils.predict

* Rename file poses.py -> pose_estimation.py

* Rename joint_colors/joint_links -> edge_colors/edge_links

* Disable showing bounding box by default

* Allow passing edge & keypoints as None, in this case colors will be generated randomly

* Update docstrings

* Fix test

* Added unit tests to verify settings preprocesisng params from dataset works

* _pad_image cannot work with pad_value that is tuple (r,g,b).
So we change the keypoint transforms defaults in config to use single scalar value

* Fix pad_value in keypoints transforms to accept single scalar value to make compatible with _pad_image
  • Loading branch information
BloodAxe authored Jun 11, 2023
1 parent 9e77e19 commit a43cfcd
Show file tree
Hide file tree
Showing 28 changed files with 1,212 additions and 171 deletions.
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,11 @@ integration_tests:

yolo_nas_integration_tests:
python -m unittest tests/integration_tests/yolo_nas_integration_test.py

recipe_accuracy_tests:
python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test epochs=1 batch_size=4 val_batch_size=8 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_yolox experiment_name=shortened_coco2017_yolox_n_map_test epochs=10 architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/train_from_recipe.py --config-name=cityscapes_regseg48 experiment_name=shortened_cityscapes_regseg48_iou_test epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
coverage run --source=super_gradients -m unittest tests/deci_core_recipe_test_suite_runner.py
5 changes: 4 additions & 1 deletion src/super_gradients/common/object_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class Transforms:
# Keypoints
KeypointsRandomAffineTransform = "KeypointsRandomAffineTransform"
KeypointsImageNormalize = "KeypointsImageNormalize"
KeypointsImageStandardize = "KeypointsImageStandardize"
KeypointsImageToTensor = "KeypointsImageToTensor"
KeypointTransform = "KeypointTransform"
KeypointsPadIfNeeded = "KeypointsPadIfNeeded"
Expand Down Expand Up @@ -413,8 +414,10 @@ class Processings:
DetectionCenterPadding = "DetectionCenterPadding"
DetectionLongestMaxSizeRescale = "DetectionLongestMaxSizeRescale"
DetectionBottomRightPadding = "DetectionBottomRightPadding"
ImagePermute = "ImagePermute"
DetectionRescale = "DetectionRescale"
KeypointsLongestMaxSizeRescale = "KeypointsLongestMaxSizeRescale"
KeypointsBottomRightPadding = "KeypointsBottomRightPadding"
ImagePermute = "ImagePermute"
ReverseImageChannels = "ReverseImageChannels"
NormalizeImage = "NormalizeImage"
ComposeProcessing = "ComposeProcessing"
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
except (ModuleNotFoundError, ImportError, NameError):
pass # no action or logging - this is normal in most cases

from super_gradients.training.models.prediction_results import ImageDetectionPrediction, ImagesDetectionPrediction
from super_gradients.training.utils.predict import ImageDetectionPrediction, ImagesDetectionPrediction


def _visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPrediction, show_confidence: bool):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
num_classes: 17
hidden_channels: 256
num_layers: 2
joint_links:
edge_links:
- [ 0, 1 ]
- [ 0, 2 ]
- [ 1, 2 ]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ num_joints: 17
# OKs sigma values take from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py#L523
oks_sigmas: [0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 0.062, 1.007, 1.007, 0.087, 0.087, 0.089, 0.089]

flip_indexes_heatmap: [ 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 17]
flip_indexes_offset: [ 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15,]
flip_indexes: [ 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15,]

joint_links:
edge_links:
- [0, 1]
- [0, 2]
- [1, 2]
Expand All @@ -27,6 +26,47 @@ joint_links:
- [13, 15]
- [14, 16]

edge_colors:
- [214, 39, 40] # Nose -> LeftEye
- [148, 103, 189] # Nose -> RightEye
- [44, 160, 44] # LeftEye -> RightEye
- [140, 86, 75] # LeftEye -> LeftEar
- [227, 119, 194] # RightEye -> RightEar
- [127, 127, 127] # LeftEar -> LeftShoulder
- [188, 189, 34] # RightEar -> RightShoulder
- [127, 127, 127] # Shoulders
- [188, 189, 34] # LeftShoulder -> LeftElbow
- [140, 86, 75] # LeftTorso
- [23, 190, 207] # RightShoulder -> RightElbow
- [227, 119, 194] # RightTorso
- [31, 119, 180] # LeftElbow -> LeftArm
- [255, 127, 14] # RightElbow -> RightArm
- [148, 103, 189] # Waist
- [255, 127, 14] # Left Hip -> Left Knee
- [214, 39, 40] # Right Hip -> Right Knee
- [31, 119, 180] # Left Knee -> Left Ankle
- [44, 160, 44] # Right Knee -> Right Ankle


keypoint_colors:
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]


train_dataset_params:
data_dir: /data/coco # root path to coco data
Expand All @@ -36,6 +76,10 @@ train_dataset_params:
include_empty_samples: False
min_instance_area: 64

edge_links: ${dataset_params.edge_links}
edge_colors: ${dataset_params.edge_colors}
keypoint_colors: ${dataset_params.keypoint_colors}

transforms:
- KeypointsLongestMaxSize:
max_height: 640
Expand All @@ -44,29 +88,31 @@ train_dataset_params:
- KeypointsPadIfNeeded:
min_height: 640
min_width: 640
image_pad_value: [ 127, 127, 127 ]
image_pad_value: 127
mask_pad_value: 1

- KeypointsRandomHorizontalFlip:
# Note these indexes are COCO-specific. If you're using a different dataset, you'll need to change these accordingly.
flip_index: [ 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15 ]
flip_index: ${dataset_params.flip_indexes}
prob: 0.5

- KeypointsRandomAffineTransform:
max_rotation: 30
min_scale: 0.5
max_scale: 2
max_translate: 0.2
image_pad_value: [ 127, 127, 127 ]
image_pad_value: 127
mask_pad_value: 1
prob: 0.75

- KeypointsImageToTensor
- KeypointsImageStandardize:
max_value: 255

- KeypointsImageNormalize:
mean: [ 0.485, 0.456, 0.406 ]
std: [ 0.229, 0.224, 0.225 ]

- KeypointsImageToTensor

val_dataset_params:
data_dir: /data/coco/
Expand All @@ -75,6 +121,11 @@ val_dataset_params:
json_file: annotations/person_keypoints_val2017.json
include_empty_samples: True
min_instance_area: 128

edge_links: ${dataset_params.edge_links}
edge_colors: ${dataset_params.edge_colors}
keypoint_colors: ${dataset_params.keypoint_colors}

transforms:
- KeypointsLongestMaxSize:
max_height: 640
Expand All @@ -83,15 +134,17 @@ val_dataset_params:
- KeypointsPadIfNeeded:
min_height: 640
min_width: 640
image_pad_value: [ 127, 127, 127 ]
image_pad_value: 127
mask_pad_value: 1

- KeypointsImageToTensor
- KeypointsImageStandardize:
max_value: 255

- KeypointsImageNormalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
mean: [ 0.485, 0.456, 0.406 ]
std: [ 0.229, 0.224, 0.225 ]

- KeypointsImageToTensor

train_dataloader_params:
shuffle: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ num_joints: 17
# OKs sigma values take from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py#L523
oks_sigmas: [0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 0.062, 1.007, 1.007, 0.087, 0.087, 0.089, 0.089]

joint_links:
edge_links:
- [0, 1]
- [0, 2]
- [1, 2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def main(cfg: DictConfig) -> None:
)

# model = DEKRWrapper(model, apply_sigmoid=True).cuda().eval()
model = DEKRHorisontalFlipWrapper(model, cfg.dataset_params.flip_indexes_heatmap, cfg.dataset_params.flip_indexes_offset, apply_sigmoid=True).cuda().eval()
model = DEKRHorisontalFlipWrapper(model, cfg.dataset_params.flip_indexes, apply_sigmoid=True).cuda().eval()

post_prediction_callback = cfg.post_prediction_callback

Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import abc
from typing import Tuple, List, Mapping, Any, Dict
from typing import Tuple, List, Mapping, Any, Dict, Union

import numpy as np
import torch
from torch.utils.data.dataloader import default_collate, Dataset

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.object_names import Processings
from super_gradients.common.registry.registry import register_collate_function
from super_gradients.training.datasets.pose_estimation_datasets.target_generators import KeypointsTargetsGenerator
from super_gradients.training.transforms.keypoint_transforms import KeypointsCompose, KeypointTransform
from super_gradients.training.utils.visualization.utils import generate_color_mapping

logger = get_logger(__name__)

Expand All @@ -24,18 +26,30 @@ def __init__(
target_generator: KeypointsTargetsGenerator,
transforms: List[KeypointTransform],
min_instance_area: float,
num_joints: int,
edge_links: Union[List[Tuple[int, int]], np.ndarray],
edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
):
"""
:param target_generator: Target generator that will be used to generate the targets for the model.
See DEKRTargetsGenerator for an example.
:param transforms: Transforms to be applied to the image & keypoints
:param min_instance_area: Minimum area of an instance to be included in the dataset
:param num_joints: Number of joints to be predicted
:param edge_links: Edge links between joints
:param edge_colors: Color of the edge links. If None, the color will be generated randomly.
:param keypoint_colors: Color of the keypoints. If None, the color will be generated randomly.
"""
super().__init__()
self.target_generator = target_generator
self.transforms = KeypointsCompose(transforms)
self.min_instance_area = min_instance_area
self.num_joints = num_joints
self.edge_links = edge_links
self.edge_colors = edge_colors or generate_color_mapping(len(edge_links))
self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints)

@abc.abstractmethod
def __len__(self) -> int:
Expand Down Expand Up @@ -95,6 +109,21 @@ def filter_joints(self, joints: np.ndarray, image: np.ndarray) -> np.ndarray:

return joints

def get_dataset_preprocessing_params(self):
"""
:return:
"""
pipeline = self.transforms.get_equivalent_preprocessing()
params = dict(
conf=0.25,
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
edge_links=self.edge_links,
edge_colors=self.edge_colors,
keypoint_colors=self.keypoint_colors,
)
return params


@register_collate_function()
class KeypointsCollate:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Tuple, List, Mapping, Any
from typing import Tuple, List, Mapping, Any, Union

import cv2
import numpy as np
Expand All @@ -8,7 +8,7 @@
from torch import Tensor

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.object_names import Datasets
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.target_generator_factory import TargetGeneratorsFactory
Expand Down Expand Up @@ -37,6 +37,9 @@ def __init__(
target_generator,
transforms: List[KeypointTransform],
min_instance_area: float,
edge_links: Union[List[Tuple[int, int]], np.ndarray],
edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
):
"""
Expand All @@ -49,20 +52,32 @@ def __init__(
See DEKRTargetsGenerator for an example.
:param transforms: Transforms to be applied to the image & keypoints
:param min_instance_area: Minimum area of an instance to be included in the dataset
:param edge_links: Edge links between joints
:param edge_colors: Color of the edge links. If None, the color will be generated randomly.
:param keypoint_colors: Color of the keypoints. If None, the color will be generated randomly.
"""
super().__init__(transforms=transforms, target_generator=target_generator, min_instance_area=min_instance_area)
self.root = data_dir
self.images_dir = os.path.join(data_dir, images_dir)
self.json_file = os.path.join(data_dir, json_file)

coco = COCO(self.json_file)
json_file = os.path.join(data_dir, json_file)
coco = COCO(json_file)
if len(coco.dataset["categories"]) != 1:
raise ValueError("Dataset must contain exactly one category")

joints = coco.dataset["categories"][0]["keypoints"]
num_joints = len(joints)

super().__init__(
transforms=transforms,
target_generator=target_generator,
min_instance_area=min_instance_area,
num_joints=num_joints,
edge_links=edge_links,
edge_colors=edge_colors,
keypoint_colors=keypoint_colors,
)
self.root = data_dir
self.images_dir = os.path.join(data_dir, images_dir)
self.coco = coco
self.ids = list(self.coco.imgs.keys())
self.joints = coco.dataset["categories"][0]["keypoints"]
self.num_joints = len(self.joints)
self.joints = joints

if not include_empty_samples:
subset = [img_id for img_id in self.ids if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0]
Expand Down Expand Up @@ -190,3 +205,21 @@ def get_mask(self, anno, img_info) -> np.ndarray:
m += mask

return (m < 0.5).astype(np.float32)

def get_dataset_preprocessing_params(self):
"""
:return:
"""
# Since we are using cv2.imread to read images, our model in fact is trained on BGR images.
# In our pipelines the convention that input images are RGB, so we need to reverse the channels to get BGR
# to match with the expected input of the model.
pipeline = [Processings.ReverseImageChannels] + self.transforms.get_equivalent_preprocessing()
params = dict(
conf=0.25,
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
edge_links=self.edge_links,
edge_colors=self.edge_colors,
keypoint_colors=self.keypoint_colors,
)
return params
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from super_gradients.training.utils.utils import HpmStruct, arch_params_deprecated
from super_gradients.training.models.sg_module import SgModule
import super_gradients.common.factories.detection_modules_factory as det_factory
from super_gradients.training.models.prediction_results import ImagesDetectionPrediction
from super_gradients.training.utils.predict import ImagesDetectionPrediction
from super_gradients.training.pipelines.pipelines import DetectionPipeline
from super_gradients.training.processing.processing import Processing
from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from super_gradients.training.utils import HpmStruct
from super_gradients.training.models.arch_params_factory import get_arch_params
from super_gradients.training.models.detection_models.pp_yolo_e.post_prediction_callback import PPYoloEPostPredictionCallback, DetectionPostPredictionCallback
from super_gradients.training.models.prediction_results import ImagesDetectionPrediction
from super_gradients.training.utils.predict import ImagesDetectionPrediction
from super_gradients.training.pipelines.pipelines import DetectionPipeline
from super_gradients.training.processing.processing import Processing
from super_gradients.training.utils.media.image import ImageSource
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from super_gradients.training.utils import torch_version_is_greater_or_equal
from super_gradients.training.utils.detection_utils import non_max_suppression, matrix_non_max_suppression, NMS_Type, DetectionPostPredictionCallback, Anchors
from super_gradients.training.utils.utils import HpmStruct, check_img_size_divisibility, get_param
from super_gradients.training.models.prediction_results import ImagesDetectionPrediction
from super_gradients.training.utils.predict import ImagesDetectionPrediction
from super_gradients.training.pipelines.pipelines import DetectionPipeline
from super_gradients.training.processing.processing import Processing
from super_gradients.training.utils.media.image import ImageSource
Expand Down
Loading

0 comments on commit a43cfcd

Please sign in to comment.