Skip to content

Commit

Permalink
Resume training and automatically compute crop size for TopDownConfma…
Browse files Browse the repository at this point in the history
…ps pipeline (#79)

* Add option to automatically compute crop size

* Move find_crop_size to Trainer

* Fix skeleton name

* Add crop size to config

* Add resumable training option

* Add tests fore resuming training

* Fix tests

* Fix test for wandb folder

* LitData Refactor PR1: Get individual functions for data pipelines (#81)

* Add functions for data pipelines

* Add test cases

* Format file

* Add more test cases

* Fix augmentation test
  • Loading branch information
gitttt-1234 committed Sep 11, 2024
1 parent 6c9262c commit 5e77de9
Show file tree
Hide file tree
Showing 34 changed files with 1,060 additions and 98 deletions.
6 changes: 4 additions & 2 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ The config file has three main sections:
- `max_width`: (int) Maximum width the image should be padded to. If not provided, the
original image size will be retained. Default: None.
- `scale`: (float or List[float]) Factor to resize the image dimensions by, specified as either a float scalar or as a 2-tuple of [scale_x, scale_y]. If a scalar is provided, both dimensions are resized by the same factor.
- `crop_hw`: (List[int]) Crop height and width of each instance (h, w) for centered-instance model.
- `crop_hw`: (Tuple[int]) Crop height and width of each instance (h, w) for centered-instance model. If `None`, this would be automatically computed based on the largest instance in the `sio.Labels` file.
- `min_crop_size`: (int) Minimum crop size to be used if `crop_hw` is `None`.
- `use_augmentations_train`: (bool) True if the data augmentation should be applied to the training data, else False.
- `augmentation_config`: (only if `use_augmentations` is `True`)
- `random crop`: (Optional) (Dict[float]) {"random_crop_p": None, "crop_height": None. "crop_width": None}, where *random_crop_p* is the probability of applying random crop and *crop_height* and *crop_width* are the desired output size (out_h, out_w) of the crop.
Expand Down Expand Up @@ -156,12 +157,13 @@ The config file has three main sections:
- `use_wandb`: (bool) True to enable wandb logging.
- `save_ckpt`: (bool) True to enable checkpointing.
- `save_ckpt_path`: (str) Directory path to save the training config and checkpoint files. *Default*: "./"
- `resume_ckpt_path`: (str) Path to `.ckpt` file from which training is resumed. *Default*: `None`.
- `wandb`: (Only if `use_wandb` is `True`, else skip this)
- `entity`: (str) Entity of wandb project.
- `project`: (str) Project name for the wandb project.
- `name`: (str) Name of the current run.
- `api_key`: (str) API key. The API key is masked when saved to config files.
- `wandb_mode`: (str) "offline" if only local logging is required. Default: "None".
- `prv_runid`: (str) Previous run ID if training should be resumed from a previous ckpt. *Default*: `None`.
- `log_params`: (List[str]) List of config parameters to save it in wandb logs. For example, to save learning rate from trainer config section, use "trainer_config.optimizer.lr" (provide the full path to the specific config parameter).
- `optimizer_name`: (str) Optimizer to be used. One of ["Adam", "AdamW"].
- `optimizer`
Expand Down
1 change: 1 addition & 0 deletions docs/config_bottomup.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ trainer_config:
use_wandb: false
save_ckpt: true
save_ckpt_path: min_inst_bottomup1
resume_ckpt_path:
optimizer_name: Adam
optimizer:
lr: 0.0001
Expand Down
3 changes: 2 additions & 1 deletion docs/config_centroid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,13 @@ trainer_config:
use_wandb: false
save_ckpt: true
save_ckpt_path: 'min_inst_centroid'
resume_ckpt_path:
wandb: # sample wandb config
entity:
project: 'test_centroid_centered'
name: 'fly_unet_centered'
wandb_mode: ''
api_key: ''
prv_runid:
log_params:
- trainer_config.optimizer_name
- trainer_config.optimizer.amsgrad
Expand Down
1 change: 1 addition & 0 deletions docs/config_topdown_centered_instance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ trainer_config:
use_wandb: false
save_ckpt: true
save_ckpt_path: 'min_inst_centered'
resume_ckpt_path:
optimizer_name: Adam
optimizer:
lr: 0.0001
Expand Down
222 changes: 221 additions & 1 deletion sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""This module implements data pipeline blocks for augmentation operations."""

from typing import Any, Dict, Optional, Tuple, Union, Iterator

import kornia as K
import torch
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
Expand All @@ -11,6 +10,227 @@
from torch.utils.data.datapipes.datapipe import IterDataPipe


def apply_intensity_augmentation(
image: torch.Tensor,
instances: torch.Tensor,
uniform_noise_min: Optional[float] = 0.0,
uniform_noise_max: Optional[float] = 0.04,
uniform_noise_p: float = 0.0,
gaussian_noise_mean: Optional[float] = 0.02,
gaussian_noise_std: Optional[float] = 0.004,
gaussian_noise_p: float = 0.0,
contrast_min: Optional[float] = 0.5,
contrast_max: Optional[float] = 2.0,
contrast_p: float = 0.0,
brightness: Optional[float] = 0.0,
brightness_p: float = 0.0,
) -> Tuple[torch.Tensor]:
"""Apply kornia intensity augmentation on image and instances.
Args:
image: Input image. Shape: (n_samples, C, H, W)
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
uniform_noise_min: Minimum value for uniform noise (uniform_noise_min >=0).
uniform_noise_max: Maximum value for uniform noise (uniform_noise_max <=1).
uniform_noise_p: Probability of applying random uniform noise.
gaussian_noise_mean: The mean of the gaussian distribution.
gaussian_noise_std: The standard deviation of the gaussian distribution.
gaussian_noise_p: Probability of applying random gaussian noise.
contrast_min: Minimum contrast factor to apply. Default: 0.5.
contrast_max: Maximum contrast factor to apply. Default: 2.0.
contrast_p: Probability of applying random contrast.
brightness: The brightness factor to apply Default: 0.0.
brightness_p: Probability of applying random brightness.
Returns:
Returns tuple: (image, instances) with augmentation applied.
"""
aug_stack = []
if uniform_noise_p > 0:
aug_stack.append(
RandomUniformNoise(
noise=(uniform_noise_min, uniform_noise_max),
p=uniform_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if gaussian_noise_p > 0:
aug_stack.append(
K.augmentation.RandomGaussianNoise(
mean=gaussian_noise_mean,
std=gaussian_noise_std,
p=gaussian_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if contrast_p > 0:
aug_stack.append(
K.augmentation.RandomContrast(
contrast=(contrast_min, contrast_max),
p=contrast_p,
keepdim=True,
same_on_batch=True,
)
)
if brightness_p > 0:
aug_stack.append(
K.augmentation.RandomBrightness(
brightness=brightness,
p=brightness_p,
keepdim=True,
same_on_batch=True,
)
)

augmenter = AugmentationSequential(
*aug_stack,
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

inst_shape = instances.shape
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
instances = instances.reshape(inst_shape[0], -1, 2)
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)

aug_image, aug_instances = augmenter(image, instances)

# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
return aug_image, aug_instances.reshape(*inst_shape)


def apply_geometric_augmentation(
image: torch.Tensor,
instances: torch.Tensor,
rotation: Optional[float] = 15.0,
scale: Union[
Optional[float], Tuple[float, float], Tuple[float, float, float, float]
] = None,
translate_width: Optional[float] = 0.02,
translate_height: Optional[float] = 0.02,
affine_p: float = 0.0,
erase_scale_min: Optional[float] = 0.0001,
erase_scale_max: Optional[float] = 0.01,
erase_ratio_min: Optional[float] = 1,
erase_ratio_max: Optional[float] = 1,
erase_p: float = 0.0,
mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None,
mixup_p: float = 0.0,
random_crop_height: int = 0,
random_crop_width: int = 0,
random_crop_p: float = 0.0,
) -> Tuple[torch.Tensor]:
"""Apply kornia geometric augmentation on image and instances.
Args:
image: Input image. Shape: (n_samples, C, H, W)
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
rotation: Angles in degrees as a scalar float of the amount of rotation. A
random angle in `(-rotation, rotation)` will be sampled and applied to both
images and keypoints. Set to 0 to disable rotation augmentation.
scale: scaling factor interval. If (a, b) represents isotropic scaling, the scale
is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale
is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d.
Default: None.
translate_width: Maximum absolute fraction for horizontal translation. For example,
if translate_width=a, then horizontal shift is randomly sampled in the range
-img_width * a < dx < img_width * a. Will not translate by default.
translate_height: Maximum absolute fraction for vertical translation. For example,
if translate_height=a, then vertical shift is randomly sampled in the range
-img_height * a < dy < img_height * a. Will not translate by default.
affine_p: Probability of applying random affine transformations.
erase_scale_min: Minimum value of range of proportion of erased area against input image. Default: 0.0001.
erase_scale_max: Maximum value of range of proportion of erased area against input image. Default: 0.01.
erase_ratio_min: Minimum value of range of aspect ratio of erased area. Default: 1.
erase_ratio_max: Maximum value of range of aspect ratio of erased area. Default: 1.
erase_p: Probability of applying random erase.
mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`.
mixup_p: Probability of applying random mixup v2.
random_crop_height: Desired output height of the crop. Must be int.
random_crop_width: Desired output width of the crop. Must be int.
random_crop_p: Probability of applying random crop.
Returns:
Returns tuple: (image, instances) with augmentation applied.
"""
if isinstance(scale, float):
scale = (scale, scale)
aug_stack = []
if affine_p > 0:
aug_stack.append(
K.augmentation.RandomAffine(
degrees=rotation,
translate=(translate_width, translate_height),
scale=scale,
p=affine_p,
keepdim=True,
same_on_batch=True,
)
)

if erase_p > 0:
aug_stack.append(
K.augmentation.RandomErasing(
scale=(erase_scale_min, erase_scale_max),
ratio=(erase_ratio_min, erase_ratio_max),
p=erase_p,
keepdim=True,
same_on_batch=True,
)
)
if mixup_p > 0:
aug_stack.append(
K.augmentation.RandomMixUpV2(
lambda_val=mixup_lambda,
p=mixup_p,
keepdim=True,
same_on_batch=True,
)
)
if random_crop_p > 0:
if random_crop_height > 0 and random_crop_width > 0:
aug_stack.append(
K.augmentation.RandomCrop(
size=(random_crop_height, random_crop_width),
pad_if_needed=True,
p=random_crop_p,
keepdim=True,
same_on_batch=True,
)
)
else:
raise ValueError(f"random_crop_hw height and width must be greater than 0.")

augmenter = AugmentationSequential(
*aug_stack,
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

inst_shape = instances.shape
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
instances = instances.reshape(inst_shape[0], -1, 2)
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)

aug_image, aug_instances = augmenter(image, instances)

# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
return aug_image, aug_instances.reshape(*inst_shape)


class RandomUniformNoise(IntensityAugmentationBase2D):
"""Data transformer for applying random uniform noise to input images.
Expand Down
88 changes: 87 additions & 1 deletion sleap_nn/data/confidence_maps.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,99 @@
"""Generate confidence maps."""

from typing import Dict, Iterator
from typing import Dict, Iterator, Tuple

import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe

from sleap_nn.data.utils import make_grid_vectors


def generate_confmaps(
instance: torch.Tensor,
img_hw: Tuple[int],
sigma: float = 1.5,
output_stride: int = 2,
) -> torch.Tensor:
"""Generate Confidence maps.
Args:
instance: Input keypoints. (n_samples, n_instances, n_nodes, 2) or
(n_samples, n_nodes, 2).
img_hw: Image size as tuple (height, width).
sigma: The standard deviation of the Gaussian distribution that is used to
generate confidence maps. Default: 1.5.
output_stride: The relative stride to use when generating confidence maps.
A larger stride will generate smaller confidence maps. Default: 2.
Returns:
Confidence maps for the input keypoints.
"""
if instance.ndim != 3:
instance = instance.view(instance.shape[0], -1, 2)
# instances: (n_samples, n_nodes, 2)

height, width = img_hw

xv, yv = make_grid_vectors(height, width, output_stride)

confidence_maps = make_confmaps(
instance,
xv,
yv,
sigma * output_stride,
) # (n_samples, n_nodes, height/ output_stride, width/ output_stride)

return confidence_maps


def generate_multiconfmaps(
instances: torch.Tensor,
img_hw: Tuple[int],
num_instances: int,
sigma: float = 1.5,
output_stride: int = 2,
is_centroids: bool = False,
) -> torch.Tensor:
"""Generate multi-instance confidence maps.
Args:
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or
for centroids - (n_samples, n_instances, 2)
img_hw: Image size as tuple (height, width).
num_instances: Original number of instances in the frame.
sigma: The standard deviation of the Gaussian distribution that is used to
generate confidence maps. Default: 1.5.
output_stride: The relative stride to use when generating confidence maps.
A larger stride will generate smaller confidence maps. Default: 2.
is_centroids: True if confidence maps should be generates for centroids else False.
Default: False.
Returns:
Confidence maps for the input keypoints.
"""
if is_centroids:
points = instances[:, :num_instances, :].unsqueeze(dim=-2)
# (n_samples, n_instances, 1, 2)
else:
points = instances[
:, :num_instances, :, :
] # (n_samples, n_instances, n_nodes, 2)

height, width = img_hw

xv, yv = make_grid_vectors(height, width, output_stride)

confidence_maps = make_multi_confmaps(
points,
xv,
yv,
sigma * output_stride,
) # (n_samples, n_nodes, height/ output_stride, width/ output_stride).
# If `is_centroids`, (n_samples, 1, height/ output_stride, width/ output_stride).

return confidence_maps


def make_confmaps(
points_batch: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float
) -> torch.Tensor:
Expand Down
Loading

0 comments on commit 5e77de9

Please sign in to comment.