Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new detection transforms that are used in PPYoloE #641

Merged
merged 3 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/super_gradients/common/object_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ class Transforms:
DetectionRandomAffine = "DetectionRandomAffine"
DetectionMixup = "DetectionMixup"
DetectionHSV = "DetectionHSV"
DetectionRGB2BGR = "DetectionRGB2BGR"
DetectionRandomRotate90 = "DetectionRandomRotate90"
DetectionHorizontalFlip = "DetectionHorizontalFlip"
DetectionRescale = "DetectionRescale"
DetectionPaddedRescale = "DetectionPaddedRescale"
DetectionTargetsFormatTransform = "DetectionTargetsFormatTransform"
RandomResizedCropAndInterpolation = "RandomResizedCropAndInterpolation"
Expand Down
6 changes: 6 additions & 0 deletions src/super_gradients/training/transforms/all_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
DetectionRandomAffine,
DetectionMixup,
DetectionHSV,
DetectionRGB2BGR,
DetectionRandomRotate90,
DetectionHorizontalFlip,
DetectionRescale,
DetectionPaddedRescale,
DetectionTargetsFormatTransform,
Standardize,
Expand Down Expand Up @@ -79,7 +82,10 @@
Transforms.DetectionRandomAffine: DetectionRandomAffine,
Transforms.DetectionMixup: DetectionMixup,
Transforms.DetectionHSV: DetectionHSV,
Transforms.DetectionRGB2BGR: DetectionRGB2BGR,
Transforms.DetectionRandomRotate90: DetectionRandomRotate90,
Transforms.DetectionHorizontalFlip: DetectionHorizontalFlip,
Transforms.DetectionRescale: DetectionRescale,
Transforms.DetectionPaddedRescale: DetectionPaddedRescale,
Transforms.DetectionTargetsFormatTransform: DetectionTargetsFormatTransform,
Transforms.RandomResizedCropAndInterpolation: RandomResizedCropAndInterpolation,
Expand Down
120 changes: 120 additions & 0 deletions src/super_gradients/training/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,126 @@ def __call__(self, sample):
return sample


class DetectionRescale(DetectionTransform):
"""
Resize image and bounding boxes to given image dimensions without preserving aspect ratio
Attributes:
input_dim: (tuple) (rows, cols)
swap: image axis's to be rearranged.
"""

def __init__(self, input_dim: Tuple[int, int], swap=(2, 0, 1)):
super().__init__()
self.swap = swap
self.input_dim = input_dim

def __call__(self, sample: Dict[str, np.array]):
img, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target")

img_resized, scale_factors = self._rescale_image(img)

sample["image"] = img_resized.transpose(self.swap).astype(np.float32, copy=True)
sample["target"] = self._rescale_target(targets, scale_factors)
if crowd_targets is not None:
sample["crowd_target"] = self._rescale_target(crowd_targets, scale_factors)
return sample

def _rescale_image(self, image):
sy, sx = self.input_dim[0] / image.shape[0], self.input_dim[1] / image.shape[1]
resized_img = cv2.resize(
image,
dsize=(int(self.input_dim[1]), int(self.input_dim[0])),
interpolation=cv2.INTER_LINEAR,
)
scale_factors = sy, sx
return resized_img, scale_factors

def _rescale_target(self, targets: np.array, scale_factors: Tuple[float, float]) -> np.array:
"""SegRescale the target according to a coefficient used to rescale the image.
This is done to have images and targets at the same scale.
:param targets: Target XYXY bboxes to rescale, shape (num_boxes, 5)
:param r: SegRescale coefficient that was applied to the image
:return: Rescaled targets, shape (num_boxes, 5)
"""
sy, sx = scale_factors
targets = targets.astype(np.float32, copy=True) if len(targets) > 0 else np.zeros((0, 5), dtype=np.float32)
targets[:, 0:4] *= np.array([[sx, sy, sx, sy]], dtype=targets.dtype)
return targets


class DetectionRandomRotate90(DetectionTransform):
def __init__(self, prob: float = 0.5):
super().__init__()
self.prob = prob

def __call__(self, sample: dict) -> dict:
if random.random() < self.prob:
k = random.randrange(0, 4)

img, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target")

sample["image"] = np.ascontiguousarray(np.rot90(img, k))
sample["target"] = self.rotate_bboxes(targets, k, img.shape[:2])
if crowd_targets is not None:
sample["crowd_target"] = self.rotate_bboxes(crowd_targets, k, img.shape[:2])

return sample

@classmethod
def rotate_bboxes(cls, targets, k: int, image_shape):
if k == 0:
return targets
rows, cols = image_shape
targets = targets.copy()
targets[:, 0:4] = cls.xyxy_bbox_rot90(targets[:, 0:4], k, rows, cols)
return targets

@classmethod
def xyxy_bbox_rot90(cls, bboxes, factor: int, rows: int, cols: int):
"""Rotates a bounding box by 90 degrees CCW (see np.rot90)
Args:
bbox: A bounding box tuple (x_min, y_min, x_max, y_max).
factor: Number of CCW rotations. Must be in set {0, 1, 2, 3} See np.rot90.
rows: Image rows.
cols: Image cols.
Returns:
tuple: A bounding box tuple (x_min, y_min, x_max, y_max).
"""
x_min, y_min, x_max, y_max = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]

if factor == 0:
bbox = x_min, y_min, x_max, y_max
elif factor == 1:
bbox = y_min, cols - x_max, y_max, cols - x_min
elif factor == 2:
bbox = cols - x_max, rows - y_max, cols - x_min, rows - y_min
elif factor == 3:
bbox = rows - y_max, x_min, rows - y_min, x_max
else:
raise ValueError("Parameter n must be in set {0, 1, 2, 3}")
return np.stack(bbox, axis=1)


class DetectionRGB2BGR(DetectionTransform):
"""
Detection change Red & Blue channel of the image
Attributes:
prob: (float) probability to apply the transform.
"""

def __init__(self, prob: float = 0.5):
super().__init__()
self.prob = prob

def __call__(self, sample: dict) -> dict:
if sample["image"].shape[2] != 3:
raise ValueError("DetectionRGB2BGR expects image to have 3 channels, got: " + str(sample["image"].shape[2]))

if random.random() < self.prob:
sample["image"] = sample["image"][..., ::-1]
return sample


class DetectionHSV(DetectionTransform):
"""
Detection HSV transform.
Expand Down