Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimal pretrained checkpoints for tests and fix PAF grouping interpolation #73

Merged
merged 14 commits into from
Sep 12, 2024
Merged
6 changes: 4 additions & 2 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ The config file has three main sections:
- `max_width`: (int) Maximum width the image should be padded to. If not provided, the
original image size will be retained. Default: None.
- `scale`: (float or List[float]) Factor to resize the image dimensions by, specified as either a float scalar or as a 2-tuple of [scale_x, scale_y]. If a scalar is provided, both dimensions are resized by the same factor.
- `crop_hw`: (List[int]) Crop height and width of each instance (h, w) for centered-instance model.
- `crop_hw`: (Tuple[int]) Crop height and width of each instance (h, w) for centered-instance model. If `None`, this would be automatically computed based on the largest instance in the `sio.Labels` file.
- `min_crop_size`: (int) Minimum crop size to be used if `crop_hw` is `None`.
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
- `use_augmentations_train`: (bool) True if the data augmentation should be applied to the training data, else False.
- `augmentation_config`: (only if `use_augmentations` is `True`)
- `random crop`: (Optional) (Dict[float]) {"random_crop_p": None, "crop_height": None. "crop_width": None}, where *random_crop_p* is the probability of applying random crop and *crop_height* and *crop_width* are the desired output size (out_h, out_w) of the crop.
Expand Down Expand Up @@ -156,12 +157,13 @@ The config file has three main sections:
- `use_wandb`: (bool) True to enable wandb logging.
- `save_ckpt`: (bool) True to enable checkpointing.
- `save_ckpt_path`: (str) Directory path to save the training config and checkpoint files. *Default*: "./"
- `resume_ckpt_path`: (str) Path to `.ckpt` file from which training is resumed. *Default*: `None`.
- `wandb`: (Only if `use_wandb` is `True`, else skip this)
- `entity`: (str) Entity of wandb project.
- `project`: (str) Project name for the wandb project.
- `name`: (str) Name of the current run.
- `api_key`: (str) API key. The API key is masked when saved to config files.
- `wandb_mode`: (str) "offline" if only local logging is required. Default: "None".
- `prv_runid`: (str) Previous run ID if training should be resumed from a previous ckpt. *Default*: `None`.
- `log_params`: (List[str]) List of config parameters to save it in wandb logs. For example, to save learning rate from trainer config section, use "trainer_config.optimizer.lr" (provide the full path to the specific config parameter).
- `optimizer_name`: (str) Optimizer to be used. One of ["Adam", "AdamW"].
- `optimizer`
Expand Down
1 change: 1 addition & 0 deletions docs/config_bottomup.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ trainer_config:
use_wandb: false
save_ckpt: true
save_ckpt_path: min_inst_bottomup1
resume_ckpt_path:
optimizer_name: Adam
optimizer:
lr: 0.0001
Expand Down
3 changes: 2 additions & 1 deletion docs/config_centroid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,13 @@ trainer_config:
use_wandb: false
save_ckpt: true
save_ckpt_path: 'min_inst_centroid'
resume_ckpt_path:
wandb: # sample wandb config
entity:
project: 'test_centroid_centered'
name: 'fly_unet_centered'
wandb_mode: ''
api_key: ''
prv_runid:
log_params:
- trainer_config.optimizer_name
- trainer_config.optimizer.amsgrad
Expand Down
1 change: 1 addition & 0 deletions docs/config_topdown_centered_instance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ trainer_config:
use_wandb: false
save_ckpt: true
save_ckpt_path: 'min_inst_centered'
resume_ckpt_path:
optimizer_name: Adam
optimizer:
lr: 0.0001
Expand Down
222 changes: 221 additions & 1 deletion sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""This module implements data pipeline blocks for augmentation operations."""

from typing import Any, Dict, Optional, Tuple, Union, Iterator

import kornia as K
import torch
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
Expand All @@ -11,6 +10,227 @@
from torch.utils.data.datapipes.datapipe import IterDataPipe


def apply_intensity_augmentation(
image: torch.Tensor,
instances: torch.Tensor,
uniform_noise_min: Optional[float] = 0.0,
uniform_noise_max: Optional[float] = 0.04,
uniform_noise_p: float = 0.0,
gaussian_noise_mean: Optional[float] = 0.02,
gaussian_noise_std: Optional[float] = 0.004,
gaussian_noise_p: float = 0.0,
contrast_min: Optional[float] = 0.5,
contrast_max: Optional[float] = 2.0,
contrast_p: float = 0.0,
brightness: Optional[float] = 0.0,
brightness_p: float = 0.0,
) -> Tuple[torch.Tensor]:
"""Apply kornia intensity augmentation on image and instances.

Args:
image: Input image. Shape: (n_samples, C, H, W)
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
uniform_noise_min: Minimum value for uniform noise (uniform_noise_min >=0).
uniform_noise_max: Maximum value for uniform noise (uniform_noise_max <=1).
uniform_noise_p: Probability of applying random uniform noise.
gaussian_noise_mean: The mean of the gaussian distribution.
gaussian_noise_std: The standard deviation of the gaussian distribution.
gaussian_noise_p: Probability of applying random gaussian noise.
contrast_min: Minimum contrast factor to apply. Default: 0.5.
contrast_max: Maximum contrast factor to apply. Default: 2.0.
contrast_p: Probability of applying random contrast.
brightness: The brightness factor to apply Default: 0.0.
brightness_p: Probability of applying random brightness.

Returns:
Returns tuple: (image, instances) with augmentation applied.
"""
aug_stack = []
if uniform_noise_p > 0:
aug_stack.append(
RandomUniformNoise(
noise=(uniform_noise_min, uniform_noise_max),
p=uniform_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if gaussian_noise_p > 0:
aug_stack.append(
K.augmentation.RandomGaussianNoise(
mean=gaussian_noise_mean,
std=gaussian_noise_std,
p=gaussian_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if contrast_p > 0:
aug_stack.append(
K.augmentation.RandomContrast(
contrast=(contrast_min, contrast_max),
p=contrast_p,
keepdim=True,
same_on_batch=True,
)
)
if brightness_p > 0:
aug_stack.append(
K.augmentation.RandomBrightness(
brightness=brightness,
p=brightness_p,
keepdim=True,
same_on_batch=True,
)
)

augmenter = AugmentationSequential(
*aug_stack,
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

inst_shape = instances.shape
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
instances = instances.reshape(inst_shape[0], -1, 2)
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)

aug_image, aug_instances = augmenter(image, instances)

# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
return aug_image, aug_instances.reshape(*inst_shape)


def apply_geometric_augmentation(
image: torch.Tensor,
instances: torch.Tensor,
rotation: Optional[float] = 15.0,
scale: Union[
Optional[float], Tuple[float, float], Tuple[float, float, float, float]
] = None,
translate_width: Optional[float] = 0.02,
translate_height: Optional[float] = 0.02,
affine_p: float = 0.0,
erase_scale_min: Optional[float] = 0.0001,
erase_scale_max: Optional[float] = 0.01,
erase_ratio_min: Optional[float] = 1,
erase_ratio_max: Optional[float] = 1,
erase_p: float = 0.0,
mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None,
mixup_p: float = 0.0,
random_crop_height: int = 0,
random_crop_width: int = 0,
random_crop_p: float = 0.0,
) -> Tuple[torch.Tensor]:
"""Apply kornia geometric augmentation on image and instances.

Args:
image: Input image. Shape: (n_samples, C, H, W)
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
rotation: Angles in degrees as a scalar float of the amount of rotation. A
random angle in `(-rotation, rotation)` will be sampled and applied to both
images and keypoints. Set to 0 to disable rotation augmentation.
scale: scaling factor interval. If (a, b) represents isotropic scaling, the scale
is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale
is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d.
Default: None.
translate_width: Maximum absolute fraction for horizontal translation. For example,
if translate_width=a, then horizontal shift is randomly sampled in the range
-img_width * a < dx < img_width * a. Will not translate by default.
translate_height: Maximum absolute fraction for vertical translation. For example,
if translate_height=a, then vertical shift is randomly sampled in the range
-img_height * a < dy < img_height * a. Will not translate by default.
affine_p: Probability of applying random affine transformations.
erase_scale_min: Minimum value of range of proportion of erased area against input image. Default: 0.0001.
erase_scale_max: Maximum value of range of proportion of erased area against input image. Default: 0.01.
erase_ratio_min: Minimum value of range of aspect ratio of erased area. Default: 1.
erase_ratio_max: Maximum value of range of aspect ratio of erased area. Default: 1.
erase_p: Probability of applying random erase.
mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`.
mixup_p: Probability of applying random mixup v2.
random_crop_height: Desired output height of the crop. Must be int.
random_crop_width: Desired output width of the crop. Must be int.
random_crop_p: Probability of applying random crop.


Returns:
Returns tuple: (image, instances) with augmentation applied.
"""
if isinstance(scale, float):
scale = (scale, scale)
aug_stack = []
if affine_p > 0:
aug_stack.append(
K.augmentation.RandomAffine(
degrees=rotation,
translate=(translate_width, translate_height),
scale=scale,
p=affine_p,
keepdim=True,
same_on_batch=True,
)
)

if erase_p > 0:
aug_stack.append(
K.augmentation.RandomErasing(
scale=(erase_scale_min, erase_scale_max),
ratio=(erase_ratio_min, erase_ratio_max),
p=erase_p,
keepdim=True,
same_on_batch=True,
)
)
if mixup_p > 0:
aug_stack.append(
K.augmentation.RandomMixUpV2(
lambda_val=mixup_lambda,
p=mixup_p,
keepdim=True,
same_on_batch=True,
)
)
if random_crop_p > 0:
if random_crop_height > 0 and random_crop_width > 0:
aug_stack.append(
K.augmentation.RandomCrop(
size=(random_crop_height, random_crop_width),
pad_if_needed=True,
p=random_crop_p,
keepdim=True,
same_on_batch=True,
)
)
else:
raise ValueError(f"random_crop_hw height and width must be greater than 0.")
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved

augmenter = AugmentationSequential(
*aug_stack,
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

inst_shape = instances.shape
# Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
instances = instances.reshape(inst_shape[0], -1, 2)
# (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)

aug_image, aug_instances = augmenter(image, instances)

# After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
# or
# After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
return aug_image, aug_instances.reshape(*inst_shape)


class RandomUniformNoise(IntensityAugmentationBase2D):
"""Data transformer for applying random uniform noise to input images.

Expand Down
88 changes: 87 additions & 1 deletion sleap_nn/data/confidence_maps.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,99 @@
"""Generate confidence maps."""

from typing import Dict, Iterator
from typing import Dict, Iterator, Tuple

import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe

from sleap_nn.data.utils import make_grid_vectors


def generate_confmaps(
instance: torch.Tensor,
img_hw: Tuple[int],
sigma: float = 1.5,
output_stride: int = 2,
) -> torch.Tensor:
"""Generate Confidence maps.

Args:
instance: Input keypoints. (n_samples, n_instances, n_nodes, 2) or
(n_samples, n_nodes, 2).
img_hw: Image size as tuple (height, width).
sigma: The standard deviation of the Gaussian distribution that is used to
generate confidence maps. Default: 1.5.
output_stride: The relative stride to use when generating confidence maps.
A larger stride will generate smaller confidence maps. Default: 2.

Returns:
Confidence maps for the input keypoints.
"""
if instance.ndim != 3:
instance = instance.view(instance.shape[0], -1, 2)
# instances: (n_samples, n_nodes, 2)

height, width = img_hw

xv, yv = make_grid_vectors(height, width, output_stride)

confidence_maps = make_confmaps(
instance,
xv,
yv,
sigma * output_stride,
) # (n_samples, n_nodes, height/ output_stride, width/ output_stride)

return confidence_maps


def generate_multiconfmaps(
instances: torch.Tensor,
img_hw: Tuple[int],
num_instances: int,
sigma: float = 1.5,
output_stride: int = 2,
is_centroids: bool = False,
) -> torch.Tensor:
"""Generate multi-instance confidence maps.

Args:
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or
for centroids - (n_samples, n_instances, 2)
img_hw: Image size as tuple (height, width).
num_instances: Original number of instances in the frame.
sigma: The standard deviation of the Gaussian distribution that is used to
generate confidence maps. Default: 1.5.
output_stride: The relative stride to use when generating confidence maps.
A larger stride will generate smaller confidence maps. Default: 2.
is_centroids: True if confidence maps should be generates for centroids else False.
Default: False.

Returns:
Confidence maps for the input keypoints.
"""
if is_centroids:
points = instances[:, :num_instances, :].unsqueeze(dim=-2)
# (n_samples, n_instances, 1, 2)
else:
points = instances[
:, :num_instances, :, :
] # (n_samples, n_instances, n_nodes, 2)

height, width = img_hw

xv, yv = make_grid_vectors(height, width, output_stride)

confidence_maps = make_multi_confmaps(
points,
xv,
yv,
sigma * output_stride,
) # (n_samples, n_nodes, height/ output_stride, width/ output_stride).
# If `is_centroids`, (n_samples, 1, height/ output_stride, width/ output_stride).

return confidence_maps


def make_confmaps(
points_batch: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float
) -> torch.Tensor:
Expand Down
Loading
Loading