Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LitData Refactor PR1: Get individual functions for data pipelines #90

Merged
merged 3 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ The config file has three main sections:
- `provider`: (str) Provider class to read the input sleap files. Only "LabelsReader" supported for the training pipeline.
- `train_labels_path`: (str) Path to training data (`.slp` file)
- `val_labels_path`: (str) Path to validation data (`.slp` file)
- `user_instances_only`: (bool) `True` if only user labeled instances should be used for training. If `False`, both user labeled and predicted instances would be used. *Default*: `True`.
- `chunk_size`: (int) Size of each chunk (in MB). *Default*: "100".
#TODO: change in inference ckpts
- `preprocessing`:
- `is_rgb`: (bool) True if the image has 3 channels (RGB image). If input has only one
channel when this is set to `True`, then the images from single-channel
Expand All @@ -30,7 +33,6 @@ The config file has three main sections:
- `min_crop_size`: (int) Minimum crop size to be used if `crop_hw` is `None`.
- `use_augmentations_train`: (bool) True if the data augmentation should be applied to the training data, else False.
- `augmentation_config`: (only if `use_augmentations` is `True`)
- `random crop`: (Optional) (Dict[float]) {"random_crop_p": None, "crop_height": None. "crop_width": None}, where *random_crop_p* is the probability of applying random crop and *crop_height* and *crop_width* are the desired output size (out_h, out_w) of the crop.
- `intensity`: (Optional)
- `uniform_noise_min`: (float) Minimum value for uniform noise (uniform_noise_min >=0).
- `uniform_noise_max`: (float) Maximum value for uniform noise (uniform_noise_max <>=1).
Expand All @@ -57,7 +59,9 @@ The config file has three main sections:
- `mixup_lambda`: (float) min-max value of mixup strength. Default is 0-1. *Default*: `None`.
- `mixup_p`: (float) Probability of applying random mixup v2. *Default*=0.0
- `input_key`: (str) Can be `image` or `instance`. The input_key `instance` expects the KorniaAugmenter to follow the InstanceCropper else `image` otherwise for default.

- `random_crop_p`: (float) Probability of applying random crop.
- `random_crop_height`: (int) Desired output height of the random crop.
- `random_crop_width`: (int) Desired output height of the random crop.
- `model_config`:
- `init_weight`: (str) model weights initialization method. "default" uses kaiming uniform initialization and "xavier" uses Xavier initialization method.
- `pre_trained_weights`: (str) Pretrained weights file name supported only for ConvNext and SwinT backbones. For ConvNext, one of ["ConvNeXt_Base_Weights","ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]. For SwinT, one of ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"].
Expand Down
1 change: 1 addition & 0 deletions docs/config_bottomup.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ data_config:
provider: LabelsReader
train_labels_path: minimal_instance.pkg.slp
val_labels_path: minimal_instance.pkg.slp
user_instances_only: True
preprocessing:
max_width: null
max_height: null
Expand Down
8 changes: 4 additions & 4 deletions docs/config_centroid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@ data_config:
provider: LabelsReader
train_labels_path: minimal_instance.pkg.slp
val_labels_path: minimal_instance.pkg.slp
user_instances_only: True
preprocessing:
max_width:
max_height:
scale: 0.5
is_rgb: False
use_augmentations_train: true
augmentation_config: # sample augmentation_config
random_crop:
random_crop_p: 0
crop_height: 160
crop_width: 160
intensity:
uniform_noise_min: 0.0
uniform_noise_max: 0.04
Expand All @@ -38,6 +35,9 @@ data_config:
erase_p: 0
mixup_lambda: null
mixup_p: 0
random_crop_p: 0
crop_height: 160
crop_width: 160

model_config:
init_weights: xavier
Expand Down
1 change: 1 addition & 0 deletions docs/config_topdown_centered_instance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ data_config:
provider: LabelsReader
train_labels_path: minimal_instance.pkg.slp
val_labels_path: minimal_instance.pkg.slp
user_instances_only: True
preprocessing:
max_width:
max_height:
Expand Down
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ dependencies:
- ndx-pose
- pip
- pip:
- "--editable=.[dev]"
- "--editable=.[dev]"
- litdata
3 changes: 2 additions & 1 deletion environment_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ dependencies:
- ndx-pose
- pip
- pip:
- "--editable=.[dev]"
- "--editable=.[dev]"
- litdata
3 changes: 2 additions & 1 deletion environment_mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ dependencies:
- ndx-pose
- pip
- pip:
- "--editable=.[dev]"
- "--editable=.[dev]"
- litdata
222 changes: 221 additions & 1 deletion sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""This module implements data pipeline blocks for augmentation operations."""

from typing import Any, Dict, Optional, Tuple, Union, Iterator

import kornia as K
import torch
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
Expand All @@ -11,6 +10,227 @@
from torch.utils.data.datapipes.datapipe import IterDataPipe


def apply_intensity_augmentation(
image: torch.Tensor,
instances: torch.Tensor,
uniform_noise_min: Optional[float] = 0.0,
uniform_noise_max: Optional[float] = 0.04,
uniform_noise_p: float = 0.0,
gaussian_noise_mean: Optional[float] = 0.02,
gaussian_noise_std: Optional[float] = 0.004,
gaussian_noise_p: float = 0.0,
contrast_min: Optional[float] = 0.5,
contrast_max: Optional[float] = 2.0,
contrast_p: float = 0.0,
brightness: Optional[float] = 0.0,
brightness_p: float = 0.0,
) -> Tuple[torch.Tensor]:
"""Apply kornia intensity augmentation on image and instances.

Args:
image: Input image. Shape: (n_samples, C, H, W)
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
uniform_noise_min: Minimum value for uniform noise (uniform_noise_min >=0).
uniform_noise_max: Maximum value for uniform noise (uniform_noise_max <=1).
uniform_noise_p: Probability of applying random uniform noise.
gaussian_noise_mean: The mean of the gaussian distribution.
gaussian_noise_std: The standard deviation of the gaussian distribution.
gaussian_noise_p: Probability of applying random gaussian noise.
contrast_min: Minimum contrast factor to apply. Default: 0.5.
contrast_max: Maximum contrast factor to apply. Default: 2.0.
contrast_p: Probability of applying random contrast.
brightness: The brightness factor to apply Default: 0.0.
brightness_p: Probability of applying random brightness.

Returns:
Returns tuple: (image, instances) with augmentation applied.
"""
aug_stack = []
if uniform_noise_p > 0:
aug_stack.append(
RandomUniformNoise(
noise=(uniform_noise_min, uniform_noise_max),
p=uniform_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if gaussian_noise_p > 0:
aug_stack.append(
K.augmentation.RandomGaussianNoise(
mean=gaussian_noise_mean,
std=gaussian_noise_std,
p=gaussian_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if contrast_p > 0:
aug_stack.append(
K.augmentation.RandomContrast(
contrast=(contrast_min, contrast_max),
p=contrast_p,
keepdim=True,
same_on_batch=True,
)
)
if brightness_p > 0:
aug_stack.append(
K.augmentation.RandomBrightness(
brightness=brightness,
p=brightness_p,
keepdim=True,
same_on_batch=True,
)
)

augmenter = AugmentationSequential(
*aug_stack,
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

inst_shape = instances.shape
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
instances = instances.reshape(inst_shape[0], -1, 2)
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)

aug_image, aug_instances = augmenter(image, instances)

# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
return aug_image, aug_instances.reshape(*inst_shape)


def apply_geometric_augmentation(
image: torch.Tensor,
instances: torch.Tensor,
rotation: Optional[float] = 15.0,
scale: Union[
Optional[float], Tuple[float, float], Tuple[float, float, float, float]
] = None,
translate_width: Optional[float] = 0.02,
translate_height: Optional[float] = 0.02,
affine_p: float = 0.0,
erase_scale_min: Optional[float] = 0.0001,
erase_scale_max: Optional[float] = 0.01,
erase_ratio_min: Optional[float] = 1,
erase_ratio_max: Optional[float] = 1,
erase_p: float = 0.0,
mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None,
mixup_p: float = 0.0,
random_crop_height: int = 0,
random_crop_width: int = 0,
random_crop_p: float = 0.0,
) -> Tuple[torch.Tensor]:
"""Apply kornia geometric augmentation on image and instances.

Args:
image: Input image. Shape: (n_samples, C, H, W)
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
rotation: Angles in degrees as a scalar float of the amount of rotation. A
random angle in `(-rotation, rotation)` will be sampled and applied to both
images and keypoints. Set to 0 to disable rotation augmentation.
scale: scaling factor interval. If (a, b) represents isotropic scaling, the scale
is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale
is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d.
Default: None.
translate_width: Maximum absolute fraction for horizontal translation. For example,
if translate_width=a, then horizontal shift is randomly sampled in the range
-img_width * a < dx < img_width * a. Will not translate by default.
translate_height: Maximum absolute fraction for vertical translation. For example,
if translate_height=a, then vertical shift is randomly sampled in the range
-img_height * a < dy < img_height * a. Will not translate by default.
affine_p: Probability of applying random affine transformations.
erase_scale_min: Minimum value of range of proportion of erased area against input image. Default: 0.0001.
erase_scale_max: Maximum value of range of proportion of erased area against input image. Default: 0.01.
erase_ratio_min: Minimum value of range of aspect ratio of erased area. Default: 1.
erase_ratio_max: Maximum value of range of aspect ratio of erased area. Default: 1.
erase_p: Probability of applying random erase.
mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`.
mixup_p: Probability of applying random mixup v2.
random_crop_height: Desired output height of the crop. Must be int.
random_crop_width: Desired output width of the crop. Must be int.
random_crop_p: Probability of applying random crop.


Returns:
Returns tuple: (image, instances) with augmentation applied.
"""
if isinstance(scale, float):
scale = (scale, scale)
aug_stack = []
if affine_p > 0:
aug_stack.append(
K.augmentation.RandomAffine(
degrees=rotation,
translate=(translate_width, translate_height),
scale=scale,
p=affine_p,
keepdim=True,
same_on_batch=True,
)
)

if erase_p > 0:
aug_stack.append(
K.augmentation.RandomErasing(
scale=(erase_scale_min, erase_scale_max),
ratio=(erase_ratio_min, erase_ratio_max),
p=erase_p,
keepdim=True,
same_on_batch=True,
)
)
if mixup_p > 0:
aug_stack.append(
K.augmentation.RandomMixUpV2(
lambda_val=mixup_lambda,
p=mixup_p,
keepdim=True,
same_on_batch=True,
)
)
if random_crop_p > 0:
if random_crop_height > 0 and random_crop_width > 0:
aug_stack.append(
K.augmentation.RandomCrop(
size=(random_crop_height, random_crop_width),
pad_if_needed=True,
p=random_crop_p,
keepdim=True,
same_on_batch=True,
)
)
else:
raise ValueError(f"random_crop_hw height and width must be greater than 0.")

augmenter = AugmentationSequential(
*aug_stack,
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

inst_shape = instances.shape
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
instances = instances.reshape(inst_shape[0], -1, 2)
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)

aug_image, aug_instances = augmenter(image, instances)

# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
return aug_image, aug_instances.reshape(*inst_shape)


class RandomUniformNoise(IntensityAugmentationBase2D):
"""Data transformer for applying random uniform noise to input images.

Expand Down
Loading
Loading