Skip to content

Commit

Permalink
updated toml; added keep keys explicitly; added test case for singlei…
Browse files Browse the repository at this point in the history
…nstance pipeline
  • Loading branch information
alckasoc committed Nov 1, 2023
1 parent 4d6e355 commit c68de13
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Repository = "https://github.com/talmolab/sleap-nn"
line-length = 88

[tool.ruff]
output-format = "github"
format = "github"
select = [
"D", # pydocstyle
]
Expand Down
17 changes: 15 additions & 2 deletions sleap_nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,18 @@ def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe:
image_key="instance_image",
instance_key="instance",
)
datapipe = KeyFilter(datapipe, keep_keys=None)
datapipe = KeyFilter(
datapipe,
keep_keys=[
"image",
"instances",
"centroids",
"instance",
"instance_bbox",
"instance_image",
"confidence_maps",
],
)

return datapipe

Expand Down Expand Up @@ -128,6 +139,8 @@ def make_training_pipeline(self, data_provider: IterDataPipe) -> IterDataPipe:
image_key="image",
instance_key="instances",
)
datapipe = KeyFilter(datapipe, keep_keys=None)
datapipe = KeyFilter(
datapipe, keep_keys=["image", "instances", "confidence_maps"]
)

return datapipe
29 changes: 24 additions & 5 deletions sleap_nn/data/providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module implements pipeline blocks for reading input data such as labels."""
from typing import Dict, Iterator

import numpy as np
import sleap_io as sio
import torch
Expand All @@ -8,25 +9,40 @@


class LabelsReader(IterDataPipe):
"""Reading frames from Labels object DataPipe.
"""Datapipe for reading frames from Labels object.
This DataPipe will produce examples containing a frame and an sleap_io.Instance
from a sleap_io.Labels instance.
Attributes:
labels: sleap_io.Labels object that contains LabeledFrames that will be
accessed through a torchdata DataPipe
user_instances_only: True if filter labels only to user instances else False. Default value True
"""

def __init__(self, labels: sio.Labels):
def __init__(self, labels: sio.Labels, user_instances_only: bool = True):
"""Initialize labels attribute of the class."""
self.labels = labels
self.labels = copy.deepcopy(labels)

# Filter to user instances
if user_instances_only:
filtered_lfs = []
for lf in self.labels:
if lf.user_instances is not None and len(lf.user_instances) > 0:
lf.instances = lf.user_instances
filtered_lfs.append(lf)
self.labels = sio.Labels(
videos=labels.videos,
skeletons=[labels.skeletons],
labeled_frames=filtered_lfs,
)

@classmethod
def from_filename(cls, filename: str):
def from_filename(cls, filename: str, user_instances_only: bool = True):
"""Create LabelsReader from a .slp filename."""
labels = sio.load_slp(filename)
return cls(labels)
return cls(labels, user_instances_only)

def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Return an example dictionary containing the following elements.
Expand All @@ -38,15 +54,18 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""
for lf in self.labels:
image = np.transpose(lf.image, (2, 0, 1)) # HWC -> CHW

instances = []
for inst in lf:
instances.append(inst.numpy())
instances = np.stack(instances, axis=0)

# Add singleton time dimension for single frames.
image = np.expand_dims(image, axis=0) # (1, C, H, W)
instances = np.expand_dims(
instances, axis=0
) # (1, num_instances, num_nodes, 2)

yield {
"image": torch.from_numpy(image),
"instances": torch.from_numpy(instances.astype("float32")),
Expand Down
55 changes: 55 additions & 0 deletions tests/data/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,58 @@ def test_singleinstanceconfmapspipeline(minimal_instance):
assert gt_key == key
assert sample["image"].shape == (1, 1, 384, 384)
assert sample["confidence_maps"].shape == (1, 2, 192, 192)

base_singleinstance_data_config = OmegaConf.create(
{
"preprocessing": {
"conf_map_gen": {"sigma": 1.5, "output_stride": 2},
},
"augmentation_config": {
"random_crop": {"random_crop_p": 1.0, "random_crop_hw": (160, 160)},
"use_augmentations": True,
"augmentations": {
"intensity": {
"uniform_noise": (0.0, 0.04),
"uniform_noise_p": 0.5,
"gaussian_noise_mean": 0.02,
"gaussian_noise_std": 0.004,
"gaussian_noise_p": 0.5,
"contrast": (0.5, 2.0),
"contrast_p": 0.5,
"brightness": 0.0,
"brightness_p": 0.5,
},
"geometric": {
"rotation": 15.0,
"scale": 0.05,
"translate": (0.02, 0.02),
"affine_p": 0.5,
"erase_scale": (0.0001, 0.01),
"erase_ratio": (1, 1),
"erase_p": 0.5,
"mixup_lambda": None,
"mixup_p": 0.5,
},
},
},
}
)

pipeline = SingleInstanceConfmapsPipeline(
data_config=base_singleinstance_data_config
)
data_provider = LabelsReader(labels=labels)
datapipe = pipeline.make_training_pipeline(data_provider=data_provider)

sample = next(iter(datapipe))

gt_sample_keys = [
"image",
"instances",
"confidence_maps",
]

for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())):
assert gt_key == key
assert sample["image"].shape == (1, 1, 160, 160)
assert sample["confidence_maps"].shape == (1, 2, 80, 80)

0 comments on commit c68de13

Please sign in to comment.