Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non contiguous clips #109

Merged
merged 7 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions dreem/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,103 @@ def __init__(
self.labels = None
self.gt_list = None

def process_segments(
self, i: int, segments_to_stitch: list[torch.Tensor], clip_length: int
) -> None:
"""Process segments to stitch. Modifies state variables chunked_frame_idx and label_idx.

Args:
segments_to_stitch: list of segments to stitch
i: index of the video
clip_length: the number of frames in each chunk
Returns: None
"""
stitched_segment = torch.cat(segments_to_stitch)
frame_idx_split = torch.split(stitched_segment, clip_length)
self.chunked_frame_idx.extend(frame_idx_split)
self.label_idx.extend(len(frame_idx_split) * [i])

Comment on lines +82 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure robust handling for dynamic attributes.
The method references self.max_batching_gap, which appears to be defined only in subclasses (e.g., SleapDataset). If another subclass or a direct instance of BaseDataset does not initialize self.max_batching_gap, this will raise an AttributeError. Consider assigning a default value in BaseDataset to avoid runtime issues.

+ # Example fix: Provide a default in BaseDataset if you expect each subclass to possibly override it
+ self.max_batching_gap = getattr(self, "max_batching_gap", 15)

def create_chunks(self) -> None:
"""Factory method to create chunks."""
if type(self).__name__ == "SleapDataset":
self.create_chunks_slp()
else:
self.create_chunks_other()

Comment on lines 98 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Replace type checking with polymorphism.

Using type(self).__name__ for routing is fragile and violates the Open/Closed Principle. Consider making create_chunks an abstract method and implementing it in each subclass.

-    def create_chunks(self) -> None:
-        """Factory method to create chunks."""
-        if type(self).__name__ == "SleapDataset":
-            self.create_chunks_slp()
-        else:
-            self.create_chunks_other()
+    @abstractmethod
+    def create_chunks(self) -> None:
+        """Create chunks for the dataset.
+        
+        This method should be implemented by subclasses to define their specific
+        chunking strategy.
+        """
+        pass
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def create_chunks(self) -> None:
"""Factory method to create chunks."""
if type(self).__name__ == "SleapDataset":
self.create_chunks_slp()
else:
self.create_chunks_other()
@abstractmethod
def create_chunks(self) -> None:
"""Create chunks for the dataset.
This method should be implemented by subclasses to define their specific
chunking strategy.
"""
pass

def create_chunks_slp(self) -> None:
"""Get indexing for data.

Creates both indexes for selecting dataset (label_idx) and frame in
dataset (chunked_frame_idx). If chunking is false, we index directly
using the frame ids. Setting chunking to true creates a list of lists
containing chunk frames for indexing. This is useful for computational
efficiency and data shuffling. To be called by subclass __init__()
"""
self.chunked_frame_idx, self.label_idx = [], []
# go through each slp file and create chunks that respect max_batching_gap
for i, slp_file in enumerate(self.label_files):
annotated_segments = self.annotated_segments[slp_file]
segments_to_stitch = []
prev_end = annotated_segments[0][1] # end of first segment
for start, end in annotated_segments:
# check if the start of current segment is within batching_max_gap of end of previous
if (
int(start) - int(prev_end) < self.max_batching_gap
) or not self.chunk: # also takes care of first segment as start < prev_end
segments_to_stitch.append(torch.arange(start, end + 1))
prev_end = end
else:
# stitch previous set of segments before creating a new chunk
self.process_segments(i, segments_to_stitch, self.clip_length)
# reset segments_to_stitch as we are starting a new chunk
segments_to_stitch = [torch.arange(start, end + 1)]
prev_end = end

if not self.chunk:
self.process_segments(
i, segments_to_stitch, self.labels[i].video.shape[0]
)
else:
# add last chunk after the loop
if segments_to_stitch:
self.process_segments(i, segments_to_stitch, self.clip_length)

if self.n_chunks > 0 and self.n_chunks <= 1.0:
n_chunks = int(self.n_chunks * len(self.chunked_frame_idx))

elif self.n_chunks <= len(self.chunked_frame_idx):
n_chunks = int(self.n_chunks)

else:
n_chunks = len(self.chunked_frame_idx)

if n_chunks > 0 and n_chunks < len(self.chunked_frame_idx):
sample_idx = np.random.choice(
np.arange(len(self.chunked_frame_idx)), n_chunks, replace=False
)

self.chunked_frame_idx = [self.chunked_frame_idx[i] for i in sample_idx]

self.label_idx = [self.label_idx[i] for i in sample_idx]

# workaround for empty batch bug (needs to be changed). Check for batch with with only 1/10 size of clip length. Arbitrary thresholds
remove_idx = []
for i, frame_chunk in enumerate(self.chunked_frame_idx):
if (
len(frame_chunk)
<= min(int(self.clip_length / 10), 5)
# and frame_chunk[-1] % self.clip_length == 0
):
logger.warning(
f"Warning: Batch containing frames {frame_chunk} from video {self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} frames. Removing to avoid empty batch possibility with failed frame loading"
)
remove_idx.append(i)
if len(remove_idx) > 0:
for i in sorted(remove_idx, reverse=True):
self.chunked_frame_idx.pop(i)
self.label_idx.pop(i)
Comment on lines +162 to +176
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Revisit “empty batch” removal threshold.
The arbitrary heuristic (≤ clip_length/10 or 5 frames) for removing “small” chunks might prematurely discard valid shorter sequences. Confirm this logic aligns with the project’s “Non contiguous clips” objectives, which support sparse annotation. If partial clips are acceptable, consider a smaller threshold or a user-configurable parameter.

Comment on lines +143 to +176
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Extract duplicated chunk processing logic.

The n_chunks calculation and small batch removal logic is duplicated between create_chunks_slp and create_chunks_other. Consider extracting these into helper methods.

+    def _calculate_n_chunks(self, total_chunks: int) -> int:
+        """Calculate the number of chunks to use based on self.n_chunks."""
+        if self.n_chunks > 0 and self.n_chunks <= 1.0:
+            return int(self.n_chunks * total_chunks)
+        elif self.n_chunks <= total_chunks:
+            return int(self.n_chunks)
+        return total_chunks
+
+    def _remove_small_chunks(self) -> None:
+        """Remove chunks that are too small to avoid empty batch issues."""
+        remove_idx = []
+        for i, frame_chunk in enumerate(self.chunked_frame_idx):
+            if len(frame_chunk) <= min(int(self.clip_length / 10), 5):
+                logger.warning(
+                    f"Warning: Batch containing frames {frame_chunk} from video "
+                    f"{self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} "
+                    "frames. Removing to avoid empty batch possibility."
+                )
+                remove_idx.append(i)
+        for i in sorted(remove_idx, reverse=True):
+            self.chunked_frame_idx.pop(i)
+            self.label_idx.pop(i)

Then use these helper methods in both create_chunks_slp and create_chunks_other.

Also applies to: 194-227


def create_chunks_other(self) -> None:
"""Get indexing for data.

Creates both indexes for selecting dataset (label_idx) and frame in
Expand Down
74 changes: 74 additions & 0 deletions dreem/datasets/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,80 @@
import pandas as pd
import sleap_io as sio
import torch
from sleap_io.io.slp import (
read_hdf5_attrs,
read_tracks,
read_videos,
read_skeletons,
read_points,
read_pred_points,
read_instances,
read_metadata,
read_hdf5_dataset,
)
from sleap_io import Labels, LabeledFrame


def load_slp(labels_path: str, open_videos: bool = True) -> Labels:
"""Read a SLEAP labels file.

Args:
labels_path: A string path to the SLEAP labels file.
open_videos: If `True` (the default), attempt to open the video backend for
I/O. If `False`, the backend will not be opened (useful for reading metadata
when the video files are not available).

Returns:
The processed `Labels` object.
"""
tracks = read_tracks(labels_path)
videos = read_videos(labels_path, open_backend=open_videos)
skeletons = read_skeletons(labels_path)
points = read_points(labels_path)
pred_points = read_pred_points(labels_path)
format_id = read_hdf5_attrs(labels_path, "metadata", "format_id")
instances = read_instances(
labels_path, skeletons, tracks, points, pred_points, format_id
)
metadata = read_metadata(labels_path)
provenance = metadata.get("provenance", dict())

frames = read_hdf5_dataset(labels_path, "frames")
labeled_frames = []
annotated_segments = []
curr_segment_start = frames[0][2]
curr_frame = curr_segment_start
# note that frames only contains frames with labelled instances, not all frames
for i, video_id, frame_idx, instance_id_start, instance_id_end in frames:
labeled_frames.append(
LabeledFrame(
video=videos[video_id],
frame_idx=int(frame_idx),
instances=instances[instance_id_start:instance_id_end],
)
)
if frame_idx == curr_frame:
pass
elif frame_idx == curr_frame + 1:
curr_frame = frame_idx
elif frame_idx > curr_frame + 1:
annotated_segments.append((curr_segment_start, curr_frame))
curr_segment_start = frame_idx
curr_frame = frame_idx

# add last segment
annotated_segments.append((curr_segment_start, curr_frame))

labels = Labels(
labeled_frames=labeled_frames,
videos=videos,
skeletons=skeletons,
tracks=tracks,
provenance=provenance,
)
labels.provenance["filename"] = labels_path

return labels, annotated_segments


def pad_bbox(bbox: ArrayLike, padding: int = 16) -> torch.Tensor:
Expand Down
52 changes: 24 additions & 28 deletions dreem/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
seed: int | None = None,
verbose: bool = False,
normalize_image: bool = True,
max_batching_gap: int = 15,
):
"""Initialize SleapDataset.

Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
seed: set a seed for reproducibility
verbose: boolean representing whether to print
normalize_image: whether to normalize the image to [0, 1]
max_batching_gap: the max number of frames that can be unlabelled before starting a new batch
"""
super().__init__(
slp_files,
Expand Down Expand Up @@ -101,6 +103,7 @@ def __init__(
self.n_chunks = n_chunks
self.seed = seed
self.normalize_image = normalize_image
self.max_batching_gap = max_batching_gap
if self.data_dirs is None:
self.data_dirs = []
if isinstance(anchors, int):
Expand Down Expand Up @@ -136,20 +139,18 @@ def __init__(

# if self.seed is not None:
# np.random.seed(self.seed)
self.labels = [sio.load_slp(slp_file) for slp_file in self.slp_files]

# load_slp is a wrapper around sio.load_slp for frame gap checks
self.labels = []
self.annotated_segments = {}
for slp_file in self.slp_files:
labels, annotated_segments = data_utils.load_slp(slp_file)
self.labels.append(labels)
self.annotated_segments[slp_file] = annotated_segments

self.videos = [imageio.get_reader(vid_file) for vid_file in self.vid_files]
# do we need this? would need to update with sleap-io

# for label in self.labels:
# label.remove_empty_instances(keep_empty_frames=False)

# note if slp is missing frames, taking last frame idx is safer than len(labels)
# as there will be fewer labeledframes than actual frames
self.frame_idx = [
torch.arange(labels[-1].frame_idx + 1) for labels in self.labels
]
self.skipped_frame_ct = [0 for labels in self.labels]
# self.frame_idx = [torch.arange(len(labels)) for labels in self.labels]
# Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
# used in call to get_instances()
self.create_chunks()
Expand All @@ -162,18 +163,20 @@ def get_indices(self, idx: int) -> tuple:
"""
return self.label_idx[idx], self.chunked_frame_idx[idx]

def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Frame]:
def get_instances(
self, label_idx: list[int], frame_idx: torch.Tensor
) -> list[Frame]:
Comment on lines +166 to +168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add input validation for frame_idx tensor.

The parameter type change from list[int] to torch.Tensor requires validation to ensure correct tensor properties.

 def get_instances(
     self, label_idx: list[int], frame_idx: torch.Tensor
 ) -> list[Frame]:
+    if not isinstance(frame_idx, torch.Tensor):
+        raise TypeError("frame_idx must be a torch.Tensor")
+    if frame_idx.dtype not in [torch.int32, torch.int64]:
+        raise TypeError("frame_idx must contain integer values")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_instances(
self, label_idx: list[int], frame_idx: torch.Tensor
) -> list[Frame]:
def get_instances(
self, label_idx: list[int], frame_idx: torch.Tensor
) -> list[Frame]:
if not isinstance(frame_idx, torch.Tensor):
raise TypeError("frame_idx must be a torch.Tensor")
if frame_idx.dtype not in [torch.int32, torch.int64]:
raise TypeError("frame_idx must contain integer values")

"""Get an element of the dataset.

Args:
label_idx: index of the labels
frame_idx: index of the frames
frame_idx: indices of the frames to load in to the batch

Returns:
A list of `dreem.io.Frame` objects containing metadata and instance data for the batch/clip.

"""
video = self.labels[label_idx]
sleap_labels_obj = self.labels[label_idx]
video_name = self.video_files[label_idx]

# get the correct crop size based on the video
Expand All @@ -189,9 +192,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

vid_reader = self.videos[label_idx]

# img = vid_reader.get_data(0)

skeleton = video.skeletons[-1]
skeleton = sleap_labels_obj.skeletons[-1]

frames = []
for i, frame_ind in enumerate(frame_idx):
Expand All @@ -206,18 +207,13 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

frame_ind = int(frame_ind)

# if slp is missing instances in some frames, frame_ind will be smaller than lf.frame_idx
lf = video[frame_ind - self.skipped_frame_ct[label_idx]]
if frame_ind < lf.frame_idx:
logger.warning(
f"Frame index {frame_ind} is trying to access frame {lf.frame_idx} of the slp file {video_name}. "
f"This likely means there are no labelled instances in this frame. Skipping frame."
)
self.skipped_frame_ct[label_idx] += 1
continue
# sleap-io method for indexing a Labels() object based on the frame's index
lf = sleap_labels_obj[(sleap_labels_obj.video, frame_ind)]
if frame_ind != lf.frame_idx:
logger.warning(f"Frame index mismatch: {frame_ind} != {lf.frame_idx}")

try:
img = vid_reader.get_data(int(lf.frame_idx))
img = vid_reader.get_data(int(frame_ind))
except IndexError as e:
logger.warning(
f"Could not read frame {frame_ind} from {video_name} due to {e}"
Expand Down Expand Up @@ -253,7 +249,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
continue

if instance.track is not None:
gt_track_id = video.tracks.index(instance.track)
gt_track_id = sleap_labels_obj.tracks.index(instance.track)
else:
gt_track_id = -1
gt_track_ids.append(gt_track_id)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,47 @@ def test_sleap_dataset(two_flies):
n_chunks=ds_length + 10000,
)

train_ds = SleapDataset(
slp_files=[two_flies[0]],
video_files=[two_flies[1]],
data_dirs="./data/sleap",
crop_size=128,
chunk=True,
max_batching_gap=10,
clip_length=clip_length,
n_chunks=30,
)

train_ds = SleapDataset(
slp_files=[two_flies[0]],
video_files=[two_flies[1]],
data_dirs="./data/sleap",
crop_size=128,
chunk=True,
max_batching_gap=0,
clip_length=clip_length,
n_chunks=30,
)

train_ds = SleapDataset(
slp_files=[two_flies[0]],
video_files=[two_flies[1]],
data_dirs="./data/sleap",
crop_size=128,
chunk=False,
max_batching_gap=10,
clip_length=clip_length,
n_chunks=30,
)
Comment on lines +208 to +239
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add assertions to verify non-contiguous segments behavior.

The test cases for max_batching_gap parameter are missing assertions to verify that:

  1. The segments are correctly split when the gap exceeds max_batching_gap.
  2. The segments are correctly merged when the gap is within max_batching_gap.
  3. The non-chunked case correctly preserves all frames.

Do you want me to generate the assertions to verify the behavior of non-contiguous segments?



def test_icy_dataset(ten_icy_particles):
"""Test icy dataset logic.

Args:
ten_icy_particles: icy fixture used for testing
"""

clip_length = 8

train_ds = MicroscopyDataset(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def __getitem__(self, idx):

dl = torch.utils.data.DataLoader(DummyDataset())
model = GTRRunner()
trainer = Trainer(max_steps=1, min_steps=1)
trainer = Trainer(max_steps=1, min_steps=1, accelerator="cpu")
trainer.fit(model, dl)
trainer.save_checkpoint(ckpt_path)

Expand Down
Loading