Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LitData Refactor PR1: Get individual functions for data pipelines #81

Merged
merged 5 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 221 additions & 1 deletion sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""This module implements data pipeline blocks for augmentation operations."""

from typing import Any, Dict, Optional, Tuple, Union, Iterator

import kornia as K
import torch
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
Expand All @@ -11,6 +10,227 @@
from torch.utils.data.datapipes.datapipe import IterDataPipe


def apply_intensity_augmentation(
image: torch.Tensor,
instances: torch.Tensor,
uniform_noise_min: Optional[float] = 0.0,
uniform_noise_max: Optional[float] = 0.04,
uniform_noise_p: float = 0.0,
gaussian_noise_mean: Optional[float] = 0.02,
gaussian_noise_std: Optional[float] = 0.004,
gaussian_noise_p: float = 0.0,
contrast_min: Optional[float] = 0.5,
contrast_max: Optional[float] = 2.0,
contrast_p: float = 0.0,
brightness: Optional[float] = 0.0,
brightness_p: float = 0.0,
) -> Tuple[torch.Tensor]:
"""Apply kornia intensity augmentation on image and instances.

Args:
image: Input image. Shape: (n_samples, C, H, W)
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
uniform_noise_min: Minimum value for uniform noise (uniform_noise_min >=0).
uniform_noise_max: Maximum value for uniform noise (uniform_noise_max <=1).
uniform_noise_p: Probability of applying random uniform noise.
gaussian_noise_mean: The mean of the gaussian distribution.
gaussian_noise_std: The standard deviation of the gaussian distribution.
gaussian_noise_p: Probability of applying random gaussian noise.
contrast_min: Minimum contrast factor to apply. Default: 0.5.
contrast_max: Maximum contrast factor to apply. Default: 2.0.
contrast_p: Probability of applying random contrast.
brightness: The brightness factor to apply Default: 0.0.
brightness_p: Probability of applying random brightness.

Returns:
Returns tuple: (image, instances) with augmentation applied.
"""
aug_stack = []
if uniform_noise_p > 0:
aug_stack.append(
RandomUniformNoise(
noise=(uniform_noise_min, uniform_noise_max),
p=uniform_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if gaussian_noise_p > 0:
aug_stack.append(
K.augmentation.RandomGaussianNoise(
mean=gaussian_noise_mean,
std=gaussian_noise_std,
p=gaussian_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if contrast_p > 0:
aug_stack.append(
K.augmentation.RandomContrast(
contrast=(contrast_min, contrast_max),
p=contrast_p,
keepdim=True,
same_on_batch=True,
)
)
if brightness_p > 0:
aug_stack.append(
K.augmentation.RandomBrightness(
brightness=brightness,
p=brightness_p,
keepdim=True,
same_on_batch=True,
)
)

augmenter = AugmentationSequential(
*aug_stack,
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

inst_shape = instances.shape
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
instances = instances.reshape(inst_shape[0], -1, 2)
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)

aug_image, aug_instances = augmenter(image, instances)

# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
return aug_image, aug_instances.reshape(*inst_shape)


def apply_geometric_augmentation(
image: torch.Tensor,
instances: torch.Tensor,
rotation: Optional[float] = 15.0,
scale: Union[
Optional[float], Tuple[float, float], Tuple[float, float, float, float]
] = None,
translate_width: Optional[float] = 0.02,
translate_height: Optional[float] = 0.02,
affine_p: float = 0.0,
erase_scale_min: Optional[float] = 0.0001,
erase_scale_max: Optional[float] = 0.01,
erase_ratio_min: Optional[float] = 1,
erase_ratio_max: Optional[float] = 1,
erase_p: float = 0.0,
mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None,
mixup_p: float = 0.0,
random_crop_height: int = 0,
random_crop_width: int = 0,
random_crop_p: float = 0.0,
) -> Tuple[torch.Tensor]:
"""Apply kornia geometric augmentation on image and instances.

Args:
image: Input image. Shape: (n_samples, C, H, W)
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
rotation: Angles in degrees as a scalar float of the amount of rotation. A
random angle in `(-rotation, rotation)` will be sampled and applied to both
images and keypoints. Set to 0 to disable rotation augmentation.
scale: scaling factor interval. If (a, b) represents isotropic scaling, the scale
is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale
is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d.
Default: None.
translate_width: Maximum absolute fraction for horizontal translation. For example,
if translate_width=a, then horizontal shift is randomly sampled in the range
-img_width * a < dx < img_width * a. Will not translate by default.
translate_height: Maximum absolute fraction for vertical translation. For example,
if translate_height=a, then vertical shift is randomly sampled in the range
-img_height * a < dy < img_height * a. Will not translate by default.
affine_p: Probability of applying random affine transformations.
erase_scale_min: Minimum value of range of proportion of erased area against input image. Default: 0.0001.
erase_scale_max: Maximum value of range of proportion of erased area against input image. Default: 0.01.
erase_ratio_min: Minimum value of range of aspect ratio of erased area. Default: 1.
erase_ratio_max: Maximum value of range of aspect ratio of erased area. Default: 1.
erase_p: Probability of applying random erase.
mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`.
mixup_p: Probability of applying random mixup v2.
random_crop_height: Desired output height of the crop. Must be int.
random_crop_width: Desired output width of the crop. Must be int.
random_crop_p: Probability of applying random crop.


Returns:
Returns tuple: (image, instances) with augmentation applied.
"""
if isinstance(scale, float):
scale = (scale, scale)
aug_stack = []
if affine_p > 0:
aug_stack.append(
K.augmentation.RandomAffine(
degrees=rotation,
translate=(translate_width, translate_height),
scale=scale,
p=affine_p,
keepdim=True,
same_on_batch=True,
)
)

if erase_p > 0:
aug_stack.append(
K.augmentation.RandomErasing(
scale=(erase_scale_min, erase_scale_max),
ratio=(erase_ratio_min, erase_ratio_max),
p=erase_p,
keepdim=True,
same_on_batch=True,
)
)
if mixup_p > 0:
aug_stack.append(
K.augmentation.RandomMixUpV2(
lambda_val=mixup_lambda,
p=mixup_p,
keepdim=True,
same_on_batch=True,
)
)
if random_crop_p > 0:
if random_crop_height > 0 and random_crop_width > 0:
aug_stack.append(
K.augmentation.RandomCrop(
size=(random_crop_height, random_crop_width),
pad_if_needed=True,
p=random_crop_p,
keepdim=True,
same_on_batch=True,
)
)
else:
raise ValueError(f"random_crop_hw height and width must be greater than 0.")

augmenter = AugmentationSequential(
*aug_stack,
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

inst_shape = instances.shape
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
instances = instances.reshape(inst_shape[0], -1, 2)
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)

aug_image, aug_instances = augmenter(image, instances)

# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
return aug_image, aug_instances.reshape(*inst_shape)


class RandomUniformNoise(IntensityAugmentationBase2D):
"""Data transformer for applying random uniform noise to input images.

Expand Down
88 changes: 87 additions & 1 deletion sleap_nn/data/confidence_maps.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,99 @@
"""Generate confidence maps."""

from typing import Dict, Iterator
from typing import Dict, Iterator, Tuple

import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe

from sleap_nn.data.utils import make_grid_vectors


def generate_confmaps(
instance: torch.Tensor,
img_hw: Tuple[int],
sigma: float = 1.5,
output_stride: int = 2,
) -> torch.Tensor:
"""Generate Confidence maps.

Args:
instance: Input keypoints. (n_samples, n_instances, n_nodes, 2) or
(n_samples, n_nodes, 2).
img_hw: Image size as tuple (height, width).
sigma: The standard deviation of the Gaussian distribution that is used to
generate confidence maps. Default: 1.5.
output_stride: The relative stride to use when generating confidence maps.
A larger stride will generate smaller confidence maps. Default: 2.

Returns:
Confidence maps for the input keypoints.
"""
if instance.ndim != 3:
instance = instance.view(instance.shape[0], -1, 2)
# instances: (n_samples, n_nodes, 2)

height, width = img_hw

xv, yv = make_grid_vectors(height, width, output_stride)

confidence_maps = make_confmaps(
instance,
xv,
yv,
sigma * output_stride,
) # (n_samples, n_nodes, height/ output_stride, width/ output_stride)

return confidence_maps


def generate_multiconfmaps(
instances: torch.Tensor,
img_hw: Tuple[int],
num_instances: int,
sigma: float = 1.5,
output_stride: int = 2,
is_centroids: bool = False,
) -> torch.Tensor:
"""Generate multi-instance confidence maps.

Args:
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or
for centroids - (n_samples, n_instances, 2)
img_hw: Image size as tuple (height, width).
num_instances: Original number of instances in the frame.
sigma: The standard deviation of the Gaussian distribution that is used to
generate confidence maps. Default: 1.5.
output_stride: The relative stride to use when generating confidence maps.
A larger stride will generate smaller confidence maps. Default: 2.
is_centroids: True if confidence maps should be generates for centroids else False.
Default: False.

Returns:
Confidence maps for the input keypoints.
"""
if is_centroids:
points = instances[:, :num_instances, :].unsqueeze(dim=-2)
# (n_samples, n_instances, 1, 2)
else:
points = instances[
:, :num_instances, :, :
] # (n_samples, n_instances, n_nodes, 2)

height, width = img_hw

xv, yv = make_grid_vectors(height, width, output_stride)

confidence_maps = make_multi_confmaps(
points,
xv,
yv,
sigma * output_stride,
) # (n_samples, n_nodes, height/ output_stride, width/ output_stride).
# If `is_centroids`, (n_samples, 1, height/ output_stride, width/ output_stride).

return confidence_maps


def make_confmaps(
points_batch: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float
) -> torch.Tensor:
Expand Down
Loading
Loading