From 82c03df75ba7cf0dfed8208a255571420efd62c6 Mon Sep 17 00:00:00 2001 From: Louis-Dupont <35190946+Louis-Dupont@users.noreply.github.com> Date: Mon, 3 Apr 2023 12:45:25 +0300 Subject: [PATCH] Feature/sg 747 add preprocessing (#804) * wip * move to imageprocessors * wip * add back changes * making it work fully for yolox and almost for ppyoloe * minor change * working for det * cleaning * clean * undo * replace empty with none * add _get_shift_params * minor doc change * replace pydantic with dataclasses and fix typing * add docstrings * doc improvment and use get_shift_params in transforms * add tests * improve comment * rename * add option to keep ratio in rescale * make functions private * remove DetectionPaddedRescale * fix doc * add fixes * improve _get_center_padding_params output * minor fix * add empty bbox test for rescale_bboxes * finalizing _DetectionPadding, DetectionCenterPadding and DetectionBottomRightPadding * remove _pad_to_side * split rescale into 2 classes * minor addition * Add DetectionPrediction object * simplify DetectionPrediction class * add round and don't rescale if no change required --------- Co-authored-by: Eugene Khvedchenya --- .../arch_params/yolox_s_arch_params.yaml | 2 +- .../tiny_imagenet_dataset_params.yaml | 2 +- .../training/models/predictions.py | 41 ++++ .../training/transforms/processing.py | 204 ++++++++++++++++++ .../training/transforms/transforms.py | 167 ++++---------- .../training/transforms/utils.py | 144 +++++++++++++ .../training/utils/detection_utils.py | 12 +- tests/unit_tests/transforms_test.py | 172 +++++++++++++++ 8 files changed, 607 insertions(+), 137 deletions(-) create mode 100644 src/super_gradients/training/models/predictions.py create mode 100644 src/super_gradients/training/transforms/processing.py create mode 100644 src/super_gradients/training/transforms/utils.py diff --git a/src/super_gradients/recipes/arch_params/yolox_s_arch_params.yaml b/src/super_gradients/recipes/arch_params/yolox_s_arch_params.yaml index eaebbcabed..d2bde90300 100644 --- a/src/super_gradients/recipes/arch_params/yolox_s_arch_params.yaml +++ b/src/super_gradients/recipes/arch_params/yolox_s_arch_params.yaml @@ -9,4 +9,4 @@ anchors: yolo_type: 'yoloX' depth_mult_factor: 0.33 -width_mult_factor: 0.5 \ No newline at end of file +width_mult_factor: 0.5 diff --git a/src/super_gradients/recipes/dataset_params/tiny_imagenet_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/tiny_imagenet_dataset_params.yaml index 68b54cca29..4c7d6120e8 100644 --- a/src/super_gradients/recipes/dataset_params/tiny_imagenet_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/tiny_imagenet_dataset_params.yaml @@ -24,4 +24,4 @@ val_dataset_params: mean: [0.4802, 0.4481, 0.3975] std: [0.2770, 0.2691, 0.2821] -_convert_: all \ No newline at end of file +_convert_: all diff --git a/src/super_gradients/training/models/predictions.py b/src/super_gradients/training/models/predictions.py new file mode 100644 index 0000000000..e493ab0a9d --- /dev/null +++ b/src/super_gradients/training/models/predictions.py @@ -0,0 +1,41 @@ +from typing import Tuple +from abc import ABC +from dataclasses import dataclass + +import numpy as np + +from super_gradients.common.factories.bbox_format_factory import BBoxFormatFactory +from super_gradients.training.datasets.data_formats.bbox_formats import convert_bboxes + + +@dataclass +class Prediction(ABC): + pass + + +@dataclass +class DetectionPrediction(Prediction): + """Represents a detection prediction, with bboxes represented in xyxy format.""" + + bboxes_xyxy: np.ndarray + confidence: np.ndarray + labels: np.ndarray + + def __init__(self, bboxes: np.ndarray, bbox_format: str, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]): + """ + :param bboxes: BBoxes in the format specified by bbox_format + :param bbox_format: BBoxes format that can be a string ("xyxy", "cxywh", ...) + :param confidence: Confidence scores for each bounding box + :param labels: Labels for each bounding box. + :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format + """ + factory = BBoxFormatFactory() + self.bboxes_xyxy = convert_bboxes( + bboxes=bboxes, + image_shape=image_shape, + source_format=factory.get(bbox_format), + target_format=factory.get("xyxy"), + inplace=False, + ) + self.confidence = confidence + self.labels = labels diff --git a/src/super_gradients/training/transforms/processing.py b/src/super_gradients/training/transforms/processing.py new file mode 100644 index 0000000000..a74cb700a7 --- /dev/null +++ b/src/super_gradients/training/transforms/processing.py @@ -0,0 +1,204 @@ +from typing import Tuple, List, Union +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import numpy as np + +from super_gradients.training.models.predictions import Prediction, DetectionPrediction +from super_gradients.training.transforms.utils import ( + _rescale_image, + _rescale_bboxes, + _get_center_padding_coordinates, + _get_bottom_right_padding_coordinates, + _pad_image, + _shift_bboxes, + PaddingCoordinates, +) + + +@dataclass +class ProcessingMetadata(ABC): + """Metadata including information to postprocess a prediction.""" + + +@dataclass +class ComposeProcessingMetadata(ProcessingMetadata): + metadata_lst: List[Union[None, ProcessingMetadata]] + + +@dataclass +class DetectionPadToSizeMetadata(ProcessingMetadata): + padding_coordinates: PaddingCoordinates + + +@dataclass +class RescaleMetadata(ProcessingMetadata): + original_shape: Tuple[int, int] + scale_factor_h: float + scale_factor_w: float + + +class Processing(ABC): + """Interface for preprocessing and postprocessing methods that are + used to prepare images for a model and process the model's output. + + Subclasses should implement the `preprocess_image` and `postprocess_predictions` + methods according to the specific requirements of the model and task. + """ + + @abstractmethod + def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, Union[None, ProcessingMetadata]]: + """Processing an image, before feeding it to the network. Expected to be in (H, W, C) or (H, W).""" + pass + + @abstractmethod + def postprocess_predictions(self, predictions: Prediction, metadata: Union[None, ProcessingMetadata]) -> Prediction: + """Postprocess the model output predictions.""" + pass + + +class ComposeProcessing(Processing): + """Compose a list of Processing objects into a single Processing object.""" + + def __init__(self, processings: List[Processing]): + self.processings = processings + + def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, ComposeProcessingMetadata]: + """Processing an image, before feeding it to the network.""" + processed_image, metadata_lst = image.copy(), [] + for processing in self.processings: + processed_image, metadata = processing.preprocess_image(image=processed_image) + metadata_lst.append(metadata) + return processed_image, ComposeProcessingMetadata(metadata_lst=metadata_lst) + + def postprocess_predictions(self, predictions: Prediction, metadata: ComposeProcessingMetadata) -> Prediction: + """Postprocess the model output predictions.""" + postprocessed_predictions = predictions + for processing, metadata in zip(self.processings[::-1], metadata.metadata_lst[::-1]): + postprocessed_predictions = processing.postprocess_predictions(postprocessed_predictions, metadata) + return postprocessed_predictions + + +class ImagePermute(Processing): + """Permute the image dimensions. + + :param permutation: Specify new order of dims. Default value (2, 0, 1) suitable for converting from HWC to CHW format. + """ + + def __init__(self, permutation: Tuple[int, int, int] = (2, 0, 1)): + self.permutation = permutation + + def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]: + processed_image = np.ascontiguousarray(image.transpose(*self.permutation)) + return processed_image, None + + def postprocess_predictions(self, predictions: Prediction, metadata: None) -> Prediction: + return predictions + + +class NormalizeImage(Processing): + """Normalize an image based on means and standard deviation. + + :param mean: Mean values for each channel. + :param std: Standard deviation values for each channel. + """ + + def __init__(self, mean: List[float], std: List[float]): + self.mean = np.array(mean).reshape((1, 1, -1)).astype(np.float32) + self.std = np.array(std).reshape((1, 1, -1)).astype(np.float32) + + def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]: + return (image - self.mean) / self.std, None + + def postprocess_predictions(self, predictions: Prediction, metadata: None) -> Prediction: + return predictions + + +class _DetectionPadding(Processing, ABC): + """Base class for detection padding methods. One should implement the `_get_padding_params` method to work with a custom padding method. + + Note: This transformation assume that dimensions of input image is equal or less than `output_shape`. + + :param output_shape: Output image shape (H, W) + :param pad_value: Padding value for image + """ + + def __init__(self, output_shape: Tuple[int, int], pad_value: int): + self.output_shape = output_shape + self.pad_value = pad_value + + def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, DetectionPadToSizeMetadata]: + padding_coordinates = self._get_padding_params(input_shape=image.shape) + processed_image = _pad_image(image=image, padding_coordinates=padding_coordinates, pad_value=self.pad_value) + return processed_image, DetectionPadToSizeMetadata(padding_coordinates=padding_coordinates) + + def postprocess_predictions(self, predictions: DetectionPrediction, metadata: DetectionPadToSizeMetadata) -> DetectionPrediction: + predictions.bboxes_xyxy = _shift_bboxes( + targets=predictions.bboxes_xyxy, + shift_h=-metadata.padding_coordinates.top, + shift_w=-metadata.padding_coordinates.left, + ) + return predictions + + @abstractmethod + def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates: + pass + + +class DetectionCenterPadding(_DetectionPadding): + def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates: + return _get_center_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape) + + +class DetectionBottomRightPadding(_DetectionPadding): + def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates: + return _get_bottom_right_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape) + + +class _Rescale(Processing, ABC): + """Resize image to given image dimensions WITHOUT preserving aspect ratio. + + :param output_shape: (H, W) + """ + + def __init__(self, output_shape: Tuple[int, int]): + self.output_shape = output_shape + + def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, RescaleMetadata]: + + scale_factor_h, scale_factor_w = self.output_shape[0] / image.shape[0], self.output_shape[1] / image.shape[1] + rescaled_image = _rescale_image(image, target_shape=self.output_shape) + + return rescaled_image, RescaleMetadata(original_shape=image.shape[:2], scale_factor_h=scale_factor_h, scale_factor_w=scale_factor_w) + + +class _LongestMaxSizeRescale(Processing, ABC): + """Resize image to given image dimensions WITH preserving aspect ratio. + + :param output_shape: (H, W) + """ + + def __init__(self, output_shape: Tuple[int, int]): + self.output_shape = output_shape + + def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, RescaleMetadata]: + height, width = image.shape[:2] + scale_factor = min(self.output_shape[0] / height, self.output_shape[1] / width) + + if scale_factor != 1.0: + new_height, new_width = round(height * scale_factor), round(width * scale_factor) + image = _rescale_image(image, target_shape=(new_height, new_width)) + + return image, RescaleMetadata(original_shape=(height, width), scale_factor_h=scale_factor, scale_factor_w=scale_factor) + + +class DetectionRescale(_Rescale): + def postprocess_predictions(self, predictions: DetectionPrediction, metadata: RescaleMetadata) -> DetectionPrediction: + predictions.bboxes_xyxy = _rescale_bboxes(targets=predictions.bboxes_xyxy, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w)) + return predictions + + +class DetectionLongestMaxSizeRescale(_LongestMaxSizeRescale): + def postprocess_predictions(self, predictions: DetectionPrediction, metadata: RescaleMetadata) -> DetectionPrediction: + predictions.bboxes_xyxy = _rescale_bboxes(targets=predictions.bboxes_xyxy, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w)) + return predictions diff --git a/src/super_gradients/training/transforms/transforms.py b/src/super_gradients/training/transforms/transforms.py index 205b4f7487..288393f5ab 100644 --- a/src/super_gradients/training/transforms/transforms.py +++ b/src/super_gradients/training/transforms/transforms.py @@ -15,13 +15,22 @@ from super_gradients.common.registry.registry import register_transform from super_gradients.common.decorators.factory_decorator import resolve_param from super_gradients.common.factories.data_formats_factory import ConcatenatedTensorFormatFactory -from super_gradients.training.utils.detection_utils import get_mosaic_coordinate, adjust_box_anns, xyxy2cxcywh, cxcywh2xyxy, DetectionTargetsFormat +from super_gradients.training.utils.detection_utils import get_mosaic_coordinate, adjust_box_anns, DetectionTargetsFormat from super_gradients.training.datasets.data_formats import ConcatenatedTensorFormatConverter from super_gradients.training.datasets.data_formats.formats import filter_on_bboxes, ConcatenatedTensorFormat from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL, LABEL_CXCYWH - -image_resample = Image.BILINEAR -mask_resample = Image.NEAREST +from super_gradients.training.transforms.utils import ( + _rescale_and_pad_to_size, + _rescale_image, + _rescale_bboxes, + _get_center_padding_coordinates, + _pad_image, + _shift_bboxes, + _rescale_xyxy_bboxes, +) + +IMAGE_RESAMPLE_MODE = Image.BILINEAR +MASK_RESAMPLE_MODE = Image.NEAREST logger = get_logger(__name__) @@ -43,8 +52,8 @@ def __init__(self, h, w): def __call__(self, sample): image = sample["image"] mask = sample["mask"] - sample["image"] = image.resize((self.w, self.h), image_resample) - sample["mask"] = mask.resize((self.w, self.h), mask_resample) + sample["image"] = image.resize((self.w, self.h), IMAGE_RESAMPLE_MODE) + sample["mask"] = mask.resize((self.w, self.h), MASK_RESAMPLE_MODE) return sample @@ -106,8 +115,8 @@ def __call__(self, sample: dict) -> dict: out_size = int(scale * w), int(scale * h) - image = image.resize(out_size, image_resample) - mask = mask.resize(out_size, mask_resample) + image = image.resize(out_size, IMAGE_RESAMPLE_MODE) + mask = mask.resize(out_size, MASK_RESAMPLE_MODE) sample["image"] = image sample["mask"] = mask @@ -149,8 +158,8 @@ def __call__(self, sample: dict) -> dict: scale = random.uniform(self.scales[0], self.scales[1]) out_size = int(scale * w), int(scale * h) - image = image.resize(out_size, image_resample) - mask = mask.resize(out_size, mask_resample) + image = image.resize(out_size, IMAGE_RESAMPLE_MODE) + mask = mask.resize(out_size, MASK_RESAMPLE_MODE) sample["image"] = image sample["mask"] = mask @@ -194,8 +203,8 @@ def __call__(self, sample: dict) -> dict: mask = sample["mask"] deg = random.uniform(self.min_deg, self.max_deg) - image = image.rotate(deg, resample=image_resample, fillcolor=self.fill_image) - mask = mask.rotate(deg, resample=mask_resample, fillcolor=self.fill_mask) + image = image.rotate(deg, resample=IMAGE_RESAMPLE_MODE, fillcolor=self.fill_image) + mask = mask.rotate(deg, resample=MASK_RESAMPLE_MODE, fillcolor=self.fill_mask) sample["image"] = image sample["mask"] = mask @@ -290,10 +299,9 @@ class SegPadShortToCropSize(SegmentationTransform): def __init__(self, crop_size: Union[float, Tuple, List], fill_mask: int = 0, fill_image: Union[int, Tuple, List] = 0): """ - :param crop_size: tuple of (width, height) for the final crop size, if is scalar size is a - square (crop_size, crop_size) - :param fill_mask: value to fill mask labels background. - :param fill_image: grey value to fill image padded background. + :param crop_size: Tuple of (width, height) for the final crop size, if is scalar size is a square (crop_size, crop_size) + :param fill_mask: Value to fill mask labels background. + :param fill_image: Grey value to fill image padded background. """ # CHECK IF CROP SIZE IS A ITERABLE OR SCALAR self.crop_size = crop_size @@ -731,46 +739,15 @@ def __init__(self, output_size: Tuple[int, int], pad_value: int): self.pad_value = pad_value def __call__(self, sample: dict) -> dict: - img, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target") - img, shift_w, shift_h = self._apply_to_image(img, final_shape=self.output_size, pad_value=self.pad_value) - sample["image"] = img - sample["target"] = self._apply_to_bboxes(targets, shift_w, shift_h) + image, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target") + padding_coordinates = _get_center_padding_coordinates(input_shape=image.shape, output_shape=self.output_size) + + sample["image"] = _pad_image(image=image, padding_coordinates=padding_coordinates, pad_value=self.pad_value) + sample["target"] = _shift_bboxes(targets=targets, shift_w=padding_coordinates.left, shift_h=padding_coordinates.top) if crowd_targets is not None: - sample["crowd_target"] = self._apply_to_bboxes(crowd_targets, shift_w, shift_h) + sample["crowd_target"] = _shift_bboxes(targets=crowd_targets, shift_w=padding_coordinates.left, shift_h=padding_coordinates.top) return sample - def _apply_to_bboxes(self, targets: np.array, shift_w: float, shift_h: float) -> np.array: - """Translate bboxes with respect to padding values. - - :param targets: Bboxes to transform of shape (N, 5). - Bboxes expected to have format [x1, y1, x2, y2, class_id, ...] - :param shift_w: shift width in pixels - :param shift_h: shift height in pixels - :return: Bboxes to transform of shape (N, 5) - Bboxes will have same format [x1, y1, x2, y2, class_id, ...] - """ - targets = targets.copy() if len(targets) > 0 else np.zeros((0, 5), dtype=np.float32) - boxes, labels = targets[:, :4], targets[:, 4:] - boxes[:, [0, 2]] += shift_w - boxes[:, [1, 3]] += shift_h - return np.concatenate((boxes, labels), 1) - - def _apply_to_image(self, image, final_shape: Tuple[int, int], pad_value: int): - """ - Pad image to final_shape. - :param image: - :param final_shape: Output image size (rows, cols). - :param pad_value: - :return: - """ - pad_h, pad_w = final_shape[0] - image.shape[0], final_shape[1] - image.shape[1] - shift_h, shift_w = pad_h // 2, pad_w // 2 - pad_h = (shift_h, pad_h - shift_h) - pad_w = (shift_w, pad_w - shift_w) - - image = np.pad(image, (pad_h, pad_w, (0, 0)), "constant", constant_values=pad_value) - return image, shift_w, shift_h - @register_transform(Transforms.DetectionPaddedRescale) class DetectionPaddedRescale(DetectionTransform): @@ -794,30 +771,14 @@ def __init__(self, input_dim: Tuple, swap: Tuple[int, ...] = (2, 0, 1), max_targ def __call__(self, sample: dict) -> dict: img, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target") - img, r = rescale_and_pad_to_size(img, self.input_dim, self.swap, self.pad_value) + img, r = _rescale_and_pad_to_size(img, self.input_dim, self.swap, self.pad_value) sample["image"] = img - sample["target"] = self._rescale_target(targets, r) + sample["target"] = _rescale_xyxy_bboxes(targets, r) if crowd_targets is not None: - sample["crowd_target"] = self._rescale_target(crowd_targets, r) + sample["crowd_target"] = _rescale_xyxy_bboxes(crowd_targets, r) return sample - def _rescale_target(self, targets: np.array, r: float) -> np.array: - """SegRescale the target according to a coefficient used to rescale the image. - This is done to have images and targets at the same scale. - - :param targets: Targets to rescale, shape (batch_size, 6) - :param r: SegRescale coefficient that was applied to the image - - :return: Rescaled targets, shape (batch_size, 6) - """ - targets = targets.copy() if len(targets) > 0 else np.zeros((self.max_targets, 5), dtype=np.float32) - boxes, labels = targets[:, :4], targets[:, 4] - boxes = xyxy2cxcywh(boxes) - boxes *= r - boxes = cxcywh2xyxy(boxes) - return np.concatenate((boxes, labels[:, np.newaxis]), 1) - @register_transform(Transforms.DetectionHorizontalFlip) class DetectionHorizontalFlip(DetectionTransform): @@ -859,40 +820,16 @@ def __init__(self, output_shape: Tuple[int, int]): self.output_shape = output_shape def __call__(self, sample: dict) -> dict: - img, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target") + image, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target") - img_resized, scale_factors = self._rescale_image(img) + sy, sx = (self.output_shape[0] / image.shape[0], self.output_shape[1] / image.shape[1]) - sample["image"] = img_resized - sample["target"] = self._rescale_target(targets, scale_factors) + sample["image"] = _rescale_image(image=image, target_shape=self.output_shape) + sample["target"] = _rescale_bboxes(targets, scale_factors=(sy, sx)) if crowd_targets is not None: - sample["crowd_target"] = self._rescale_target(crowd_targets, scale_factors) + sample["crowd_target"] = _rescale_bboxes(crowd_targets, scale_factors=(sy, sx)) return sample - def _rescale_image(self, image): - sy, sx = self.output_shape[0] / image.shape[0], self.output_shape[1] / image.shape[1] - resized_img = cv2.resize( - image, - dsize=(int(self.output_shape[1]), int(self.output_shape[0])), - interpolation=cv2.INTER_LINEAR, - ) - scale_factors = sy, sx - return resized_img, scale_factors - - def _rescale_target(self, targets: np.array, scale_factors: Tuple[float, float]) -> np.array: - """SegRescale the target according to a coefficient used to rescale the image. - This is done to have images and targets at the same scale. - - :param targets: Target XYXY bboxes to rescale, shape (num_boxes, 5) - :param r: SegRescale coefficient that was applied to the image - - :return: Rescaled targets, shape (num_boxes, 5) - """ - sy, sx = scale_factors - targets = targets.astype(np.float32, copy=True) if len(targets) > 0 else np.zeros((0, 5), dtype=np.float32) - targets[:, 0:4] *= np.array([[sx, sy, sx, sy]], dtype=targets.dtype) - return targets - @register_transform(Transforms.DetectionRandomRotate90) class DetectionRandomRotate90(DetectionTransform): @@ -1335,34 +1272,6 @@ def augment_hsv(img: np.array, hgain: float, sgain: float, vgain: float, bgr_cha img[..., bgr_channels] = cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR) # no return needed -def rescale_and_pad_to_size(img, input_size, swap=(2, 0, 1), pad_val=114): - """ - Rescales image according to minimum ratio between the target height /image height, target width / image width, - and pads the image to the target size. - - :param img: Image to be rescaled - :param input_size: Target size - :param swap: Axis's to be rearranged. - :return: rescaled image, ratio - """ - if len(img.shape) == 3: - padded_img = np.ones((input_size[0], input_size[1], img.shape[-1]), dtype=np.uint8) * pad_val - else: - padded_img = np.ones(input_size, dtype=np.uint8) * pad_val - - r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) - resized_img = cv2.resize( - img, - (int(img.shape[1] * r), int(img.shape[0] * r)), - interpolation=cv2.INTER_LINEAR, - ).astype(np.uint8) - padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img - - padded_img = padded_img.transpose(swap) - padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) - return padded_img, r - - @register_transform(Transforms.Standardize) class Standardize(torch.nn.Module): """ diff --git a/src/super_gradients/training/transforms/utils.py b/src/super_gradients/training/transforms/utils.py new file mode 100644 index 0000000000..7379569b93 --- /dev/null +++ b/src/super_gradients/training/transforms/utils.py @@ -0,0 +1,144 @@ +from typing import Tuple +from dataclasses import dataclass +import cv2 + +import numpy as np + +from super_gradients.training.utils.detection_utils import xyxy2cxcywh, cxcywh2xyxy + + +@dataclass +class PaddingCoordinates: + top: int + bottom: int + left: int + right: int + + +def _rescale_image(image: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray: + """Rescale image to target_shape, without preserving aspect ratio. + + :param image: Image to rescale. (H, W, C) or (H, W). + :param target_shape: Target shape to rescale to. + :return: Rescaled image. + """ + height, width = target_shape[:2] + return cv2.resize(image, dsize=(width, height), interpolation=cv2.INTER_LINEAR).astype(np.uint8) + + +def _rescale_bboxes(targets: np.array, scale_factors: Tuple[float, float]) -> np.array: + """Rescale bboxes to given scale factors, without preserving aspect ratio. + + :param targets: Targets to rescale (N, 4+), where target[:, :4] is the bounding box coordinates. + :param scale_factors: Tuple of (scale_factor_h, scale_factor_w) scale factors to rescale to. + :return: Rescaled targets. + """ + + targets = targets.astype(np.float32, copy=True) + + sy, sx = scale_factors + targets[:, :4] *= np.array([[sx, sy, sx, sy]], dtype=targets.dtype) + return targets + + +def _get_center_padding_coordinates(input_shape: Tuple[int, int], output_shape: Tuple[int, int]) -> PaddingCoordinates: + """Get parameters for padding an image to given output shape, in center mode. + + :param input_shape: Shape of the input image. + :param output_shape: Shape to resize to. + :return: Padding parameters. + """ + pad_height, pad_width = output_shape[0] - input_shape[0], output_shape[1] - input_shape[1] + + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top + + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + + return PaddingCoordinates(top=pad_top, bottom=pad_bottom, left=pad_left, right=pad_right) + + +def _get_bottom_right_padding_coordinates(input_shape: Tuple[int, int], output_shape: Tuple[int, int]) -> PaddingCoordinates: + """Get parameters for padding an image to given output shape, in bottom right mode + (i.e. image will be at top-left while bottom-right corner will be padded). + + :param input_shape: Shape of the input image. + :param output_shape: Shape to resize to. + :return: Padding parameters. + """ + pad_height, pad_width = output_shape[0] - input_shape[0], output_shape[1] - input_shape[1] + return PaddingCoordinates(top=0, bottom=pad_height, left=0, right=pad_width) + + +def _pad_image(image: np.ndarray, padding_coordinates: PaddingCoordinates, pad_value: int) -> np.ndarray: + """Pad an image. + + :param image: Image to shift. (H, W, C) or (H, W). + :param pad_h: Tuple of (padding_top, padding_bottom). + :param pad_w: Tuple of (padding_left, padding_right). + :param pad_value: Padding value + :return: Image shifted according to padding coordinates. + """ + pad_h = (padding_coordinates.top, padding_coordinates.bottom) + pad_w = (padding_coordinates.left, padding_coordinates.right) + if len(image.shape) == 3: + return np.pad(image, (pad_h, pad_w, (0, 0)), "constant", constant_values=pad_value) + else: + return np.pad(image, (pad_h, pad_w), "constant", constant_values=pad_value) + + +def _shift_bboxes(targets: np.array, shift_w: float, shift_h: float) -> np.array: + """Shift bboxes with respect to padding values. + + :param targets: Bboxes to transform of shape (N, 4+), in format [x1, y1, x2, y2, ...] + :param shift_w: shift width. + :param shift_h: shift height. + :return: Bboxes transformed of shape (N, 4+), in format [x1, y1, x2, y2, ...] + """ + boxes, labels = targets[:, :4], targets[:, 4:] + boxes[:, [0, 2]] += shift_w + boxes[:, [1, 3]] += shift_h + return np.concatenate((boxes, labels), 1) + + +def _rescale_xyxy_bboxes(targets: np.array, r: float) -> np.array: + """Scale targets to given scale factors. + + :param targets: Bboxes to transform of shape (N, 4+), in format [x1, y1, x2, y2, ...] + :param r: DetectionRescale coefficient that was applied to the image + :return: Rescaled Bboxes to transform of shape (N, 4+), in format [x1, y1, x2, y2, ...] + """ + targets = targets.copy() + boxes, targets = targets[:, :4], targets[:, 4:] + boxes = xyxy2cxcywh(boxes) + boxes *= r + boxes = cxcywh2xyxy(boxes) + return np.concatenate((boxes, targets), 1) + + +def _rescale_and_pad_to_size(image: np.ndarray, output_shape: Tuple[int, int], swap: Tuple[int] = (2, 0, 1), pad_val: int = 114) -> Tuple[np.ndarray, float]: + """ + Rescales image according to minimum ratio input height/width and output height/width rescaled_padded_image, + pads the image to the target shape and finally swap axis. + Note: Pads the image to corner, padding is not centered. + + :param image: Image to be rescaled. (H, W, C) or (H, W). + :param output_shape: Target Shape. + :param swap: Axis's to be rearranged. + :param pad_val: Value to use for padding. + :return: + - Rescaled image while preserving aspect ratio, padded to fit output_shape and with axis swapped. By default, (C, H, W). + - Minimum ratio between the input height/width and output height/width. + """ + r = min(output_shape[0] / image.shape[0], output_shape[1] / image.shape[1]) + rescale_shape = (int(image.shape[0] * r), int(image.shape[1] * r)) + + resized_image = _rescale_image(image=image, target_shape=rescale_shape) + + padding_coordinates = _get_bottom_right_padding_coordinates(input_shape=rescale_shape, output_shape=output_shape) + padded_image = _pad_image(image=resized_image, padding_coordinates=padding_coordinates, pad_value=pad_val) + + padded_image = padded_image.transpose(swap) + padded_image = np.ascontiguousarray(padded_image, dtype=np.float32) + return padded_image, r diff --git a/src/super_gradients/training/utils/detection_utils.py b/src/super_gradients/training/utils/detection_utils.py index b830bcae69..fd34996eac 100755 --- a/src/super_gradients/training/utils/detection_utils.py +++ b/src/super_gradients/training/utils/detection_utils.py @@ -59,9 +59,9 @@ def _set_batch_labels_index(labels_batch): return labels_batch -def convert_xywh_bbox_to_xyxy(input_bbox: torch.Tensor): +def convert_cxcywh_bbox_to_xyxy(input_bbox: torch.Tensor): """ - Converts bounding box format from [x, y, w, h] to [x1, y1, x2, y2] + Converts bounding box format from [cx, cy, w, h] to [x1, y1, x2, y2] :param input_bbox: input bbox either 2-dimensional (for all boxes of a single image) or 3-dimensional (for boxes of a batch of images) :return: Converted bbox in same dimensions as the original @@ -234,7 +234,7 @@ def box_area(box): def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label_per_box: bool = True, with_confidence: bool = False): """ Performs Non-Maximum Suppression (NMS) on inference results - :param prediction: raw model prediction + :param prediction: raw model prediction. Should be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...) :param conf_thres: below the confidence threshold - prediction are discarded :param iou_thres: IoU threshold for the nms algorithm :param multi_label_per_box: whether to use re-use each box with all possible labels @@ -257,7 +257,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label_p if with_confidence: pred[:, 5:] *= pred[:, 4:5] # multiply objectness score with class score - box = convert_xywh_bbox_to_xyxy(pred[:, :4]) # xywh to xyxy + box = convert_cxcywh_bbox_to_xyxy(pred[:, :4]) # cxcywh to xyxy # Detections matrix nx6 (xyxy, conf, cls) if multi_label_per_box: # try for all good confidence classes @@ -302,7 +302,7 @@ def matrix_non_max_suppression( pred[:, :, 4] *= class_conf # BOX (CENTER X, CENTER Y, WIDTH, HEIGHT) TO (X1, Y1, X2, Y2) - pred[:, :, :4] = convert_xywh_bbox_to_xyxy(pred[:, :, :4]) + pred[:, :, :4] = convert_cxcywh_bbox_to_xyxy(pred[:, :, :4]) # DETECTIONS ORDERED AS (x1y1x2y2, obj_conf, class_conf, class_pred) pred = torch.cat((pred[:, :, :5], class_pred.unsqueeze(2)), 2) @@ -822,7 +822,7 @@ def crowd_ioa(det_box: torch.Tensor, crowd_box: torch.Tensor) -> torch.Tensor: def compute_detection_matching( - output: torch.Tensor, + output: List[torch.Tensor], targets: torch.Tensor, height: int, width: int, diff --git a/tests/unit_tests/transforms_test.py b/tests/unit_tests/transforms_test.py index 85edf21ef0..b537eb4080 100644 --- a/tests/unit_tests/transforms_test.py +++ b/tests/unit_tests/transforms_test.py @@ -11,6 +11,18 @@ ) from super_gradients.training.transforms.transforms import DetectionImagePermute, DetectionPadToSize +from super_gradients.training.transforms.utils import ( + _rescale_image, + _rescale_bboxes, + _pad_image, + _shift_bboxes, + _rescale_and_pad_to_size, + _rescale_xyxy_bboxes, + _get_center_padding_coordinates, + _get_bottom_right_padding_coordinates, + PaddingCoordinates, +) + class TestTransforms(unittest.TestCase): def test_keypoints_random_affine(self): @@ -120,6 +132,166 @@ def test_detection_pad_to_size(self): self.assertEqual(output["image"].shape, (640, 640, 3)) np.testing.assert_array_equal(output["target"], expected_boxes) + def test_rescale_image(self): + image = np.random.randint(0, 256, size=(640, 480, 3), dtype=np.uint8) + target_shape = (320, 240) + rescaled_image = _rescale_image(image, target_shape) + + # Check if the rescaled image has the correct target shape + self.assertEqual(rescaled_image.shape[:2], target_shape) + + def test_rescale_bboxes(self): + sy, sx = (2.0, 0.5) + + # Empty bboxes + bboxes = np.zeros((0, 4)) + expected_bboxes = np.zeros((0, 4)) + rescaled_bboxes = _rescale_bboxes(targets=bboxes, scale_factors=(sy, sx)) + np.testing.assert_array_equal(rescaled_bboxes, expected_bboxes) + + # Not empty bboxes + bboxes = np.array([[10, 20, 50, 60, 1], [30, 40, 80, 90, 2]], dtype=np.float32) + expected_bboxes = np.array([[5.0, 40.0, 25.0, 120.0, 1.0], [15.0, 80.0, 40.0, 180.0, 2.0]], dtype=np.float32) + rescaled_bboxes = _rescale_bboxes(targets=bboxes, scale_factors=(sy, sx)) + np.testing.assert_array_equal(rescaled_bboxes, expected_bboxes) + + def test_pad_image(self): + image = np.random.randint(0, 256, size=(640, 480, 3), dtype=np.uint8) + padding_coordinates = PaddingCoordinates(top=80, bottom=80, left=60, right=60) + pad_value = 0 + shifted_image = _pad_image(image, padding_coordinates, pad_value) + + # Check if the shifted image has the correct shape + self.assertEqual(shifted_image.shape, (800, 600, 3)) + # Check if the padding values are correct + self.assertTrue((shifted_image[: padding_coordinates.top, :, :] == pad_value).all()) + self.assertTrue((shifted_image[-padding_coordinates.bottom :, :, :] == pad_value).all()) + self.assertTrue((shifted_image[:, : padding_coordinates.left, :] == pad_value).all()) + self.assertTrue((shifted_image[:, -padding_coordinates.right :, :] == pad_value).all()) + + def test_shift_bboxes(self): + bboxes = np.array([[10, 20, 50, 60, 1], [30, 40, 80, 90, 2]], dtype=np.float32) + shift_w, shift_h = 60, 80 + shifted_bboxes = _shift_bboxes(bboxes, shift_w, shift_h) + + # Check if the shifted bboxes have the correct values + expected_bboxes = np.array([[70, 100, 110, 140, 1], [90, 120, 140, 170, 2]], dtype=np.float32) + np.testing.assert_array_equal(shifted_bboxes, expected_bboxes) + + def test_rescale_xyxy_bboxes(self): + bboxes = np.array([[10, 20, 50, 60, 1], [30, 40, 80, 90, 2]], dtype=np.float32) + r = 0.5 + rescaled_bboxes = _rescale_xyxy_bboxes(bboxes, r) + + # Check if the rescaled bboxes have the correct values + expected_bboxes = np.array([[5.0, 10.0, 25.0, 30.0, 1.0], [15.0, 20.0, 40.0, 45.0, 2.0]], dtype=np.float32) + np.testing.assert_array_equal(rescaled_bboxes, expected_bboxes) + + def test_padding(self): + # Test Case 1: Padding needed + image = np.array([[1, 2], [3, 4]]) + padding_coordinates = PaddingCoordinates(top=0, left=0, bottom=1, right=2) + expected_padded_image = np.array([[1, 2, 114, 114], [3, 4, 114, 114], [114, 114, 114, 114]]) + + padded_image = _pad_image(image=image, padding_coordinates=padding_coordinates, pad_value=114) + np.testing.assert_array_equal(padded_image, expected_padded_image) + + # Test Case 2: No padding needed + image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + padding_coordinates = PaddingCoordinates(top=0, left=0, bottom=0, right=0) + expected_padded_image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + padded_image = _pad_image(image=image, padding_coordinates=padding_coordinates, pad_value=114) + np.testing.assert_array_equal(padded_image, expected_padded_image) + + # Test Case 3: Image with channel dimension + image = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) + padding_coordinates = PaddingCoordinates(top=0, left=0, bottom=1, right=2) + expected_padded_image = np.array( + [ + [[1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0]], + [[7, 8, 9], [10, 11, 12], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + ], + ) + + padded_image = _pad_image(image=image, padding_coordinates=padding_coordinates, pad_value=0) + np.testing.assert_array_equal(padded_image, expected_padded_image) + + def test_get_padding_coordinates(self): + # Test Case 1: Width padding required + image = np.zeros((640, 480)) + output_size = (640, 640) + expected_center_padding = PaddingCoordinates(top=0, bottom=0, left=80, right=80) + expected_bottom_right_padding = PaddingCoordinates(top=0, bottom=0, left=0, right=160) + + center_padding_coordinates = _get_center_padding_coordinates(input_shape=image.shape, output_shape=output_size) + bottom_right_padding_coordinates = _get_bottom_right_padding_coordinates(input_shape=image.shape, output_shape=output_size) + self.assertEqual(center_padding_coordinates, expected_center_padding) + self.assertEqual(bottom_right_padding_coordinates, expected_bottom_right_padding) + + # Test Case 2: Height padding required + image = np.zeros((480, 640)) + output_size = (640, 640) + expected_center_padding = PaddingCoordinates(top=80, bottom=80, left=0, right=0) + expected_bottom_right_padding = PaddingCoordinates(top=0, bottom=160, left=0, right=0) + + center_padding_coordinates = _get_center_padding_coordinates(input_shape=image.shape, output_shape=output_size) + bottom_right_padding_coordinates = _get_bottom_right_padding_coordinates(input_shape=image.shape, output_shape=output_size) + self.assertEqual(center_padding_coordinates, expected_center_padding) + self.assertEqual(bottom_right_padding_coordinates, expected_bottom_right_padding) + + # Test Case 3: Width and Height padding required + image = np.zeros((480, 640)) + output_size = (800, 800) + expected_center_padding = PaddingCoordinates(top=160, bottom=160, left=80, right=80) + expected_bottom_right_padding = PaddingCoordinates(top=0, bottom=320, left=0, right=160) + + center_padding_coordinates = _get_center_padding_coordinates(input_shape=image.shape, output_shape=output_size) + bottom_right_padding_coordinates = _get_bottom_right_padding_coordinates(input_shape=image.shape, output_shape=output_size) + self.assertEqual(center_padding_coordinates, expected_center_padding) + self.assertEqual(bottom_right_padding_coordinates, expected_bottom_right_padding) + + # Test Case 4: Image shape is bigger than output shape + image = np.zeros((800, 800)) + output_size = (640, 640) + expected_center_padding = PaddingCoordinates(top=-80, bottom=-80, left=-80, right=-80) + expected_bottom_right_padding = PaddingCoordinates(top=0, bottom=-160, left=0, right=-160) + + center_padding_coordinates = _get_center_padding_coordinates(input_shape=image.shape, output_shape=output_size) + bottom_right_padding_coordinates = _get_bottom_right_padding_coordinates(input_shape=image.shape, output_shape=output_size) + self.assertEqual(center_padding_coordinates, expected_center_padding) + self.assertEqual(bottom_right_padding_coordinates, expected_bottom_right_padding) + + # Test Case 5: Width and Height padding required with an image of 3 channels + image = np.zeros((480, 640, 3)) + output_size = (800, 800) + expected_center_padding = PaddingCoordinates(top=160, bottom=160, left=80, right=80) + expected_bottom_right_padding = PaddingCoordinates(top=0, bottom=320, left=0, right=160) + + center_padding_coordinates = _get_center_padding_coordinates(input_shape=image.shape, output_shape=output_size) + bottom_right_padding_coordinates = _get_bottom_right_padding_coordinates(input_shape=image.shape, output_shape=output_size) + self.assertEqual(center_padding_coordinates, expected_center_padding) + self.assertEqual(bottom_right_padding_coordinates, expected_bottom_right_padding) + + def test_rescale_and_pad_to_size(self): + image = np.random.randint(0, 256, size=(640, 480, 3), dtype=np.uint8) + output_size = (800, 500) + pad_val = 114 + rescaled_padded_image, r = _rescale_and_pad_to_size(image, output_size, pad_val=pad_val) + + # Check if the rescaled and padded image has the correct shape + self.assertEqual(rescaled_padded_image.shape, (3, *output_size)) + + # Check if the image is rescaled with the correct ratio + resized_image_shape = (int(image.shape[0] * r), int(image.shape[1] * r)) + + # Check if the padding is correctly applied + padded_area = rescaled_padded_image[:, resized_image_shape[0] :, :] # Right padding area + self.assertTrue((padded_area == pad_val).all()) + padded_area = rescaled_padded_image[:, :, resized_image_shape[1] :] # Bottom padding area + self.assertTrue((padded_area == pad_val).all()) + if __name__ == "__main__": unittest.main()