Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove IterDataPipe from Inference pipeline #96

Merged
merged 11 commits into from
Oct 3, 2024
4 changes: 2 additions & 2 deletions sleap_nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def make_training_pipeline(

Args:
data_provider: A `Provider` that generates data examples, typically a
`LabelsReader` instance.
`LabelsReaderDP` instance.
use_augmentations: `True` if augmentations should be applied to the training
pipeline, else `False`. Default: `False`.

Expand Down Expand Up @@ -353,7 +353,7 @@ def make_training_pipeline(

Args:
data_provider: A `Provider` that generates data examples, typically a
`LabelsReader` instance.
`LabelsReaderDP` instance.
use_augmentations: `True` if augmentations should be applied to the training
pipeline, else `False`. Default: `False`.

Expand Down
143 changes: 129 additions & 14 deletions sleap_nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def process_lf(
return ex


class LabelsReader(IterDataPipe):
class LabelsReaderDP(IterDataPipe):
"""IterDataPipe for reading frames from Labels object.

This IterDataPipe will produce examples containing a frame and an sleap_io.Instance
Expand Down Expand Up @@ -152,7 +152,7 @@ def from_filename(
user_instances_only: bool = True,
instances_key: bool = True,
):
"""Create LabelsReader from a .slp filename."""
"""Create LabelsReaderDP from a .slp filename."""
labels = sio.load_slp(filename)
return cls(labels, user_instances_only, instances_key)

Expand Down Expand Up @@ -205,13 +205,13 @@ class VideoReader(Thread):
"""Thread module for reading frames from sleap-io Video object.

This module will load the frames from video and pushes them as Tensors into a buffer
queue as a tuple in the format (image, frame index, (height, width)) which are then
batched and consumed during the inference process.
queue as a dictionary with (image, frame index, video index, (height, width))
which are then batched and consumed during the inference process.

Attributes:
video: sleap_io.Video object that contains LabeledFrames that will be
video: sleap_io.Video object that contains images that will be
accessed through a torchdata DataPipe.
frame_buffer: Maximum size of the frame buffer queue.
frame_buffer: Frame buffer queue.
start_idx: start index of the frames to read. If None, 0 is set as the default.
end_idx: end index of the frames to read. If None, length of the video is set as
the default.
Expand Down Expand Up @@ -248,12 +248,13 @@ def max_height_and_width(self) -> Tuple[int, int]:
def from_filename(
cls,
filename: str,
frame_buffer: Queue,
queue_maxsize: int,
start_idx: Optional[int] = None,
end_idx: Optional[int] = None,
):
"""Create LabelsReader from a .slp filename."""
"""Create VideoReader from a .slp filename."""
video = sio.load_video(filename)
frame_buffer = Queue(maxsize=queue_maxsize)
return cls(video, frame_buffer, start_idx, end_idx)

def run(self):
Expand All @@ -265,15 +266,129 @@ def run(self):
img = np.expand_dims(img, axis=0) # (1, C, H, W)

self.frame_buffer.put(
(
torch.from_numpy(img),
torch.tensor(idx, dtype=torch.int32),
torch.Tensor(img.shape[-2:]),
)
{
"image": torch.from_numpy(img),
"frame_idx": torch.tensor(idx, dtype=torch.int32),
"video_idx": torch.tensor(0, dtype=torch.int32),
"orig_size": torch.Tensor(img.shape[-2:]),
}
)

except Exception as e:
print(f"Error when reading video frame. Stopping video reader.\n{e}")

finally:
self.frame_buffer.put((None, None, None))
self.frame_buffer.put(
{
"image": None,
"frame_idx": None,
"video_idx": None,
"orig_size": None,
}
)


class LabelsReader(Thread):
"""Thread module for reading images from sleap-io Labels object.

This module will load the images from `.slp` files and pushes them as Tensors into a
buffer queue as a dictionary with (image, frame index, video index, (height, width))
which are then batched and consumed during the inference process.

Attributes:
labels: sleap_io.Labels object that contains LabeledFrames that will be
accessed through a torchdata DataPipe.
frame_buffer: Frame buffer queue.
instances_key: If `True`, then instances are appended to the output dictionary.
"""

def __init__(
self, labels: sio.Labels, frame_buffer: Queue, instances_key: bool = False
):
"""Initialize attribute of the class."""
super().__init__()
self.labels = labels
self.frame_buffer = frame_buffer
self.instances_key = instances_key
self.max_instances = get_max_instances(self.labels)

def total_len(self):
"""Returns the total number of frames in the video."""
return len(self.labels)

@property
def max_height_and_width(self) -> Tuple[int, int]:
"""Return `(height, width)` of frames in the video."""
return max(video.shape[1] for video in self.labels.videos), max(
video.shape[2] for video in self.labels.videos
)

@classmethod
def from_filename(
cls, filename: str, queue_maxsize: int, instances_key: bool = False
):
"""Create LabelsReader from a .slp filename."""
labels = sio.load_slp(filename)
frame_buffer = Queue(maxsize=queue_maxsize)
return cls(labels, frame_buffer, instances_key)

def run(self):
"""Adds frames to the buffer queue."""
try:
for idx in range(self.total_len()):
lf = self.labels[idx]
img = lf.image
img = np.transpose(img, (2, 0, 1)) # convert H,W,C to C,H,W
img = np.expand_dims(img, axis=0) # (1, C, H, W)

sample = {
"image": torch.from_numpy(img),
"frame_idx": torch.tensor(lf.frame_idx, dtype=torch.int32),
"video_idx": torch.tensor(
self.labels.videos.index(lf.video), dtype=torch.int32
),
"orig_size": torch.Tensor(img.shape[-2:]),
}

if self.instances_key:
instances = []
for inst in lf:
if not inst.is_empty:
instances.append(inst.numpy())
instances = np.stack(instances, axis=0)

# Add singleton time dimension for single frames.
instances = np.expand_dims(
instances, axis=0
) # (n_samples=1, num_instances, num_nodes, 2)

instances = torch.from_numpy(instances.astype("float32"))

num_instances, nodes = instances.shape[1:3]

# append with nans for broadcasting
if self.max_instances != 1:
nans = torch.full(
(1, np.abs(self.max_instances - num_instances), nodes, 2),
torch.nan,
)
instances = torch.cat(
[instances, nans], dim=1
) # (n_samples, max_instances, num_nodes, 2)

sample["instances"] = instances

self.frame_buffer.put(sample)

except Exception as e:
print(f"Error when reading labelled frame. Stopping labels reader.\n{e}")

finally:
self.frame_buffer.put(
{
"image": None,
"frame_idx": None,
"video_idx": None,
"orig_size": None,
}
)
4 changes: 2 additions & 2 deletions sleap_nn/data/resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import torch.nn.functional as F
from sleap_nn.data.providers import LabelsReader, VideoReader
from sleap_nn.data.providers import LabelsReaderDP, VideoReader
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
import torchvision.transforms.v2.functional as tvf
from torch.utils.data.datapipes.datapipe import IterDataPipe

Expand Down Expand Up @@ -230,7 +230,7 @@ class SizeMatcher(IterDataPipe):
def __init__(
self,
source_datapipe: IterDataPipe,
provider: Optional[Union[LabelsReader, VideoReader]] = None,
provider: Optional[Union[LabelsReaderDP, VideoReader]] = None,
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
max_height: Optional[int] = None,
max_width: Optional[int] = None,
):
Expand Down
Loading
Loading