Skip to content

Commit

Permalink
Top-down Centered-instance Pipeline (#16)
Browse files Browse the repository at this point in the history
* added make_centered_bboxes & normalize_bboxes

* added make_centered_bboxes & normalize_bboxes

* created test_instance_cropping.py

* added test normalize bboxes; added find_global_peaks_rough

* black formatted

* black formatted peak_finding

* added make_grid_vectors, normalize_bboxes, integral_regression, added docstring to make_centered_bboxes, fixed find_global_peaks_rough; added crop_bboxes

* finished find_global_peaks with integral regression over centroid crops!

* reformatted with pydocstyle & black

* moved make_grid_vectors to data/utils

* removed normalize_bboxes

* added tests docstrings

* sorted imports with isort

* remove unused imports

* updated test cases for instance cropping

* added minimal_cms.pt fixture + unit tests

* added minimal_bboxes fixture; added unit tests for crop_bboxes & integral_regression

* added find_global_peaks unit tests

* finished find_local_peaks_rough!

* finished find_local_peaks!

* added unit tests for find_local_peaks and find_local_peaks_rough

* updated test cases

* added more test cases for find_local_peaks

* updated test cases

* added architectures folder

* added maxpool2d same padding, get_act_fn; added simpleconvblock, simpleupsamplingblock, encoder, decoder; added unet

* added test_unet_reference

* black formatted common.py & test_unet.py

* deleted tmp nb

* _calc_same_pad returns int

* fixed test case

* added simpleconvblock tests

* added tests

* added tests for simple upsampling block

* updated test_unet

* removed unnecessary variables

* updated augmentation random erase default values

* created data/pipelines.py

* added base config in config/data; temporary till config system settled

* updated variable defaults to 0 and edited variable names in augmentation

* updated parameter names in data/instance_cropping

* added data/pipelines topdown pipeline make_base_pipeline

* added test_pipelines

* removed configs

* updated augmentation class

* modified test

* updated pipelines docstring

* removed make_base_pipeline and updated tests

* removed empty_cache in SleapDataset

* updated test_pipelines

* updated sleapdataset to return a dict

* added key filter transformer block, removed sleap dataset, added type hinting

* updated type hints

* added coderabbit suggestions

* fixed small squeeze issue
  • Loading branch information
alckasoc committed Sep 19, 2023
1 parent c2db05f commit 47897ad
Show file tree
Hide file tree
Showing 13 changed files with 368 additions and 82 deletions.
4 changes: 3 additions & 1 deletion sleap_nn/architectures/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Common utilities for architecture and model building."""
from typing import List

import torch
from torch import nn
from torch.nn import functional as F
Expand Down Expand Up @@ -134,7 +136,7 @@ def get_act_fn(activation: str) -> nn.Module:
return activations[activation]


def get_children_layers(model: torch.nn.Module):
def get_children_layers(model: torch.nn.Module) -> List[nn.Module]:
"""Recursively retrieves a flattened list of all children modules and submodules within the given model.
Args:
Expand Down
81 changes: 51 additions & 30 deletions sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""This module implements data pipeline blocks for augmentation operations."""
from typing import Any, Dict, Optional, Text, Tuple, Union
from typing import Any, Dict, Iterator, Optional, Text, Tuple, Union

import kornia as K
import torch
Expand Down Expand Up @@ -100,9 +100,6 @@ class KorniaAugmenter(IterDataPipe):
Attributes:
source_dp: The input `IterDataPipe` with examples that contain `"instances"` and
`"image"` keys.
crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int],
then out_h = size[0], out_w = size[1].
crop_p: Probability of applying random crop.
rotation: Angles in degrees as a scalar float of the amount of rotation. A
random angle in `(-rotation, rotation)` will be sampled and applied to both
images and keypoints. Set to 0 to disable rotation augmentation.
Expand All @@ -125,11 +122,14 @@ class KorniaAugmenter(IterDataPipe):
contrast_p: Probability of applying random contrast.
brightness: The brightness factor to apply Default: `(1.0, 1.0)`.
brightness_p: Probability of applying random brightness.
erase_scale: Range of proportion of erased area against input image. Default: `(0.02, 0.33)`.
erase_ratio: Range of aspect ratio of erased area. Default: `(0.3, 3.3)`.
erase_scale: Range of proportion of erased area against input image. Default: `(0.0001, 0.01)`.
erase_ratio: Range of aspect ratio of erased area. Default: `(1, 1)`.
erase_p: Probability of applying random erase.
mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`.
mixup_p: Probability of applying random mixup v2.
random_crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int],
then out_h = size[0], out_w = size[1].
random_crop_p: Probability of applying random crop.
Notes:
This block expects the "image" and "instances" keys to be present in the input
Expand All @@ -150,24 +150,24 @@ def __init__(
rotation: Optional[float] = 15.0,
scale: Optional[float] = 0.05,
translate: Optional[Tuple[float, float]] = (0.02, 0.02),
affine_p: float = 0.5,
affine_p: float = 0.0,
uniform_noise: Optional[Tuple[float, float]] = (0.0, 0.04),
uniform_noise_p: float = 0.5,
uniform_noise_p: float = 0.0,
gaussian_noise_mean: Optional[float] = 0.02,
gaussian_noise_std: Optional[float] = 0.004,
gaussian_noise_p: float = 0.5,
gaussian_noise_p: float = 0.0,
contrast: Optional[Tuple[float, float]] = (0.5, 2.0),
contrast_p: float = 0.5,
contrast_p: float = 0.0,
brightness: Optional[float] = 0.0,
brightness_p: float = 0.5,
brightness_p: float = 0.0,
erase_scale: Optional[Tuple[float, float]] = (0.0001, 0.01),
erase_ratio: Optional[Tuple[float, float]] = (1, 1),
erase_p: float = 0.5,
erase_p: float = 0.0,
mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None,
mixup_p: float = 0.5,
crop_hw: Tuple[int, int] = (0, 0),
crop_p: float = 0.0,
):
mixup_p: float = 0.0,
random_crop_hw: Tuple[int, int] = (0, 0),
random_crop_p: float = 0.0,
) -> None:
"""Initialize the block and the augmentation pipeline."""
self.source_dp = source_dp
self.rotation = rotation
Expand All @@ -188,8 +188,8 @@ def __init__(
self.erase_p = erase_p
self.mixup_lambda = mixup_lambda
self.mixup_p = mixup_p
self.crop_hw = crop_hw
self.crop_p = crop_p
self.random_crop_hw = random_crop_hw
self.random_crop_p = random_crop_p

aug_stack = []
if self.affine_p > 0:
Expand Down Expand Up @@ -259,19 +259,21 @@ def __init__(
same_on_batch=True,
)
)
if self.crop_p > 0:
if self.crop_hw[0] > 0 and self.crop_hw[1] > 0:
if self.random_crop_p > 0:
if self.random_crop_hw[0] > 0 and self.random_crop_hw[1] > 0:
aug_stack.append(
K.augmentation.RandomCrop(
size=self.crop_hw,
size=self.random_crop_hw,
pad_if_needed=True,
p=self.crop_p,
p=self.random_crop_p,
keepdim=True,
same_on_batch=True,
)
)
else:
raise ValueError(f"crop_hw height and width must be greater than 0.")
raise ValueError(
f"random_crop_hw height and width must be greater than 0."
)

self.augmenter = AugmentationSequential(
*aug_stack,
Expand All @@ -280,12 +282,31 @@ def __init__(
same_on_batch=True,
)

def __iter__(self):
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Return an example dictionary with the augmented image and instances."""
for ex in self.source_dp:
inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2)
image, instances = ex["image"], ex["instances"].reshape(
inst_shape[0], -1, 2
)
aug_image, aug_instances = self.augmenter(image, instances)
yield {"image": aug_image, "instances": aug_instances.reshape(*inst_shape)}
if "instance_image" in ex and "instance" in ex:
inst_shape = ex["instance"].shape
# (B, channels, height, width), (1, num_nodes, 2)
image, instances = ex["instance_image"], ex["instance"].unsqueeze(0)
aug_image, aug_instances = self.augmenter(image, instances)
ex.update(
{
"instance_image": aug_image,
"instance": aug_instances.reshape(*inst_shape),
}
)
elif "image" in ex and "instances" in ex:
inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2)
image, instances = ex["image"], ex["instances"].reshape(
inst_shape[0], -1, 2
) # (B, channels, height, width), (B, num_instances x num_nodes, 2)

aug_image, aug_instances = self.augmenter(image, instances)
ex.update(
{
"image": aug_image,
"instances": aug_instances.reshape(*inst_shape),
}
)
yield ex
8 changes: 4 additions & 4 deletions sleap_nn/data/confidence_maps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Generate confidence maps."""
from typing import Optional
from typing import Dict, Iterator, Optional

import sleap_io as sio
import torch
Expand All @@ -10,7 +10,7 @@

def make_confmaps(
points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float
):
) -> torch.Tensor:
"""Make confidence maps from a set of points from a single instance.
Args:
Expand Down Expand Up @@ -70,15 +70,15 @@ def __init__(
output_stride: int = 1,
instance_key: str = "instance",
image_key: str = "instance_image",
):
) -> None:
"""Initialize ConfidenceMapGenerator with input `DataPipe`, sigma, and output stride."""
self.source_dp = source_dp
self.sigma = sigma
self.output_stride = output_stride
self.instance_key = instance_key
self.image_key = image_key

def __iter__(self):
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Generate confidence maps for each example."""
for example in self.source_dp:
instance = example[self.instance_key]
Expand Down
39 changes: 39 additions & 0 deletions sleap_nn/data/general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""General purpose transformers for common pipeline processing tasks."""
from typing import Callable, Dict, Iterator, List, Text

import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe


class KeyFilter(IterDataPipe):
"""Transformer for filtering example keys."""

def __init__(self, source_dp: IterDataPipe, keep_keys: List[Text] = None) -> None:
"""Initialize KeyFilter with the source `DataPipe."""
self.dp = source_dp
self.keep_keys = set(keep_keys) if keep_keys else None

def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Return a dictionary filtered for the relevant outputs.
The input dictionary includes:
- image: the full frame image
- instances: all keypoints of all instances in the frame image
- centroids: all centroids of all instances in the frame image
- instance: the individual instance's keypoints
- instance_bbox: the individual instance's bbox
- instance_image: the individual instance's cropped image
- confidence_maps: the individual instance's heatmap
"""
for example in self.dp:
if self.keep_keys is None:
# If keep_keys is not provided, yield the entire example.
yield example
else:
# Filter the example dictionary based on keep_keys.
filtered_example = {
key: value
for key, value in example.items()
if key in self.keep_keys
}
yield filtered_example
14 changes: 7 additions & 7 deletions sleap_nn/data/instance_centroids.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Handle calculation of instance centroids."""
from typing import Optional
from typing import Dict, Iterator, Optional

import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe
Expand Down Expand Up @@ -79,15 +79,15 @@ def __init__(
self,
source_dp: IterDataPipe,
anchor_ind: Optional[int] = None,
):
) -> None:
"""Initialize InstanceCentroidFinder with the source `DataPipe."""
self.source_dp = source_dp
self.anchor_ind = anchor_ind

def __iter__(self):
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Add `"centroids"` key to example."""
for example in self.source_dp:
example["centroids"] = find_centroids(
example["instances"], anchor_ind=self.anchor_ind
for ex in self.source_dp:
ex["centroids"] = find_centroids(
ex["instances"], anchor_ind=self.anchor_ind
)
yield example
yield ex
50 changes: 23 additions & 27 deletions sleap_nn/data/instance_cropping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Handle cropping of instances."""
from typing import Optional
from typing import Dict, Iterator, Optional, Tuple

import numpy as np
import sleap_io as sio
Expand Down Expand Up @@ -58,50 +58,46 @@ class InstanceCropper(IterDataPipe):
Attributes:
source_dp: The previous `DataPipe` with samples that contain an `instances` key.
crop_width: Width of the crop in pixels
crop_height: Height of the crop in pixels
crop_hw: Height and Width of the crop in pixels
"""

def __init__(
self,
source_dp: IterDataPipe,
crop_width: int,
crop_height: int,
):
def __init__(self, source_dp: IterDataPipe, crop_hw: Tuple[int, int]) -> None:
"""Initialize InstanceCropper with the source `DataPipe."""
self.source_dp = source_dp
self.crop_width = crop_width
self.crop_height = crop_height
self.crop_hw = crop_hw

def __iter__(self):
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Generate instance cropped examples."""
for example in self.source_dp:
image = example["image"] # (frames, channels, height, width)
instances = example["instances"] # (frames, n_instances, n_nodes, 2)
centroids = example["centroids"] # (frames, n_instances, 2)
for ex in self.source_dp:
image = ex["image"] # (B, channels, height, width)
instances = ex["instances"] # (B, n_instances, num_nodes, 2)
centroids = ex["centroids"] # (B, n_instances, 2)
for instance, centroid in zip(instances[0], centroids[0]):
# Generate bounding boxes from centroid.
bbox = torch.unsqueeze(
make_centered_bboxes(centroid, self.crop_height, self.crop_width), 0
) # (frames, 4, 2)
instance_bbox = torch.unsqueeze(
make_centered_bboxes(centroid, self.crop_hw[0], self.crop_hw[1]), 0
) # (B, 4, 2)

box_size = (self.crop_height, self.crop_width)
box_size = (self.crop_hw[0], self.crop_hw[1])

# Generate cropped image of shape (frames, channels, crop_height, crop_width)
# Generate cropped image of shape (B, channels, crop_height, crop_width)
instance_image = crop_and_resize(
image,
boxes=bbox,
boxes=instance_bbox,
size=box_size,
)

# Access top left point (x,y) of bounding box and subtract this offset from
# position of nodes.
point = bbox[0][0]
point = instance_bbox[0][0]
center_instance = instance - point

instance_example = {
"instance_image": instance_image, # (frames, channels, crop_height, crop_width)
"bbox": bbox, # (frames, 4, 2)
"instance": center_instance, # (n_instances, 2)
"instance_image": instance_image.squeeze(
0
), # (B=1, channels, crop_height, crop_width)
"instance_bbox": instance_bbox, # (B, 4, 2)
"instance": center_instance, # (num_nodes, 2)
}
yield instance_example
ex.update(instance_example)
yield ex
6 changes: 4 additions & 2 deletions sleap_nn/data/normalization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""This module implements data pipeline blocks for normalization operations."""
from typing import Dict, Iterator

import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe

Expand All @@ -16,11 +18,11 @@ class Normalizer(IterDataPipe):
def __init__(
self,
source_dp: IterDataPipe,
):
) -> None:
"""Initialize the block."""
self.source_dp = source_dp

def __iter__(self):
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Return an example dictionary with the augmented image and instance."""
for ex in self.source_dp:
if not torch.is_floating_point(ex["image"]):
Expand Down
Loading

0 comments on commit 47897ad

Please sign in to comment.