diff --git a/src/super_gradients/training/transforms/utils.py b/src/super_gradients/training/transforms/utils.py index 1c3708e59e..9adec24ea9 100644 --- a/src/super_gradients/training/transforms/utils.py +++ b/src/super_gradients/training/transforms/utils.py @@ -1,4 +1,6 @@ -from typing import Tuple +import numbers +import typing +from typing import Tuple, Union from dataclasses import dataclass import cv2 @@ -106,22 +108,47 @@ def _get_bottom_right_padding_coordinates(input_shape: Tuple[int, int], output_s return PaddingCoordinates(top=0, bottom=pad_height, left=0, right=pad_width) -def _pad_image(image: np.ndarray, padding_coordinates: PaddingCoordinates, pad_value: int) -> np.ndarray: +def _pad_image(image: np.ndarray, padding_coordinates: PaddingCoordinates, pad_value: Union[int, Tuple[int, ...]]) -> np.ndarray: """Pad an image. :param image: Image to shift. (H, W, C) or (H, W). :param pad_h: Tuple of (padding_top, padding_bottom). :param pad_w: Tuple of (padding_left, padding_right). - :param pad_value: Padding value + :param pad_value: Padding value. Can be a single scalar (Same value for all channels) or a tuple of values. + In the latter case, the tuple length must be equal to the number of channels. :return: Image shifted according to padding coordinates. """ pad_h = (padding_coordinates.top, padding_coordinates.bottom) pad_w = (padding_coordinates.left, padding_coordinates.right) if len(image.shape) == 3: - return np.pad(image, (pad_h, pad_w, (0, 0)), "constant", constant_values=pad_value) + _, _, num_channels = image.shape + + if isinstance(pad_value, numbers.Number): + pad_value = tuple([pad_value] * num_channels) + else: + if isinstance(pad_value, typing.Sized) and len(pad_value) != num_channels: + raise ValueError(f"A pad_value tuple ({pad_value} length should be {num_channels} for an image with {num_channels} channels") + + pad_value = tuple(pad_value) + + constant_values = ((pad_value, pad_value), (pad_value, pad_value), (0, 0)) + padding_values = (pad_h, pad_w, (0, 0)) else: - return np.pad(image, (pad_h, pad_w), "constant", constant_values=pad_value) + if isinstance(pad_value, numbers.Number): + pass + elif isinstance(pad_value, typing.Sized): + if len(pad_value) != 1: + raise ValueError(f"A pad_value tuple ({pad_value} length should be 1 for a grayscale image") + else: + (pad_value,) = pad_value # Unpack to a single scalar + else: + raise ValueError(f"Unsupported pad_value type {type(pad_value)}") + + constant_values = pad_value + padding_values = (pad_h, pad_w) + + return np.pad(image, pad_width=padding_values, mode="constant", constant_values=constant_values) def _shift_bboxes(targets: np.array, shift_w: float, shift_h: float) -> np.array: diff --git a/tests/unit_tests/transforms_test.py b/tests/unit_tests/transforms_test.py index 281fc26860..d8a73dbade 100644 --- a/tests/unit_tests/transforms_test.py +++ b/tests/unit_tests/transforms_test.py @@ -3,6 +3,7 @@ import cv2 import matplotlib.pyplot as plt import numpy as np +from omegaconf import ListConfig from super_gradients.training.transforms import KeypointsMixup, KeypointsCompose from super_gradients.training.transforms.keypoint_transforms import ( @@ -244,7 +245,7 @@ def test_rescale_bboxes(self): rescaled_bboxes = _rescale_bboxes(targets=bboxes, scale_factors=(sy, sx)) np.testing.assert_array_equal(rescaled_bboxes, expected_bboxes) - def test_pad_image(self): + def test_pad_image_with_constant(self): image = np.random.randint(0, 256, size=(640, 480, 3), dtype=np.uint8) padding_coordinates = PaddingCoordinates(top=80, bottom=80, left=60, right=60) pad_value = 0 @@ -258,6 +259,48 @@ def test_pad_image(self): self.assertTrue((shifted_image[:, : padding_coordinates.left, :] == pad_value).all()) self.assertTrue((shifted_image[:, -padding_coordinates.right :, :] == pad_value).all()) + def test_pad_image_with_tuple(self): + image = np.random.randint(0, 256, size=(640, 480, 3), dtype=np.uint8) + padding_coordinates = PaddingCoordinates(top=80, bottom=80, left=60, right=60) + pad_value = (1, 2, 3) + shifted_image = _pad_image(image, padding_coordinates, pad_value) + + # Check if the shifted image has the correct shape + self.assertEqual(shifted_image.shape, (800, 600, 3)) + # Check if the padding values are correct + self.assertTrue((shifted_image[: padding_coordinates.top, :, :] == pad_value).all()) + self.assertTrue((shifted_image[-padding_coordinates.bottom :, :, :] == pad_value).all()) + self.assertTrue((shifted_image[:, : padding_coordinates.left, :] == pad_value).all()) + self.assertTrue((shifted_image[:, -padding_coordinates.right :, :] == pad_value).all()) + + def test_pad_image_with_listconfig(self): + image = np.random.randint(0, 256, size=(640, 480, 3), dtype=np.uint8) + padding_coordinates = PaddingCoordinates(top=80, bottom=80, left=60, right=60) + pad_value = ListConfig([1, 2, 3]) + shifted_image = _pad_image(image, padding_coordinates, pad_value) + + # Check if the shifted image has the correct shape + self.assertEqual(shifted_image.shape, (800, 600, 3)) + # Check if the padding values are correct + self.assertTrue((shifted_image[: padding_coordinates.top, :, :] == pad_value).all()) + self.assertTrue((shifted_image[-padding_coordinates.bottom :, :, :] == pad_value).all()) + self.assertTrue((shifted_image[:, : padding_coordinates.left, :] == pad_value).all()) + self.assertTrue((shifted_image[:, -padding_coordinates.right :, :] == pad_value).all()) + + def test_pad_grayscale_image(self): + image = np.random.randint(0, 256, size=(640, 480), dtype=np.uint8) + padding_coordinates = PaddingCoordinates(top=80, bottom=80, left=60, right=60) + pad_value = 1 + shifted_image = _pad_image(image, padding_coordinates, pad_value) + + # Check if the shifted image has the correct shape + self.assertEqual(shifted_image.shape, (800, 600)) + # Check if the padding values are correct + self.assertTrue((shifted_image[: padding_coordinates.top, :] == pad_value).all()) + self.assertTrue((shifted_image[-padding_coordinates.bottom :, :] == pad_value).all()) + self.assertTrue((shifted_image[:, : padding_coordinates.left] == pad_value).all()) + self.assertTrue((shifted_image[:, -padding_coordinates.right :] == pad_value).all()) + def test_shift_bboxes(self): bboxes = np.array([[10, 20, 50, 60, 1], [30, 40, 80, 90, 2]], dtype=np.float32) shift_w, shift_h = 60, 80