Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added SingleInstanceConfmapsPipeline #23

Merged
merged 107 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
2dfcd90
added make_centered_bboxes & normalize_bboxes
alckasoc Aug 3, 2023
1088e7f
added make_centered_bboxes & normalize_bboxes
alckasoc Aug 3, 2023
2d0a009
created test_instance_cropping.py
alckasoc Aug 3, 2023
02ea629
added test normalize bboxes; added find_global_peaks_rough
alckasoc Aug 6, 2023
711b3aa
black formatted
alckasoc Aug 6, 2023
3a0cfb7
fixed merges
alckasoc Aug 6, 2023
9a728aa
black formatted peak_finding
alckasoc Aug 6, 2023
e84535f
added make_grid_vectors, normalize_bboxes, integral_regression, added…
alckasoc Aug 10, 2023
36f6573
finished find_global_peaks with integral regression over centroid crops!
alckasoc Aug 10, 2023
b17af28
reformatted with pydocstyle & black
alckasoc Aug 10, 2023
3ea75ae
Merge remote-tracking branch 'origin/main' into vincent/find_peaks
alckasoc Aug 10, 2023
a506579
moved make_grid_vectors to data/utils
alckasoc Aug 10, 2023
02babb1
removed normalize_bboxes
alckasoc Aug 10, 2023
373f4b1
added tests docstrings
alckasoc Aug 10, 2023
6351314
sorted imports with isort
alckasoc Aug 10, 2023
008a994
remove unused imports
alckasoc Aug 10, 2023
b45619c
updated test cases for instance cropping
alckasoc Aug 10, 2023
381a49f
added minimal_cms.pt fixture + unit tests
alckasoc Aug 11, 2023
0ad336c
added minimal_bboxes fixture; added unit tests for crop_bboxes & inte…
alckasoc Aug 11, 2023
da1ba7e
added find_global_peaks unit tests
alckasoc Aug 11, 2023
7778512
finished find_local_peaks_rough!
alckasoc Aug 17, 2023
9f7ac3f
finished find_local_peaks!
alckasoc Aug 17, 2023
b9869d6
added unit tests for find_local_peaks and find_local_peaks_rough
alckasoc Aug 17, 2023
bfd1cac
updated test cases
alckasoc Aug 17, 2023
a8b3c31
added more test cases for find_local_peaks
alckasoc Aug 17, 2023
125625d
updated test cases
alckasoc Aug 17, 2023
a25d920
added architectures folder
alckasoc Aug 17, 2023
3ba92b6
added maxpool2d same padding, get_act_fn; added simpleconvblock, simp…
alckasoc Aug 17, 2023
f9558f2
added test_unet_reference
alckasoc Aug 18, 2023
28d57ca
black formatted common.py & test_unet.py
alckasoc Aug 18, 2023
8ca4538
fixed merge conflicts
alckasoc Aug 18, 2023
6df3c20
Merge branch 'main' into vincent/unet
alckasoc Aug 18, 2023
c4792a6
Merge branch 'vincent/unet' of https://github.com/talmolab/sleap-nn i…
alckasoc Aug 18, 2023
87cd034
deleted tmp nb
alckasoc Aug 18, 2023
7004869
_calc_same_pad returns int
alckasoc Aug 19, 2023
680778d
fixed test case
alckasoc Aug 19, 2023
7cd75dc
added simpleconvblock tests
alckasoc Aug 19, 2023
79b535d
added tests
alckasoc Aug 19, 2023
691af45
added tests for simple upsampling block
alckasoc Aug 19, 2023
2520fa2
updated test_unet
alckasoc Aug 28, 2023
bcf4069
removed unnecessary variables
alckasoc Aug 30, 2023
dbccdcf
updated augmentation random erase default values
alckasoc Aug 30, 2023
029a545
created data/pipelines.py
alckasoc Aug 30, 2023
3e5ae68
added base config in config/data; temporary till config system settled
alckasoc Aug 31, 2023
1b8002b
updated variable defaults to 0 and edited variable names in augmentation
alckasoc Aug 31, 2023
f1c64f4
updated parameter names in data/instance_cropping
alckasoc Aug 31, 2023
2a22674
added data/pipelines topdown pipeline make_base_pipeline
alckasoc Aug 31, 2023
f3ddf2f
added test_pipelines
alckasoc Aug 31, 2023
c861c72
removed configs
alckasoc Sep 5, 2023
31aadc1
updated augmentation class
alckasoc Sep 6, 2023
6630155
modified test
alckasoc Sep 6, 2023
c7fc015
removed cuda cache
alckasoc Sep 6, 2023
4b439e9
added Model builder class and heads
alckasoc Sep 6, 2023
199e4d5
added type hinting for init
alckasoc Sep 6, 2023
4134473
black reformatted heads.py
alckasoc Sep 6, 2023
86f960e
updated model.py
alckasoc Sep 6, 2023
d576e63
updated test_model.py
alckasoc Sep 6, 2023
40089b7
updated test_model.py
alckasoc Sep 6, 2023
dddb6ac
updated pipelines docstring
alckasoc Sep 6, 2023
b75cacd
added from_config for Model
alckasoc Sep 7, 2023
979de3f
added more act fn to get_act_fn
alckasoc Sep 12, 2023
ff144a9
black reformatted & updated model.py & test_model.py
alckasoc Sep 12, 2023
4dacac9
updated config, typehints, black formatted & added doc strings
alckasoc Sep 12, 2023
a4fffd4
added test_heads.py
alckasoc Sep 12, 2023
ddc70c4
updated module docstring
alckasoc Sep 12, 2023
3f58407
updated Model docstring
alckasoc Sep 12, 2023
980f107
added coderabbit suggestions
alckasoc Sep 12, 2023
cf9d18f
black reformat
alckasoc Sep 13, 2023
0def8ae
added 2 helper methods for getting backbone/head; separated common an…
alckasoc Sep 13, 2023
53065a3
removed comments
alckasoc Sep 13, 2023
1a66ace
updated test_get_act_fn
alckasoc Sep 13, 2023
f762317
added multi-head feature to Model
alckasoc Sep 14, 2023
a70c5c4
black reformatted model.py
alckasoc Sep 14, 2023
f2543a2
added all test cases for heads.py
alckasoc Sep 14, 2023
e6fa8d4
reformatted test_heads.py
alckasoc Sep 14, 2023
ce86884
updated L44 in confidence_maps.py
alckasoc Sep 14, 2023
361f979
added output channels to unet
alckasoc Sep 19, 2023
81b5af0
resolved merge conflicts
alckasoc Sep 19, 2023
ca4e87d
Merge branch 'main' into vincent/models
alckasoc Sep 19, 2023
a768c4f
resolved merge conflicts + small bugs
alckasoc Sep 19, 2023
0be083f
black reformatted
alckasoc Sep 20, 2023
88bc306
added coderabbit suggestions
alckasoc Sep 20, 2023
27990c3
not sure how intermediate features + multi head would work
alckasoc Sep 20, 2023
212e94f
Separate Augmentations into Intensity and Geometric (#18)
alckasoc Sep 21, 2023
d44f8a6
pseudo code in model.py
alckasoc Sep 21, 2023
2930057
Merge branch 'vincent/models' of https://github.com/talmolab/sleap-nn…
alckasoc Sep 21, 2023
537c14b
small fix
alckasoc Sep 21, 2023
b759e7a
name property in heads.py
alckasoc Sep 21, 2023
e21e74c
name of head docstring added
alckasoc Sep 21, 2023
dce7294
Merge remote-tracking branch 'origin/main' into vincent/models
alckasoc Sep 26, 2023
99413a1
added ruff cache to gitignore; added head selection in Model class
alckasoc Sep 27, 2023
2880849
updated return value for decoder
alckasoc Oct 7, 2023
9afdebb
small change to model.py
alckasoc Oct 7, 2023
5cd8801
made model.py forward more efficient
alckasoc Oct 7, 2023
4eaf268
small comments updated in instance_cropping for clarity
alckasoc Oct 7, 2023
702009a
updated output structure of unet to dict; updated model.py attribute …
alckasoc Oct 12, 2023
3efac3c
added single instance confmaps pipeline
alckasoc Oct 16, 2023
0833b19
added test cases
alckasoc Oct 17, 2023
abe7f97
fixed batching issue
alckasoc Oct 25, 2023
89781fb
resolved merge conflict; labelsreader issue unsolved; fixed output sh…
alckasoc Oct 26, 2023
2a8ac09
updated test cases
alckasoc Oct 26, 2023
4d6e355
fixed pydocstyle
alckasoc Oct 26, 2023
c68de13
updated toml; added keep keys explicitly; added test case for singlei…
alckasoc Nov 1, 2023
3dd3b7d
updated toml
alckasoc Nov 2, 2023
84222d1
Merge branch 'main' into vincent/single_instance_pipeline
talmo Nov 22, 2023
2080316
updated test_pipelines.py test case
alckasoc Nov 26, 2023
bd09be8
black reformatted
alckasoc Nov 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ Repository = "https://github.com/talmolab/sleap-nn"
line-length = 88

[tool.ruff]
output-format = "github"
select = [
"D", # pydocstyle
]
Expand Down
47 changes: 24 additions & 23 deletions sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ class KorniaAugmenter(IterDataPipe):
random_crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int],
then out_h = size[0], out_w = size[1].
random_crop_p: Probability of applying random crop.
input_key: Can be `image` or `instance`. The input_key `instance` expects the
the KorniaAugmenter to follow the InstanceCropper else `image` otherwise
for default.

Notes:
This block expects the "image" and "instances" keys to be present in the input
Expand Down Expand Up @@ -164,6 +167,8 @@ def __init__(
mixup_p: float = 0.0,
random_crop_hw: Tuple[int, int] = (0, 0),
random_crop_p: float = 0.0,
image_key: str = "image",
instance_key: str = "instances",
) -> None:
"""Initialize the block and the augmentation pipeline."""
self.source_dp = source_dp
Expand All @@ -187,6 +192,8 @@ def __init__(
self.mixup_p = mixup_p
self.random_crop_hw = random_crop_hw
self.random_crop_p = random_crop_p
self.image_key = image_key
self.instance_key = instance_key

aug_stack = []
if self.affine_p > 0:
Expand Down Expand Up @@ -282,28 +289,22 @@ def __init__(
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Return an example dictionary with the augmented image and instances."""
for ex in self.source_dp:
if "instance_image" in ex and "instance" in ex:
inst_shape = ex["instance"].shape
# (B, channels, height, width), (1, num_nodes, 2)
image, instances = ex["instance_image"], ex["instance"].unsqueeze(0)
aug_image, aug_instances = self.augmenter(image, instances)
ex.update(
{
"instance_image": aug_image,
"instance": aug_instances.reshape(*inst_shape),
}
)
elif "image" in ex and "instances" in ex:
inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2)
image, instances = ex["image"], ex["instances"].reshape(
inst_shape[0], -1, 2
) # (B, channels, height, width), (B, num_instances x num_nodes, 2)
inst_shape = ex[self.instance_key].shape
# Before (self.input_key="image"): (B=1, C, H, W), (B=1, num_instances, num_nodes, 2)
# or
# Before (self.input_key="instance"): (B=1, C, crop_H, crop_W), (B=1, num_nodes, 2)
image, instances = ex[self.image_key], ex[self.instance_key].reshape(
inst_shape[0], -1, 2
) # (B=1, C, H, W), (B=1, num_instances * num_nodes, 2) OR (B=1, num_nodes, 2)

aug_image, aug_instances = self.augmenter(image, instances)
ex.update(
{
"image": aug_image,
"instances": aug_instances.reshape(*inst_shape),
}
)
aug_image, aug_instances = self.augmenter(image, instances)
ex.update(
{
self.image_key: aug_image,
self.instance_key: aug_instances.reshape(*inst_shape),
}
)
# After (self.input_key="image"): (B=1, C, H, W), (B=1, num_instances, num_nodes, 2)
# or
# After (self.input_key="instance"): (B=1, C, crop_H, crop_W), (B=1, num_nodes, 2)
yield ex
48 changes: 28 additions & 20 deletions sleap_nn/data/confidence_maps.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
"""Generate confidence maps."""
from typing import Dict, Iterator, Optional
from typing import Dict, Iterator

import sleap_io as sio
import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe

from sleap_nn.data.utils import make_grid_vectors


def make_confmaps(
points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float
points_batch: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float
) -> torch.Tensor:
"""Make confidence maps from a set of points from a single instance.
"""Make confidence maps from a batch of points for multiple instances.

Args:
points: A tensor of points of shape `(n_nodes, 2)` and dtype `torch.float32` where
the last axis corresponds to (x, y) pixel coordinates on the image. These
can contain NaNs to indicate missing points.
points_batch: A tensor of points of shape `(batch_size, n_nodes, 2)` and dtype `torch.float32` where
the last axis corresponds to (x, y) pixel coordinates on the image for each instance.
These can contain NaNs to indicate missing points.
xv: Sampling grid vector for x-coordinates of shape `(grid_width,)` and dtype
`torch.float32`. This can be generated by
`sleap.nn.data.utils.make_grid_vectors`.
Expand All @@ -27,21 +26,24 @@ def make_confmaps(
confidence maps.

Returns:
Confidence maps as a tensor of shape `(n_nodes, grid_height, grid_width)` of
Confidence maps as a tensor of shape `(batch_size, n_nodes, grid_height, grid_width)` of
dtype `torch.float32`.
"""
x = torch.reshape(points[:, 0], (-1, 1, 1))
y = torch.reshape(points[:, 1], (-1, 1, 1))
batch_size, n_nodes, _ = points_batch.shape

x = torch.reshape(points_batch[:, :, 0], (batch_size, n_nodes, 1, 1))
y = torch.reshape(points_batch[:, :, 1], (batch_size, n_nodes, 1, 1))

xv_reshaped = torch.reshape(xv, (1, 1, 1, -1))
yv_reshaped = torch.reshape(yv, (1, 1, -1, 1))

cm = torch.exp(
-(
(torch.reshape(xv, (1, 1, -1)) - x) ** 2
+ (torch.reshape(yv, (1, -1, 1)) - y) ** 2
)
/ (2 * sigma**2)
-((xv_reshaped - x) ** 2 + (yv_reshaped - y) ** 2) / (2 * sigma**2)
)

# Replace NaNs with 0.
cm = torch.nan_to_num(cm)

return cm


Expand All @@ -59,36 +61,42 @@ class ConfidenceMapGenerator(IterDataPipe):
generate confidence maps.
output_stride: The relative stride to use when generating confidence maps.
A larger stride will generate smaller confidence maps.
instance_key: The name of the key where the instance points (n_instances, 2) are.
image_key: The name of the key where the image (frames, channels, crop_height, crop_width) is.
instance_key: The name of the key where the instance points (n_instances, 2) are.
"""

def __init__(
self,
source_dp: IterDataPipe,
sigma: int = 1.5,
output_stride: int = 1,
instance_key: str = "instance",
image_key: str = "instance_image",
image_key: str = "image",
instance_key: str = "instances",
) -> None:
"""Initialize ConfidenceMapGenerator with input `DataPipe`, sigma, and output stride."""
self.source_dp = source_dp
self.sigma = sigma
self.output_stride = output_stride
self.instance_key = instance_key
self.image_key = image_key
self.instance_key = instance_key

def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Generate confidence maps for each example."""
for example in self.source_dp:
instance = example[self.instance_key]
if self.instance_key == "instances":
instance = instance.view(instance.shape[0], -1, 2)

width = example[self.image_key].shape[-1]
height = example[self.image_key].shape[-2]

xv, yv = make_grid_vectors(height, width, self.output_stride)

confidence_maps = make_confmaps(
instance, xv, yv, self.sigma
instance,
xv,
yv,
self.sigma,
) # (n_nodes, height, width)

example["confidence_maps"] = confidence_maps
Expand Down
2 changes: 1 addition & 1 deletion sleap_nn/data/instance_centroids.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,5 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
for ex in self.source_dp:
ex["centroids"] = find_centroids(
ex["instances"], anchor_ind=self.anchor_ind
)
) # (B=1, num_instances, 2)
yield ex
14 changes: 6 additions & 8 deletions sleap_nn/data/instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def __init__(self, source_dp: IterDataPipe, crop_hw: Tuple[int, int]) -> None:
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Generate instance cropped examples."""
for ex in self.source_dp:
image = ex["image"] # (B=1, channels, height, width)
instances = ex["instances"] # (B=1, n_instances, num_nodes, 2)
centroids = ex["centroids"] # (B=1, n_instances, 2)
image = ex["image"] # (B=1, C, H, W)
instances = ex["instances"] # (B=1, num_instances, num_nodes, 2)
centroids = ex["centroids"] # (B=1, num_instances, 2)
for instance, centroid in zip(instances[0], centroids[0]):
# Generate bounding boxes from centroid.
instance_bbox = torch.unsqueeze(
Expand All @@ -78,7 +78,7 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:

box_size = (self.crop_hw[0], self.crop_hw[1])

# Generate cropped image of shape (B=1, channels, crop_height, crop_width)
# Generate cropped image of shape (B=1, C, crop_H, crop_W)
instance_image = crop_and_resize(
image,
boxes=instance_bbox,
Expand All @@ -91,11 +91,9 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
center_instance = instance - point

instance_example = {
"instance_image": instance_image.squeeze(
0
), # (B=1, channels, crop_height, crop_width)
"instance_image": instance_image, # (B=1, C, crop_H, crop_W)
"instance_bbox": instance_bbox, # (B=1, 4, 2)
"instance": center_instance, # (num_nodes, 2)
"instance": center_instance.unsqueeze(0), # (B=1, num_nodes, 2)
}
ex.update(instance_example)
yield ex
84 changes: 78 additions & 6 deletions sleap_nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,25 @@ def __init__(self, data_config: DictConfig) -> None:
"""Initialize the data config."""
self.data_config = data_config

def make_training_pipeline(
self, data_provider: IterDataPipe, filename: str
) -> IterDataPipe:
def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe:
"""Create training pipeline with input data only.

Args:
data_provider: A `Provider` that generates data examples, typically a
`LabelsReader` instance.
filename: A string path to the name of the `.slp` file.

Returns:
An `IterDataPipe` instance configured to produce input examples.
"""
datapipe = data_provider.from_filename(filename=filename)
datapipe = data_provider
datapipe = Normalizer(datapipe)

if self.data_config.augmentation_config.use_augmentations:
datapipe = KorniaAugmenter(
datapipe,
**dict(self.data_config.augmentation_config.augmentations.intensity),
image_key="image",
instance_key="instances",
)

datapipe = InstanceCentroidFinder(
Expand All @@ -56,19 +55,92 @@ def make_training_pipeline(
datapipe,
random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw,
random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p,
image_key="instance_image",
instance_key="instance",
)

if self.data_config.augmentation_config.use_augmentations:
datapipe = KorniaAugmenter(
datapipe,
**dict(self.data_config.augmentation_config.augmentations.geometric),
image_key="instance_image",
instance_key="instance",
)

datapipe = ConfidenceMapGenerator(
datapipe,
sigma=self.data_config.preprocessing.conf_map_gen.sigma,
output_stride=self.data_config.preprocessing.conf_map_gen.output_stride,
image_key="instance_image",
instance_key="instance",
)
datapipe = KeyFilter(
datapipe,
keep_keys=[
"image",
"instances",
"centroids",
"instance",
"instance_bbox",
"instance_image",
"confidence_maps",
],
)

return datapipe


class SingleInstanceConfmapsPipeline:
"""Pipeline builder for single-instance confidence map models.

Attributes:
data_config: Data-related configuration.
"""

def __init__(self, data_config: DictConfig) -> None:
"""Initialize the data config."""
self.data_config = data_config

def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe:
"""Create training pipeline with input data only.

Args:
data_provider: A `Provider` that generates data examples, typically a
`LabelsReader` instance.

Returns:
An `IterDataPipe` instance configured to produce input examples.
"""
datapipe = data_provider
datapipe = Normalizer(datapipe)

if self.data_config.augmentation_config.use_augmentations:
datapipe = KorniaAugmenter(
datapipe,
**dict(self.data_config.augmentation_config.augmentations.intensity),
**dict(self.data_config.augmentation_config.augmentations.geometric),
image_key="image",
instance_key="instances",
)

if self.data_config.augmentation_config.random_crop.random_crop_p:
datapipe = KorniaAugmenter(
datapipe,
random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw,
random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p,
image_key="image",
instance_key="instances",
)

datapipe = ConfidenceMapGenerator(
datapipe,
sigma=self.data_config.preprocessing.conf_map_gen.sigma,
output_stride=self.data_config.preprocessing.conf_map_gen.output_stride,
image_key="image",
instance_key="instances",
)
datapipe = KeyFilter(
datapipe, keep_keys=["image", "instances", "confidence_maps"]
)
datapipe = KeyFilter(datapipe, keep_keys=self.data_config.general.keep_keys)

return datapipe
1 change: 1 addition & 0 deletions sleap_nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class LabelsReader(IterDataPipe):
labels: sleap_io.Labels object that contains LabeledFrames that will be
accessed through a torchdata DataPipe
user_instances_only: True if filter labels only to user instances else False. Default value True

"""

def __init__(self, labels: sio.Labels, user_instances_only: bool = True):
Expand Down
4 changes: 2 additions & 2 deletions sleap_nn/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class MatchInstance:


def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
"""Function to get a list of instances of type MatchInstance from the Labeled Frame.
"""Get a list of instances of type MatchInstance from the Labeled Frame.

Args:
labeled_frame: Input Labeled frame of type sio.LabeledFrame.
Expand Down Expand Up @@ -555,7 +555,7 @@ def voc_metrics(
}

def mOKS(self):
"""Returns the meanOKS value."""
"""Return the meanOKS value."""
pair_oks = np.array([oks for _, _, oks in self.positive_pairs])
return {"mOKS": pair_oks.mean()}

Expand Down
Loading