Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 747 add preprocessing #804

Merged
merged 41 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
43176f2
wip
Louis-Dupont Mar 26, 2023
5a0023b
move to imageprocessors
Louis-Dupont Mar 26, 2023
1aacdfa
Merge branch 'master' into feature/SG-747-add_image_processor
Louis-Dupont Mar 26, 2023
89c48a5
wip
Louis-Dupont Mar 27, 2023
6958813
add back changes
Louis-Dupont Mar 27, 2023
4ae57b1
making it work fully for yolox and almost for ppyoloe
Louis-Dupont Mar 27, 2023
2700b80
minor change
Louis-Dupont Mar 27, 2023
b48c596
working for det
Louis-Dupont Mar 28, 2023
e5366c5
Merge branch 'master' into feature/SG-747-add_preprocessing
Louis-Dupont Mar 28, 2023
0ac4fe8
cleaning
Louis-Dupont Mar 28, 2023
24c16c8
clean
Louis-Dupont Mar 28, 2023
2735cf8
undo
Louis-Dupont Mar 28, 2023
3587cee
replace empty with none
Louis-Dupont Mar 28, 2023
4a50611
Merge branch 'master' into feature/SG-747-add_preprocessing
Louis-Dupont Mar 28, 2023
6a4250e
add _get_shift_params
Louis-Dupont Mar 28, 2023
061aa5d
minor doc change
Louis-Dupont Mar 28, 2023
0031494
Merge branch 'master' into feature/SG-747-add_preprocessing
Louis-Dupont Mar 29, 2023
2464398
replace pydantic with dataclasses and fix typing
Louis-Dupont Mar 29, 2023
d4c0774
add docstrings
Louis-Dupont Mar 29, 2023
cf19765
doc improvment and use get_shift_params in transforms
Louis-Dupont Mar 29, 2023
7e8ad22
add tests
Louis-Dupont Mar 29, 2023
90f708e
improve comment
Louis-Dupont Mar 29, 2023
8830ba9
rename
Louis-Dupont Mar 29, 2023
efd58d4
Merge branch 'master' into feature/SG-747-add_preprocessing
Louis-Dupont Mar 29, 2023
74379c6
add option to keep ratio in rescale
Louis-Dupont Mar 29, 2023
efbde36
Merge branch 'master' into feature/SG-747-add_preprocessing
Louis-Dupont Mar 29, 2023
efd023e
make functions private
Louis-Dupont Mar 29, 2023
008b77b
remove DetectionPaddedRescale
Louis-Dupont Mar 29, 2023
77addfa
fix doc
Louis-Dupont Mar 29, 2023
d6c0f9b
add fixes
Louis-Dupont Mar 30, 2023
0cb58e2
improve _get_center_padding_params output
Louis-Dupont Mar 30, 2023
f0baed7
minor fix
Louis-Dupont Mar 30, 2023
1a32cf2
add empty bbox test for rescale_bboxes
Louis-Dupont Mar 30, 2023
dcfd902
finalizing _DetectionPadding, DetectionCenterPadding and DetectionBot…
Louis-Dupont Mar 30, 2023
858ecc0
remove _pad_to_side
Louis-Dupont Mar 30, 2023
a19f591
split rescale into 2 classes
Louis-Dupont Mar 30, 2023
3229c54
minor addition
Louis-Dupont Mar 30, 2023
b012d46
Add DetectionPrediction object
Louis-Dupont Apr 2, 2023
3571780
simplify DetectionPrediction class
Louis-Dupont Apr 3, 2023
7b73edb
add round and don't rescale if no change required
Louis-Dupont Apr 3, 2023
68e5097
Merge branch 'master' into feature/SG-747-add_preprocessing
BloodAxe Apr 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ anchors:
yolo_type: 'yoloX'

depth_mult_factor: 0.33
width_mult_factor: 0.5
width_mult_factor: 0.5
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ val_dataset_params:
mean: [0.4802, 0.4481, 0.3975]
std: [0.2770, 0.2691, 0.2821]

_convert_: all
_convert_: all
183 changes: 183 additions & 0 deletions src/super_gradients/training/transforms/processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from typing import Tuple, List, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass

import numpy as np

from super_gradients.training.transforms.utils import (
rescale_image,
rescale_bboxes,
shift_image,
shift_bboxes,
rescale_and_pad_to_size,
rescale_xyxy_bboxes,
get_center_padding_params,
)


@dataclass
class ProcessingMetadata(ABC):
"""Metadata including information to postprocess a prediction."""


@dataclass
class ComposeProcessingMetadata(ProcessingMetadata):
metadata_lst: List[Union[None, ProcessingMetadata]]


@dataclass
class DetectionPadToSizeMetadata(ProcessingMetadata):
shift_w: float
shift_h: float


@dataclass
class RescaleMetadata(ProcessingMetadata):
original_size: Tuple[int, int]
sy: float
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
sx: float


@dataclass
class DetectionPaddedRescaleMetadata(ProcessingMetadata):
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
r: float


class Processing(ABC):
@abstractmethod
def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, Union[None, ProcessingMetadata]]:
"""Processing an image, before feeding it to the network. Expected to be in (H, W, C) or (H, W)."""
pass

@abstractmethod
def postprocess_predictions(self, predictions: np.ndarray, metadata: Union[None, ProcessingMetadata]) -> np.ndarray:
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
"""Postprocess the model output predictions."""
pass


class ComposeProcessing(Processing):
"""Compose a list of Processing objects into a single Processing object."""

def __init__(self, processings: List[Processing]):
self.processings = processings

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, ComposeProcessingMetadata]:
"""Processing an image, before feeding it to the network."""
processed_image, metadata_lst = image.copy(), []
for processing in self.processings:
processed_image, metadata = processing.preprocess_image(image=processed_image)
metadata_lst.append(metadata)
return processed_image, ComposeProcessingMetadata(metadata_lst=metadata_lst)

def postprocess_predictions(self, predictions: np.ndarray, metadata: ComposeProcessingMetadata) -> np.ndarray:
"""Postprocess the model output predictions."""
postprocessed_predictions = predictions
for processing, metadata in zip(self.processings[::-1], metadata.metadata_lst[::-1]):
postprocessed_predictions = processing.postprocess_predictions(postprocessed_predictions, metadata)
return postprocessed_predictions


class ImagePermute(Processing):
"""Permute the image dimensions.

:param permutation: Specify new order of dims. Default value (2, 0, 1) suitable for converting from HWC to CHW format.
"""

def __init__(self, permutation: Tuple[int, int, int] = (2, 0, 1)):
self.permutation = permutation

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
processed_image = np.ascontiguousarray(image.transpose(*self.permutation))
return processed_image, None

def postprocess_predictions(self, predictions: np.ndarray, metadata: None) -> np.ndarray:
return predictions


class NormalizeImage(Processing):
"""Normalize an image based on means and standard deviation.

:param mean: Mean values for each channel.
:param std: Standard deviation values for each channel.
"""

def __init__(self, mean: List[float], std: List[float]):
self.mean = np.array(mean).reshape((1, 1, -1)).astype(np.float32)
self.std = np.array(std).reshape((1, 1, -1)).astype(np.float32)

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
return (image - self.mean) / self.std, None

def postprocess_predictions(self, predictions: np.ndarray, metadata: None) -> np.ndarray:
return predictions


class DetectionPaddedRescale(Processing):
"""Apply padding rescaling to image and bboxes to `output_size` shape (rows, cols).

:param output_size: Target input dimension.
:param swap: Image axis's to be rearranged.
:param pad_value: Padding value for image.
"""

def __init__(self, output_size: Tuple[int, int], swap: Tuple[int, ...] = (2, 0, 1), pad_value: int = 114):
self.output_size = output_size
self.swap = swap
self.pad_value = pad_value

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, DetectionPaddedRescaleMetadata]:
rescaled_image, r = rescale_and_pad_to_size(image=image, output_size=self.output_size, swap=self.swap, pad_val=self.pad_value)
return rescaled_image, DetectionPaddedRescaleMetadata(r=r)

def postprocess_predictions(self, predictions: np.array, metadata=DetectionPaddedRescaleMetadata) -> np.array:
return rescale_xyxy_bboxes(targets=predictions, r=1 / metadata.r)


class DetectionPadToSize(Processing):
"""Preprocessing transform to pad image and bboxes to `output_size` shape (rows, cols).
Center padding, so that input image with bboxes located in the center of the produced image.

Note: This transformation assume that dimensions of input image is equal or less than `output_size`.

:param output_size: Output image size (rows, cols)
:param pad_value: Padding value for image
"""

def __init__(self, output_size: Tuple[int, int], pad_value: int):
self.output_size = output_size
self.pad_value = pad_value

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, DetectionPadToSizeMetadata]:
shift_h, shift_w, pad_h, pad_w = get_center_padding_params(input_size=image.shape, output_size=self.output_size)
processed_image = shift_image(image, pad_h, pad_w, self.pad_value)

return processed_image, DetectionPadToSizeMetadata(shift_h=shift_h, shift_w=shift_w)

def postprocess_predictions(self, predictions: np.ndarray, metadata: DetectionPadToSizeMetadata) -> np.ndarray:
return shift_bboxes(targets=predictions, shift_w=-metadata.shift_w, shift_h=-metadata.shift_h)


class _Rescale(Processing, ABC):
"""Resize image to given image dimensions without preserving aspect ratio.

:param output_shape: (rows, cols)
"""

def __init__(self, output_shape: Tuple[int, int]):
self.output_shape = output_shape
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, RescaleMetadata]:
sy, sx = self.output_shape[0] / image.shape[0], self.output_shape[1] / image.shape[1]
rescaled_image = rescale_image(image, target_shape=self.output_shape)

return rescaled_image, RescaleMetadata(original_size=image.shape[:2], sy=sy, sx=sx)


class DetectionRescale(_Rescale):
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
def postprocess_predictions(self, predictions: np.ndarray, metadata: RescaleMetadata) -> np.ndarray:
return rescale_bboxes(targets=predictions, scale_factors=(1 / metadata.sy, 1 / metadata.sx))


class SegmentationRescale(_Rescale):
def postprocess_predictions(self, predictions: np.ndarray, metadata: RescaleMetadata) -> np.ndarray:
return rescale_image(predictions, target_shape=metadata.original_size)
Loading