Skip to content

Commit

Permalink
LitData Refactor PR1: Get individual functions for data pipelines (#81)
Browse files Browse the repository at this point in the history
* Add functions for data pipelines

* Add test cases

* Format file

* Add more test cases

* Fix augmentation test
  • Loading branch information
gitttt-1234 authored Sep 11, 2024
1 parent 5a5d590 commit 3514464
Show file tree
Hide file tree
Showing 18 changed files with 831 additions and 50 deletions.
222 changes: 221 additions & 1 deletion sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""This module implements data pipeline blocks for augmentation operations."""

from typing import Any, Dict, Optional, Tuple, Union, Iterator

import kornia as K
import torch
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
Expand All @@ -11,6 +10,227 @@
from torch.utils.data.datapipes.datapipe import IterDataPipe


def apply_intensity_augmentation(
image: torch.Tensor,
instances: torch.Tensor,
uniform_noise_min: Optional[float] = 0.0,
uniform_noise_max: Optional[float] = 0.04,
uniform_noise_p: float = 0.0,
gaussian_noise_mean: Optional[float] = 0.02,
gaussian_noise_std: Optional[float] = 0.004,
gaussian_noise_p: float = 0.0,
contrast_min: Optional[float] = 0.5,
contrast_max: Optional[float] = 2.0,
contrast_p: float = 0.0,
brightness: Optional[float] = 0.0,
brightness_p: float = 0.0,
) -> Tuple[torch.Tensor]:
"""Apply kornia intensity augmentation on image and instances.
Args:
image: Input image. Shape: (n_samples, C, H, W)
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
uniform_noise_min: Minimum value for uniform noise (uniform_noise_min >=0).
uniform_noise_max: Maximum value for uniform noise (uniform_noise_max <=1).
uniform_noise_p: Probability of applying random uniform noise.
gaussian_noise_mean: The mean of the gaussian distribution.
gaussian_noise_std: The standard deviation of the gaussian distribution.
gaussian_noise_p: Probability of applying random gaussian noise.
contrast_min: Minimum contrast factor to apply. Default: 0.5.
contrast_max: Maximum contrast factor to apply. Default: 2.0.
contrast_p: Probability of applying random contrast.
brightness: The brightness factor to apply Default: 0.0.
brightness_p: Probability of applying random brightness.
Returns:
Returns tuple: (image, instances) with augmentation applied.
"""
aug_stack = []
if uniform_noise_p > 0:
aug_stack.append(
RandomUniformNoise(
noise=(uniform_noise_min, uniform_noise_max),
p=uniform_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if gaussian_noise_p > 0:
aug_stack.append(
K.augmentation.RandomGaussianNoise(
mean=gaussian_noise_mean,
std=gaussian_noise_std,
p=gaussian_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if contrast_p > 0:
aug_stack.append(
K.augmentation.RandomContrast(
contrast=(contrast_min, contrast_max),
p=contrast_p,
keepdim=True,
same_on_batch=True,
)
)
if brightness_p > 0:
aug_stack.append(
K.augmentation.RandomBrightness(
brightness=brightness,
p=brightness_p,
keepdim=True,
same_on_batch=True,
)
)

augmenter = AugmentationSequential(
*aug_stack,
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

inst_shape = instances.shape
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
instances = instances.reshape(inst_shape[0], -1, 2)
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)

aug_image, aug_instances = augmenter(image, instances)

# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
return aug_image, aug_instances.reshape(*inst_shape)


def apply_geometric_augmentation(
image: torch.Tensor,
instances: torch.Tensor,
rotation: Optional[float] = 15.0,
scale: Union[
Optional[float], Tuple[float, float], Tuple[float, float, float, float]
] = None,
translate_width: Optional[float] = 0.02,
translate_height: Optional[float] = 0.02,
affine_p: float = 0.0,
erase_scale_min: Optional[float] = 0.0001,
erase_scale_max: Optional[float] = 0.01,
erase_ratio_min: Optional[float] = 1,
erase_ratio_max: Optional[float] = 1,
erase_p: float = 0.0,
mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None,
mixup_p: float = 0.0,
random_crop_height: int = 0,
random_crop_width: int = 0,
random_crop_p: float = 0.0,
) -> Tuple[torch.Tensor]:
"""Apply kornia geometric augmentation on image and instances.
Args:
image: Input image. Shape: (n_samples, C, H, W)
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
rotation: Angles in degrees as a scalar float of the amount of rotation. A
random angle in `(-rotation, rotation)` will be sampled and applied to both
images and keypoints. Set to 0 to disable rotation augmentation.
scale: scaling factor interval. If (a, b) represents isotropic scaling, the scale
is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale
is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d.
Default: None.
translate_width: Maximum absolute fraction for horizontal translation. For example,
if translate_width=a, then horizontal shift is randomly sampled in the range
-img_width * a < dx < img_width * a. Will not translate by default.
translate_height: Maximum absolute fraction for vertical translation. For example,
if translate_height=a, then vertical shift is randomly sampled in the range
-img_height * a < dy < img_height * a. Will not translate by default.
affine_p: Probability of applying random affine transformations.
erase_scale_min: Minimum value of range of proportion of erased area against input image. Default: 0.0001.
erase_scale_max: Maximum value of range of proportion of erased area against input image. Default: 0.01.
erase_ratio_min: Minimum value of range of aspect ratio of erased area. Default: 1.
erase_ratio_max: Maximum value of range of aspect ratio of erased area. Default: 1.
erase_p: Probability of applying random erase.
mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`.
mixup_p: Probability of applying random mixup v2.
random_crop_height: Desired output height of the crop. Must be int.
random_crop_width: Desired output width of the crop. Must be int.
random_crop_p: Probability of applying random crop.
Returns:
Returns tuple: (image, instances) with augmentation applied.
"""
if isinstance(scale, float):
scale = (scale, scale)
aug_stack = []
if affine_p > 0:
aug_stack.append(
K.augmentation.RandomAffine(
degrees=rotation,
translate=(translate_width, translate_height),
scale=scale,
p=affine_p,
keepdim=True,
same_on_batch=True,
)
)

if erase_p > 0:
aug_stack.append(
K.augmentation.RandomErasing(
scale=(erase_scale_min, erase_scale_max),
ratio=(erase_ratio_min, erase_ratio_max),
p=erase_p,
keepdim=True,
same_on_batch=True,
)
)
if mixup_p > 0:
aug_stack.append(
K.augmentation.RandomMixUpV2(
lambda_val=mixup_lambda,
p=mixup_p,
keepdim=True,
same_on_batch=True,
)
)
if random_crop_p > 0:
if random_crop_height > 0 and random_crop_width > 0:
aug_stack.append(
K.augmentation.RandomCrop(
size=(random_crop_height, random_crop_width),
pad_if_needed=True,
p=random_crop_p,
keepdim=True,
same_on_batch=True,
)
)
else:
raise ValueError(f"random_crop_hw height and width must be greater than 0.")

augmenter = AugmentationSequential(
*aug_stack,
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

inst_shape = instances.shape
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
instances = instances.reshape(inst_shape[0], -1, 2)
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)

aug_image, aug_instances = augmenter(image, instances)

# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
return aug_image, aug_instances.reshape(*inst_shape)


class RandomUniformNoise(IntensityAugmentationBase2D):
"""Data transformer for applying random uniform noise to input images.
Expand Down
88 changes: 87 additions & 1 deletion sleap_nn/data/confidence_maps.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,99 @@
"""Generate confidence maps."""

from typing import Dict, Iterator
from typing import Dict, Iterator, Tuple

import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe

from sleap_nn.data.utils import make_grid_vectors


def generate_confmaps(
instance: torch.Tensor,
img_hw: Tuple[int],
sigma: float = 1.5,
output_stride: int = 2,
) -> torch.Tensor:
"""Generate Confidence maps.
Args:
instance: Input keypoints. (n_samples, n_instances, n_nodes, 2) or
(n_samples, n_nodes, 2).
img_hw: Image size as tuple (height, width).
sigma: The standard deviation of the Gaussian distribution that is used to
generate confidence maps. Default: 1.5.
output_stride: The relative stride to use when generating confidence maps.
A larger stride will generate smaller confidence maps. Default: 2.
Returns:
Confidence maps for the input keypoints.
"""
if instance.ndim != 3:
instance = instance.view(instance.shape[0], -1, 2)
# instances: (n_samples, n_nodes, 2)

height, width = img_hw

xv, yv = make_grid_vectors(height, width, output_stride)

confidence_maps = make_confmaps(
instance,
xv,
yv,
sigma * output_stride,
) # (n_samples, n_nodes, height/ output_stride, width/ output_stride)

return confidence_maps


def generate_multiconfmaps(
instances: torch.Tensor,
img_hw: Tuple[int],
num_instances: int,
sigma: float = 1.5,
output_stride: int = 2,
is_centroids: bool = False,
) -> torch.Tensor:
"""Generate multi-instance confidence maps.
Args:
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or
for centroids - (n_samples, n_instances, 2)
img_hw: Image size as tuple (height, width).
num_instances: Original number of instances in the frame.
sigma: The standard deviation of the Gaussian distribution that is used to
generate confidence maps. Default: 1.5.
output_stride: The relative stride to use when generating confidence maps.
A larger stride will generate smaller confidence maps. Default: 2.
is_centroids: True if confidence maps should be generates for centroids else False.
Default: False.
Returns:
Confidence maps for the input keypoints.
"""
if is_centroids:
points = instances[:, :num_instances, :].unsqueeze(dim=-2)
# (n_samples, n_instances, 1, 2)
else:
points = instances[
:, :num_instances, :, :
] # (n_samples, n_instances, n_nodes, 2)

height, width = img_hw

xv, yv = make_grid_vectors(height, width, output_stride)

confidence_maps = make_multi_confmaps(
points,
xv,
yv,
sigma * output_stride,
) # (n_samples, n_nodes, height/ output_stride, width/ output_stride).
# If `is_centroids`, (n_samples, 1, height/ output_stride, width/ output_stride).

return confidence_maps


def make_confmaps(
points_batch: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float
) -> torch.Tensor:
Expand Down
Loading

0 comments on commit 3514464

Please sign in to comment.