From b89b214ea3efaf05fa03d749e67b72a989775ef6 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Fri, 20 Jan 2023 17:32:15 +0200 Subject: [PATCH 1/2] Added new detection transforms that are used in PPYolo. --- src/super_gradients/common/object_names.py | 3 + .../training/transforms/all_transforms.py | 6 + .../training/transforms/transforms.py | 120 ++++++++++++++++++ 3 files changed, 129 insertions(+) diff --git a/src/super_gradients/common/object_names.py b/src/super_gradients/common/object_names.py index 097dd11dd3..e3e035d42d 100644 --- a/src/super_gradients/common/object_names.py +++ b/src/super_gradients/common/object_names.py @@ -48,7 +48,10 @@ class Transforms: DetectionRandomAffine = "DetectionRandomAffine" DetectionMixup = "DetectionMixup" DetectionHSV = "DetectionHSV" + DetectionRGB2BGR = "DetectionRGB2BGR" + DetectionRandomRotate90 = "DetectionRandomRotate90" DetectionHorizontalFlip = "DetectionHorizontalFlip" + DetectionRescale = "DetectionRescale" DetectionPaddedRescale = "DetectionPaddedRescale" DetectionTargetsFormatTransform = "DetectionTargetsFormatTransform" RandomResizedCropAndInterpolation = "RandomResizedCropAndInterpolation" diff --git a/src/super_gradients/training/transforms/all_transforms.py b/src/super_gradients/training/transforms/all_transforms.py index 9ec71464d5..31030d3503 100644 --- a/src/super_gradients/training/transforms/all_transforms.py +++ b/src/super_gradients/training/transforms/all_transforms.py @@ -21,7 +21,10 @@ DetectionRandomAffine, DetectionMixup, DetectionHSV, + DetectionRGB2BGR, + DetectionRandomRotate90, DetectionHorizontalFlip, + DetectionRescale, DetectionPaddedRescale, DetectionTargetsFormatTransform, Standardize, @@ -79,7 +82,10 @@ Transforms.DetectionRandomAffine: DetectionRandomAffine, Transforms.DetectionMixup: DetectionMixup, Transforms.DetectionHSV: DetectionHSV, + Transforms.DetectionRGB2BGR: DetectionRGB2BGR, + Transforms.DetectionRandomRotate90: DetectionRandomRotate90, Transforms.DetectionHorizontalFlip: DetectionHorizontalFlip, + Transforms.DetectionRescale: DetectionRescale, Transforms.DetectionPaddedRescale: DetectionPaddedRescale, Transforms.DetectionTargetsFormatTransform: DetectionTargetsFormatTransform, Transforms.RandomResizedCropAndInterpolation: RandomResizedCropAndInterpolation, diff --git a/src/super_gradients/training/transforms/transforms.py b/src/super_gradients/training/transforms/transforms.py index 0e4b92d047..8d79e5aab1 100644 --- a/src/super_gradients/training/transforms/transforms.py +++ b/src/super_gradients/training/transforms/transforms.py @@ -718,6 +718,126 @@ def __call__(self, sample): return sample +class DetectionRescale(DetectionTransform): + """ + Resize image and bounding boxes to given image dimensions without preserving aspect ratio + Attributes: + input_dim: (tuple) (rows, cols) + swap: image axis's to be rearranged. + """ + + def __init__(self, input_dim: Tuple[int, int], swap=(2, 0, 1)): + super().__init__() + self.swap = swap + self.input_dim = input_dim + + def __call__(self, sample: Dict[str, np.array]): + img, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target") + + img_resized, scale_factors = self._rescale_image(img) + + sample["image"] = img_resized.transpose(self.swap).astype(np.float32, copy=True) + sample["target"] = self._rescale_target(targets, scale_factors) + if crowd_targets is not None: + sample["crowd_target"] = self._rescale_target(crowd_targets, scale_factors) + return sample + + def _rescale_image(self, image): + sy, sx = self.input_dim[0] / image.shape[0], self.input_dim[1] / image.shape[1] + resized_img = cv2.resize( + image, + dsize=(int(self.input_dim[1]), int(self.input_dim[0])), + interpolation=cv2.INTER_LINEAR, + ) + scale_factors = sy, sx + return resized_img, scale_factors + + def _rescale_target(self, targets: np.array, scale_factors: Tuple[float, float]) -> np.array: + """SegRescale the target according to a coefficient used to rescale the image. + This is done to have images and targets at the same scale. + :param targets: Target XYXY bboxes to rescale, shape (num_boxes, 5) + :param r: SegRescale coefficient that was applied to the image + :return: Rescaled targets, shape (num_boxes, 5) + """ + sy, sx = scale_factors + targets = targets.astype(np.float32, copy=True) if len(targets) > 0 else np.zeros((0, 5), dtype=np.float32) + targets[:, 0:4] *= np.array([[sx, sy, sx, sy]], dtype=targets.dtype) + return targets + + +class DetectionRandomRotate90(DetectionTransform): + def __init__(self, prob: float = 0.5): + super().__init__() + self.prob = prob + + def __call__(self, sample: dict) -> dict: + if random.random() < self.prob: + k = random.randrange(0, 4) + + img, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target") + + sample["image"] = np.ascontiguousarray(np.rot90(img, k)) + sample["target"] = self.rotate_bboxes(targets, k, img.shape[:2]) + if crowd_targets is not None: + sample["crowd_target"] = self.rotate_bboxes(crowd_targets, k, img.shape[:2]) + + return sample + + @classmethod + def rotate_bboxes(cls, targets, k: int, image_shape): + if k == 0: + return targets + rows, cols = image_shape + targets = targets.copy() + targets[:, 0:4] = cls.xyxy_bbox_rot90(targets[:, 0:4], k, rows, cols) + return targets + + @classmethod + def xyxy_bbox_rot90(cls, bboxes, factor: int, rows: int, cols: int): + """Rotates a bounding box by 90 degrees CCW (see np.rot90) + Args: + bbox: A bounding box tuple (x_min, y_min, x_max, y_max). + factor: Number of CCW rotations. Must be in set {0, 1, 2, 3} See np.rot90. + rows: Image rows. + cols: Image cols. + Returns: + tuple: A bounding box tuple (x_min, y_min, x_max, y_max). + """ + x_min, y_min, x_max, y_max = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3] + + if factor == 0: + bbox = x_min, y_min, x_max, y_max + elif factor == 1: + bbox = y_min, cols - x_max, y_max, cols - x_min + elif factor == 2: + bbox = cols - x_max, rows - y_max, cols - x_min, rows - y_min + elif factor == 3: + bbox = rows - y_max, x_min, rows - y_min, x_max + else: + raise ValueError("Parameter n must be in set {0, 1, 2, 3}") + return np.stack(bbox, axis=1) + + +class DetectionRGB2BGR(DetectionTransform): + """ + Detection change Red & Blue channel of the image + Attributes: + prob: (float) probability to apply the transform. + """ + + def __init__(self, prob: float = 0.5): + super().__init__() + self.prob = prob + + def __call__(self, sample: dict) -> dict: + if sample["image"].shape[2] < 3: + raise ValueError("HSV transform expects at least 3 channels, got: " + str(sample["image"].shape[2])) + + if random.random() < self.prob: + sample["image"] = sample["image"][..., ::-1] + return sample + + class DetectionHSV(DetectionTransform): """ Detection HSV transform. From 68885ebc971691c30f20188d877b368bcd45d2f4 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Fri, 20 Jan 2023 17:34:47 +0200 Subject: [PATCH 2/2] Fix exception message --- src/super_gradients/training/transforms/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/super_gradients/training/transforms/transforms.py b/src/super_gradients/training/transforms/transforms.py index 8d79e5aab1..3f62269451 100644 --- a/src/super_gradients/training/transforms/transforms.py +++ b/src/super_gradients/training/transforms/transforms.py @@ -830,8 +830,8 @@ def __init__(self, prob: float = 0.5): self.prob = prob def __call__(self, sample: dict) -> dict: - if sample["image"].shape[2] < 3: - raise ValueError("HSV transform expects at least 3 channels, got: " + str(sample["image"].shape[2])) + if sample["image"].shape[2] != 3: + raise ValueError("DetectionRGB2BGR expects image to have 3 channels, got: " + str(sample["image"].shape[2])) if random.random() < self.prob: sample["image"] = sample["image"][..., ::-1]