Skip to content

Commit

Permalink
Feature/sg 747 add preprocessing (#804)
Browse files Browse the repository at this point in the history
* wip

* move to imageprocessors

* wip

* add back changes

* making it work fully for yolox and almost for ppyoloe

* minor change

* working for det

* cleaning

* clean

* undo

* replace empty with none

* add _get_shift_params

* minor doc change

* replace pydantic with dataclasses and fix typing

* add docstrings

* doc improvment and use get_shift_params in transforms

* add tests

* improve comment

* rename

* add option to keep ratio in rescale

* make functions private

* remove DetectionPaddedRescale

* fix doc

* add fixes

* improve _get_center_padding_params output

* minor fix

* add empty bbox test for rescale_bboxes

* finalizing _DetectionPadding, DetectionCenterPadding and DetectionBottomRightPadding

* remove _pad_to_side

* split rescale into 2 classes

* minor addition

* Add DetectionPrediction object

* simplify DetectionPrediction class

* add round and don't rescale if no change required

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
Louis-Dupont and BloodAxe committed Apr 17, 2023
1 parent 64e96b2 commit 82c03df
Show file tree
Hide file tree
Showing 8 changed files with 607 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ anchors:
yolo_type: 'yoloX'

depth_mult_factor: 0.33
width_mult_factor: 0.5
width_mult_factor: 0.5
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ val_dataset_params:
mean: [0.4802, 0.4481, 0.3975]
std: [0.2770, 0.2691, 0.2821]

_convert_: all
_convert_: all
41 changes: 41 additions & 0 deletions src/super_gradients/training/models/predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Tuple
from abc import ABC
from dataclasses import dataclass

import numpy as np

from super_gradients.common.factories.bbox_format_factory import BBoxFormatFactory
from super_gradients.training.datasets.data_formats.bbox_formats import convert_bboxes


@dataclass
class Prediction(ABC):
pass


@dataclass
class DetectionPrediction(Prediction):
"""Represents a detection prediction, with bboxes represented in xyxy format."""

bboxes_xyxy: np.ndarray
confidence: np.ndarray
labels: np.ndarray

def __init__(self, bboxes: np.ndarray, bbox_format: str, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]):
"""
:param bboxes: BBoxes in the format specified by bbox_format
:param bbox_format: BBoxes format that can be a string ("xyxy", "cxywh", ...)
:param confidence: Confidence scores for each bounding box
:param labels: Labels for each bounding box.
:param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
"""
factory = BBoxFormatFactory()
self.bboxes_xyxy = convert_bboxes(
bboxes=bboxes,
image_shape=image_shape,
source_format=factory.get(bbox_format),
target_format=factory.get("xyxy"),
inplace=False,
)
self.confidence = confidence
self.labels = labels
204 changes: 204 additions & 0 deletions src/super_gradients/training/transforms/processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from typing import Tuple, List, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass

import numpy as np

from super_gradients.training.models.predictions import Prediction, DetectionPrediction
from super_gradients.training.transforms.utils import (
_rescale_image,
_rescale_bboxes,
_get_center_padding_coordinates,
_get_bottom_right_padding_coordinates,
_pad_image,
_shift_bboxes,
PaddingCoordinates,
)


@dataclass
class ProcessingMetadata(ABC):
"""Metadata including information to postprocess a prediction."""


@dataclass
class ComposeProcessingMetadata(ProcessingMetadata):
metadata_lst: List[Union[None, ProcessingMetadata]]


@dataclass
class DetectionPadToSizeMetadata(ProcessingMetadata):
padding_coordinates: PaddingCoordinates


@dataclass
class RescaleMetadata(ProcessingMetadata):
original_shape: Tuple[int, int]
scale_factor_h: float
scale_factor_w: float


class Processing(ABC):
"""Interface for preprocessing and postprocessing methods that are
used to prepare images for a model and process the model's output.
Subclasses should implement the `preprocess_image` and `postprocess_predictions`
methods according to the specific requirements of the model and task.
"""

@abstractmethod
def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, Union[None, ProcessingMetadata]]:
"""Processing an image, before feeding it to the network. Expected to be in (H, W, C) or (H, W)."""
pass

@abstractmethod
def postprocess_predictions(self, predictions: Prediction, metadata: Union[None, ProcessingMetadata]) -> Prediction:
"""Postprocess the model output predictions."""
pass


class ComposeProcessing(Processing):
"""Compose a list of Processing objects into a single Processing object."""

def __init__(self, processings: List[Processing]):
self.processings = processings

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, ComposeProcessingMetadata]:
"""Processing an image, before feeding it to the network."""
processed_image, metadata_lst = image.copy(), []
for processing in self.processings:
processed_image, metadata = processing.preprocess_image(image=processed_image)
metadata_lst.append(metadata)
return processed_image, ComposeProcessingMetadata(metadata_lst=metadata_lst)

def postprocess_predictions(self, predictions: Prediction, metadata: ComposeProcessingMetadata) -> Prediction:
"""Postprocess the model output predictions."""
postprocessed_predictions = predictions
for processing, metadata in zip(self.processings[::-1], metadata.metadata_lst[::-1]):
postprocessed_predictions = processing.postprocess_predictions(postprocessed_predictions, metadata)
return postprocessed_predictions


class ImagePermute(Processing):
"""Permute the image dimensions.
:param permutation: Specify new order of dims. Default value (2, 0, 1) suitable for converting from HWC to CHW format.
"""

def __init__(self, permutation: Tuple[int, int, int] = (2, 0, 1)):
self.permutation = permutation

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
processed_image = np.ascontiguousarray(image.transpose(*self.permutation))
return processed_image, None

def postprocess_predictions(self, predictions: Prediction, metadata: None) -> Prediction:
return predictions


class NormalizeImage(Processing):
"""Normalize an image based on means and standard deviation.
:param mean: Mean values for each channel.
:param std: Standard deviation values for each channel.
"""

def __init__(self, mean: List[float], std: List[float]):
self.mean = np.array(mean).reshape((1, 1, -1)).astype(np.float32)
self.std = np.array(std).reshape((1, 1, -1)).astype(np.float32)

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
return (image - self.mean) / self.std, None

def postprocess_predictions(self, predictions: Prediction, metadata: None) -> Prediction:
return predictions


class _DetectionPadding(Processing, ABC):
"""Base class for detection padding methods. One should implement the `_get_padding_params` method to work with a custom padding method.
Note: This transformation assume that dimensions of input image is equal or less than `output_shape`.
:param output_shape: Output image shape (H, W)
:param pad_value: Padding value for image
"""

def __init__(self, output_shape: Tuple[int, int], pad_value: int):
self.output_shape = output_shape
self.pad_value = pad_value

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, DetectionPadToSizeMetadata]:
padding_coordinates = self._get_padding_params(input_shape=image.shape)
processed_image = _pad_image(image=image, padding_coordinates=padding_coordinates, pad_value=self.pad_value)
return processed_image, DetectionPadToSizeMetadata(padding_coordinates=padding_coordinates)

def postprocess_predictions(self, predictions: DetectionPrediction, metadata: DetectionPadToSizeMetadata) -> DetectionPrediction:
predictions.bboxes_xyxy = _shift_bboxes(
targets=predictions.bboxes_xyxy,
shift_h=-metadata.padding_coordinates.top,
shift_w=-metadata.padding_coordinates.left,
)
return predictions

@abstractmethod
def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates:
pass


class DetectionCenterPadding(_DetectionPadding):
def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates:
return _get_center_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape)


class DetectionBottomRightPadding(_DetectionPadding):
def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates:
return _get_bottom_right_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape)


class _Rescale(Processing, ABC):
"""Resize image to given image dimensions WITHOUT preserving aspect ratio.
:param output_shape: (H, W)
"""

def __init__(self, output_shape: Tuple[int, int]):
self.output_shape = output_shape

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, RescaleMetadata]:

scale_factor_h, scale_factor_w = self.output_shape[0] / image.shape[0], self.output_shape[1] / image.shape[1]
rescaled_image = _rescale_image(image, target_shape=self.output_shape)

return rescaled_image, RescaleMetadata(original_shape=image.shape[:2], scale_factor_h=scale_factor_h, scale_factor_w=scale_factor_w)


class _LongestMaxSizeRescale(Processing, ABC):
"""Resize image to given image dimensions WITH preserving aspect ratio.
:param output_shape: (H, W)
"""

def __init__(self, output_shape: Tuple[int, int]):
self.output_shape = output_shape

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, RescaleMetadata]:
height, width = image.shape[:2]
scale_factor = min(self.output_shape[0] / height, self.output_shape[1] / width)

if scale_factor != 1.0:
new_height, new_width = round(height * scale_factor), round(width * scale_factor)
image = _rescale_image(image, target_shape=(new_height, new_width))

return image, RescaleMetadata(original_shape=(height, width), scale_factor_h=scale_factor, scale_factor_w=scale_factor)


class DetectionRescale(_Rescale):
def postprocess_predictions(self, predictions: DetectionPrediction, metadata: RescaleMetadata) -> DetectionPrediction:
predictions.bboxes_xyxy = _rescale_bboxes(targets=predictions.bboxes_xyxy, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w))
return predictions


class DetectionLongestMaxSizeRescale(_LongestMaxSizeRescale):
def postprocess_predictions(self, predictions: DetectionPrediction, metadata: RescaleMetadata) -> DetectionPrediction:
predictions.bboxes_xyxy = _rescale_bboxes(targets=predictions.bboxes_xyxy, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w))
return predictions
Loading

0 comments on commit 82c03df

Please sign in to comment.