Skip to content

Commit

Permalink
Fix changes
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 committed Oct 3, 2024
1 parent ef744e1 commit 80023b3
Show file tree
Hide file tree
Showing 21 changed files with 104 additions and 102 deletions.
4 changes: 2 additions & 2 deletions sleap_nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def make_training_pipeline(
Args:
data_provider: A `Provider` that generates data examples, typically a
`LabelsReader` instance.
`LabelsReaderDP` instance.
use_augmentations: `True` if augmentations should be applied to the training
pipeline, else `False`. Default: `False`.
Expand Down Expand Up @@ -353,7 +353,7 @@ def make_training_pipeline(
Args:
data_provider: A `Provider` that generates data examples, typically a
`LabelsReader` instance.
`LabelsReaderDP` instance.
use_augmentations: `True` if augmentations should be applied to the training
pipeline, else `False`. Default: `False`.
Expand Down
8 changes: 4 additions & 4 deletions sleap_nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def process_lf(
return ex


class LabelsReader(IterDataPipe):
class LabelsReaderDP(IterDataPipe):
"""IterDataPipe for reading frames from Labels object.
This IterDataPipe will produce examples containing a frame and an sleap_io.Instance
Expand Down Expand Up @@ -152,7 +152,7 @@ def from_filename(
user_instances_only: bool = True,
instances_key: bool = True,
):
"""Create LabelsReader from a .slp filename."""
"""Create LabelsReaderDP from a .slp filename."""
labels = sio.load_slp(filename)
return cls(labels, user_instances_only, instances_key)

Expand Down Expand Up @@ -252,7 +252,7 @@ def from_filename(
start_idx: Optional[int] = None,
end_idx: Optional[int] = None,
):
"""Create LabelsReader from a .slp filename."""
"""Create VideoReader from a .slp filename."""
video = sio.load_video(filename)
frame_buffer = Queue(maxsize=queue_maxsize)
return cls(video, frame_buffer, start_idx, end_idx)
Expand Down Expand Up @@ -288,7 +288,7 @@ def run(self):
)


class LabelReader(Thread):
class LabelsReader(Thread):
"""Thread module for reading images from sleap-io Labels object.
This module will load the images from `.slp` files and pushes them as Tensors into a
Expand Down
4 changes: 2 additions & 2 deletions sleap_nn/data/resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import torch.nn.functional as F
from sleap_nn.data.providers import LabelsReader, VideoReader
from sleap_nn.data.providers import LabelsReaderDP, VideoReader
import torchvision.transforms.v2.functional as tvf
from torch.utils.data.datapipes.datapipe import IterDataPipe

Expand Down Expand Up @@ -230,7 +230,7 @@ class SizeMatcher(IterDataPipe):
def __init__(
self,
source_datapipe: IterDataPipe,
provider: Optional[Union[LabelsReader, VideoReader]] = None,
provider: Optional[Union[LabelsReaderDP, VideoReader]] = None,
max_height: Optional[int] = None,
max_width: Optional[int] = None,
):
Expand Down
49 changes: 25 additions & 24 deletions sleap_nn/inference/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import lightning as L
import litdata as ld
from omegaconf import OmegaConf
from sleap_nn.data.providers import LabelReader, VideoReader
from sleap_nn.data.providers import LabelsReader, VideoReader
from sleap_nn.data.resizing import (
resize_image,
apply_pad_to_stride,
Expand Down Expand Up @@ -58,9 +58,9 @@ class Predictor(ABC):
preprocess_config: Preprocessing config with keys: [`batch_size`,
`scale`, `is_rgb`, `max_stride`]. Default: {"batch_size": 4, "scale": 1.0,
"is_rgb": False, "max_stride": 1}
provider: Provider for inference pipeline. One of ["LabelReader", "VideoReader"].
Default: LabelReader.
pipeline: If provider is LabelReader, pipeline is a `DataLoader` object. If provider
provider: Provider for inference pipeline. One of ["LabelsReader", "VideoReader"].
Default: LabelsReader.
pipeline: If provider is LabelsReader, pipeline is a `DataLoader` object. If provider
is VideoReader, pipeline is an instance of `sleap_nn.data.providers.VideoReader`
class. Default: None.
inference_model: Instance of one of the inference models ["TopDownInferenceModel",
Expand All @@ -75,8 +75,8 @@ class Predictor(ABC):
"is_rgb": False,
"max_stride": 1,
}
provider: Union[LabelReader, VideoReader] = LabelReader
pipeline: Optional[Union[LabelReader, VideoReader]] = None
provider: Union[LabelsReader, VideoReader] = LabelsReader
pipeline: Optional[Union[LabelsReader, VideoReader]] = None
inference_model: Optional[
Union[
TopDownInferenceModel, SingleInstanceInferenceModel, BottomUpInferenceModel
Expand Down Expand Up @@ -265,6 +265,7 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]:
if self.instances_key:
instances.append(frame["instances"].unsqueeze(dim=0))
if imgs:
# TODO: all preprocessing should be moved into InferenceModels to be exportable.
imgs = torch.concatenate(imgs, dim=0)
fidxs = torch.tensor(fidxs, dtype=torch.int32)
vidxs = torch.tensor(vidxs, dtype=torch.int32)
Expand Down Expand Up @@ -522,7 +523,7 @@ def from_trained_models(
An instance of `TopDownPredictor` with the loaded models.
One of the two models can be left as `None` to perform inference with ground
truth data. This will only work with `LabelReader` as the provider.
truth data. This will only work with `LabelsReader` as the provider.
"""
if centroid_ckpt_path is not None:
Expand Down Expand Up @@ -591,7 +592,7 @@ def make_pipeline(
Args:
provider: (str) Provider class to read the input sleap files.
Either "LabelReader" or "VideoReader".
Either "LabelsReader" or "VideoReader".
data_path: (str) Path to `.slp` file or `.mp4` to run inference on.
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
video_start_idx: (int) Start index of the frames to read. Default: None.
Expand All @@ -603,9 +604,9 @@ def make_pipeline(
"""
self.provider = provider

# LabelReader provider
if self.provider == "LabelReader":
provider = LabelReader
# LabelsReader provider
if self.provider == "LabelsReader":
provider = LabelsReader

if self.centroid_config is not None:
max_stride = (
Expand Down Expand Up @@ -660,7 +661,7 @@ def make_pipeline(

else:
raise Exception(
"Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider"
"Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider"
)

def _make_labeled_frames_from_generator(
Expand Down Expand Up @@ -882,7 +883,7 @@ def make_pipeline(
Args:
provider: (str) Provider class to read the input sleap files.
Either "LabelReader" or "VideoReader".
Either "LabelsReader" or "VideoReader".
data_path: (str) Path to `.slp` file or `.mp4` to run inference on.
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
video_start_idx: (int) Start index of the frames to read. Default: None.
Expand All @@ -895,9 +896,9 @@ def make_pipeline(
"""
self.provider = provider

# LabelReader provider
if self.provider == "LabelReader":
provider = LabelReader
# LabelsReader provider
if self.provider == "LabelsReader":
provider = LabelsReader

max_stride = self.confmap_config.model_config.backbone_config.max_stride

Expand Down Expand Up @@ -938,7 +939,7 @@ def make_pipeline(

else:
raise Exception(
"Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider"
"Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider"
)

def _make_labeled_frames_from_generator(
Expand Down Expand Up @@ -1191,7 +1192,7 @@ def make_pipeline(
Args:
provider: (str) Provider class to read the input sleap files.
Either "LabelReader" or "VideoReader".
Either "LabelsReader" or "VideoReader".
data_path: (str) Path to `.slp` file or `.mp4` to run inference on.
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
video_start_idx: (int) Start index of the frames to read. Default: None.
Expand All @@ -1202,9 +1203,9 @@ def make_pipeline(
Thread is started in Predictor._predict_generator() method.
"""
self.provider = provider
# LabelReader provider
if self.provider == "LabelReader":
provider = LabelReader
# LabelsReader provider
if self.provider == "LabelsReader":
provider = LabelsReader

max_stride = self.bottomup_config.model_config.backbone_config.max_stride

Expand Down Expand Up @@ -1245,7 +1246,7 @@ def make_pipeline(

else:
raise Exception(
"Provider not recognised. Please use either `LabelReader` or `VideoReader` as provider"
"Provider not recognised. Please use either `LabelsReader` or `VideoReader` as provider"
)

def _make_labeled_frames_from_generator(
Expand Down Expand Up @@ -1346,7 +1347,7 @@ def main(
max_width: int = None,
max_height: int = None,
is_rgb: bool = False,
provider: str = "LabelReader",
provider: str = "LabelsReader",
batch_size: int = 4,
queue_maxsize: int = 8,
videoreader_start_idx: Optional[int] = None,
Expand Down Expand Up @@ -1397,7 +1398,7 @@ def main(
is set to False, then we convert the image to grayscale (single-channel)
image. Default: False.
provider: (str) Provider class to read the input sleap files.
Either "LabelReader" or "VideoReader". Default: LabelReader.
Either "LabelsReader" or "VideoReader". Default: LabelsReader.
batch_size: (int) Number of samples per batch. Default: 4.
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
videoreader_start_idx: (int) Start index of the frames to read. Default: None.
Expand Down
4 changes: 2 additions & 2 deletions sleap_nn/inference/topdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class CentroidCrop(L.LightningModule):

def __init__(
self,
torch_model: L.LightningModule = None,
torch_model: Optional[L.LightningModule] = None,
output_stride: int = 1,
peak_threshold: float = 0.0,
max_instances: Optional[int] = None,
Expand All @@ -70,7 +70,7 @@ def __init__(
input_scale: float = 1.0,
max_stride: int = 1,
use_gt_centroids: bool = False,
anchor_ind: int = None,
anchor_ind: Optional[int] = None,
**kwargs,
):
"""Initialise the model attributes."""
Expand Down
4 changes: 0 additions & 4 deletions sleap_nn/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
import shutil
import torch
import sleap_io as sio
from torch.utils.data import DataLoader
from omegaconf import OmegaConf
import lightning as L
import litdata as ld
from sleap_nn.data.providers import LabelsReader
from sleap_nn.data.pipelines import (
TopdownConfmapsPipeline,
SingleInstanceConfmapsPipeline,
Expand Down Expand Up @@ -145,8 +143,6 @@ def _get_data_chunks(self, func, train_labels, val_labels):
def _create_data_loaders(self):
"""Create a DataLoader for train, validation and test sets using the data_config."""
self.provider = self.config.data_config.provider
if self.provider == "LabelsReader":
self.provider = LabelsReader

train_labels = sio.load_slp(self.config.data_config.train_labels_path)
val_labels = sio.load_slp(self.config.data_config.val_labels_path)
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from sleap_nn.data.normalization import apply_normalization
from sleap_nn.data.providers import process_lf
from sleap_nn.data.normalization import Normalizer
from sleap_nn.data.providers import LabelsReader
from sleap_nn.data.providers import LabelsReaderDP


def test_uniform_noise(minimal_instance):
"""Test RandomUniformNoise module."""
p = LabelsReader.from_filename(minimal_instance)
p = LabelsReaderDP.from_filename(minimal_instance)
p = Normalizer(p)

sample = next(iter(p))
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_apply_geometric_augmentation(minimal_instance):

def test_kornia_augmentation(minimal_instance):
"""Test KorniaAugmenter module."""
p = LabelsReader.from_filename(minimal_instance)
p = LabelsReaderDP.from_filename(minimal_instance)

p = Normalizer(p)
p = KorniaAugmenter(
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_kornia_augmentation(minimal_instance):
assert pts.shape == (1, 2, 2, 2)

# Test RandomCrop value error.
p = LabelsReader.from_filename(minimal_instance)
p = LabelsReaderDP.from_filename(minimal_instance)
p = Normalizer(p)
with pytest.raises(
ValueError, match="crop_hw height and width must be greater than 0."
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_confmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from sleap_nn.data.instance_cropping import InstanceCropper
from sleap_nn.data.normalization import Normalizer
from sleap_nn.data.resizing import Resizer
from sleap_nn.data.providers import LabelsReader, process_lf
from sleap_nn.data.providers import LabelsReaderDP, process_lf
from sleap_nn.data.utils import make_grid_vectors
import numpy as np


def test_confmaps(minimal_instance):
"""Test ConfidenceMapGenerator module."""
datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = LabelsReaderDP.from_filename(minimal_instance)
datapipe = InstanceCentroidFinder(datapipe)
datapipe = Normalizer(datapipe)
datapipe = InstanceCropper(datapipe, (100, 100))
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_confmaps(minimal_instance):
def test_multi_confmaps(minimal_instance):
"""Test MultiConfidenceMapGenerator module."""
# centroids = True
datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = LabelsReaderDP.from_filename(minimal_instance)
datapipe = Normalizer(datapipe)
datapipe = InstanceCentroidFinder(datapipe)
datapipe1 = MultiConfidenceMapGenerator(
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_multi_confmaps(minimal_instance):
torch.testing.assert_close(gt, cms[0][0], atol=0.001, rtol=0.0)

# centroids = False (for instances)
datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = LabelsReaderDP.from_filename(minimal_instance)
datapipe = Normalizer(datapipe)
datapipe = Resizer(datapipe, scale=2)
datapipe = InstanceCentroidFinder(datapipe)
Expand Down
4 changes: 2 additions & 2 deletions tests/data/test_edge_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import sleap_io as sio
from sleap_nn.data.utils import make_grid_vectors
from sleap_nn.data.providers import LabelsReader, process_lf
from sleap_nn.data.providers import LabelsReaderDP, process_lf
from sleap_nn.data.edge_maps import (
distance_to_edge,
make_edge_maps,
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_generate_pafs(minimal_instance):


def test_part_affinity_fields_generator(minimal_instance):
provider = LabelsReader.from_filename(minimal_instance)
provider = LabelsReaderDP.from_filename(minimal_instance)
paf_generator = PartAffinityFieldsGenerator(
provider,
sigma=8,
Expand Down
4 changes: 2 additions & 2 deletions tests/data/test_instance_centroids.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
InstanceCentroidFinder,
generate_centroids,
)
from sleap_nn.data.providers import LabelsReader, process_lf
from sleap_nn.data.providers import LabelsReaderDP, process_lf


def test_generate_centroids(minimal_instance):
Expand Down Expand Up @@ -38,7 +38,7 @@ def test_generate_centroids(minimal_instance):
def test_instance_centroids(minimal_instance):
"""Test InstanceCentroidFinder and generate_centroids functions."""
# Undefined anchor_ind.
datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = LabelsReaderDP.from_filename(minimal_instance)
datapipe = InstanceCentroidFinder(datapipe)
sample = next(iter(datapipe))
instances = sample["instances"]
Expand Down
4 changes: 2 additions & 2 deletions tests/data/test_instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from sleap_nn.data.normalization import Normalizer, apply_normalization
from sleap_nn.data.resizing import SizeMatcher, Resizer, PadToStride
from sleap_nn.data.providers import LabelsReader, process_lf
from sleap_nn.data.providers import LabelsReaderDP, process_lf


def test_find_instance_crop_size(minimal_instance):
Expand Down Expand Up @@ -44,7 +44,7 @@ def test_make_centered_bboxes():

def test_instance_cropper(minimal_instance):
"""Test InstanceCropper module."""
provider = LabelsReader.from_filename(minimal_instance)
provider = LabelsReaderDP.from_filename(minimal_instance)
provider.max_instances = 3
datapipe = Normalizer(provider)
datapipe = SizeMatcher(datapipe, provider)
Expand Down
Loading

0 comments on commit 80023b3

Please sign in to comment.