Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non contiguous clips #109

Merged
merged 7 commits into from
Feb 8, 2025
Merged

Non contiguous clips #109

merged 7 commits into from
Feb 8, 2025

Conversation

shaikh58
Copy link
Contributor

@shaikh58 shaikh58 commented Feb 7, 2025

This PR implements a feature supporting non-contiguous annotated segments in training data. Currently, the pipeline expects the entire file to be annotated (except possibly a handful of missing frames), but does not support sparse annotation. This PR introduces support for this feature, as well as a redesign of chunking, and consistent indexing:

NOTE: Only SleapDataset is supported in this fix. Other Datasets will eventually need to be added. Currently, the other datasets are not in use. Test cases related to other Dataset types will fail

Highlights:

  • Training files can have any number of frames annotated, with any length of gaps between segments of annotated frames
  • A max_batching_gap parameter in the configs that controls how many frames can have no annotations before a new batch is started. Default value of 15 is set in SleapDataset init
  • Redesign of chunking - the create_chunks() method now makes all decisions on which frames enter a batch, while SleapDataset just has to load each image. It returns a complete list of batches with frame ids that correspond to the actual labeledFrame ids, respecting max_batching_gap. There is now no enumerated frame_id based on length of the video. As a result, there is no need to track missing frames anymore (makes a previous bug fix redundant)
  • Wraps sleap-io loader to return indices of start/end frames that are annotated

Summary by CodeRabbit

  • New Features

    • Enhanced video dataset creation with flexible stitching of video segments for smoother analysis.
    • Improved support for SLEAP label files enabling more effective annotation integration.
    • Added a configurable batching gap to optimize data processing and instance retrieval.
    • Updated training initialization to enforce CPU-based processing during inference.
  • Tests

    • Expanded testing scenarios to validate new dataset parameters and CPU-driven inference setups.

…mes that have labels

- implemented segment stitcher in BaseDataset that chunks according to max_batching_gap
- refactored SleapDataset and BaseDataset to using frame_idx from labeledFrame objects, rather than enumerated index created in SleapDataset. Avoids indexing related bugs
- chunker now makes the full decision on which frames are in a batch
- cleaned up index related code in SleapDataset including skip frame counts
…5 in SleapDataset init

- Have max_batching_gap as part of train/val/test dataset configs separately rather than a common value. Allows more flexibility, and cleaner config handling
- Rename variables
- Lint
…alse, just stitch all the segments into a single batch
- fixed multi gpu test case failure by forcing to cpu
- tested with no chunking with large dataset
Copy link
Contributor

coderabbitai bot commented Feb 7, 2025

Walkthrough

This update enhances the dataset processing and testing pipelines. A new method, process_segments, has been added to BaseDataset to stitch video frame segments, with corresponding updates in the create_chunks logic. The data_utils module now includes a load_slp function that parses SLEAP labels and constructs annotated segments. In SleapDataset, a max_batching_gap parameter and updated instance retrieval (get_instances) improve handling of frame indices and label loading. Additionally, tests now cover these changes, and the Trainer initialization has been modified to include accelerator="cpu".

Changes

File(s) Change Summary
dreem/…/base_dataset.py, dreem/…/sleap_dataset.py Added process_segments method and revised create_chunks in BaseDataset; introduced max_batching_gap parameter and updated get_instances signature in SleapDataset for improved segment handling.
dreem/…/data_utils.py Introduced load_slp function with new imports and classes (Labels, LabeledFrame) for parsing SLEAP labels and assembling annotated segments.
tests/…/test_datasets.py, tests/…/test_inference.py Extended test cases for SleapDataset to cover max_batching_gap variations; updated Trainer initialization to include accelerator="cpu".

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant SleapDataset
    participant DataUtils
    participant BaseDataset

    Client->>SleapDataset: Initialize dataset (max_batching_gap, clip_length, etc.)
    SleapDataset->>DataUtils: load_slp(labels_path)
    DataUtils-->>SleapDataset: Return (Labels, Annotated Segments)
    SleapDataset->>BaseDataset: create_chunks(...)
    BaseDataset->>BaseDataset: process_segments(i, segments, clip_length)
    Client->>SleapDataset: get_instances(frame_idx Tensor)
    SleapDataset->>SleapDataset: Retrieve instances using updated label access
Loading

Possibly related issues

Possibly related PRs

Poem

Oh, I’m a bunny with code so neat,
Hopping through segments with a happy beat,
Stitching frames in a clever dance,
With gaps and batches given a chance,
My floppy ears twitch at every byte,
In a world of changes so light and bright,
🐰 Code hops along into the night!

✨ Finishing Touches
  • 📝 Generate Docstrings (Beta)

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@shaikh58 shaikh58 linked an issue Feb 7, 2025 that may be closed by this pull request
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (7)
dreem/datasets/base_dataset.py (2)

82-97: Consider clarifying docstring.
You might clarify exactly how “stitching” works and whether partial segments shorter than clip_length are allowed or padded, particularly in unusual boundary conditions. This helps new contributors and future maintainers.


107-135: Potential performance concern when using torch.cat repeatedly.
Collecting segments in a list and concatenating them each time could be costly for larger datasets. Consider building a single list/array of frame indices and splitting only once, or at least limit the number of concatenations in performance-critical contexts.

dreem/datasets/sleap_dataset.py (3)

142-149: Handle exceptions from load_slp gracefully.
If load_slp fails (e.g., invalid file, partial data), the code could crash. Adding basic error handling or logging around data_utils.load_slp would improve robustness.

Do you want me to create a follow-up to gracefully catch load errors or corrupt label data?


209-210: Guard against missing frames.
Querying sleap_labels_obj[(sleap_labels_obj.video, frame_ind)] can raise a KeyError if no labels exist for that frame. Consider verifying the presence of frame data prior to indexing to avoid runtime errors.


213-214: Confirm video length bounds.
Reading frame_ind from vid_reader can fail if frame_ind is out of range. Although you catch an IndexError, you might log additional details (e.g., the maximum valid index). This can simplify debugging.

dreem/datasets/data_utils.py (2)

28-39: Add missing type hints and docstring.

The function is missing:

  1. Type hints for the return values.
  2. Documentation for the annotated_segments return value.

Apply this diff to add the missing type hints and docstring:

-def load_slp(labels_path: str, open_videos: bool = True) -> Labels:
+def load_slp(labels_path: str, open_videos: bool = True) -> tuple[Labels, list[tuple[int, int]]]:
     """Read a SLEAP labels file.
 
     Args:
         labels_path: A string path to the SLEAP labels file.
         open_videos: If `True` (the default), attempt to open the video backend for
             I/O. If `False`, the backend will not be opened (useful for reading metadata
             when the video files are not available).
 
     Returns:
-        The processed `Labels` object.
+        A tuple containing:
+            - The processed `Labels` object.
+            - A list of tuples containing (start_frame, end_frame) for each annotated segment.
     """

58-74: Fix unused loop variable and add validation.

The function has the following issues:

  1. Unused loop variable i in frames loop.
  2. Missing validation for empty frames dataset.

Apply this diff to fix the issues:

+    if len(frames) == 0:
+        return Labels(
+            labeled_frames=[],
+            videos=videos,
+            skeletons=skeletons,
+            tracks=tracks,
+            provenance=provenance,
+        ), []
+
     curr_segment_start = frames[0][2]
     curr_frame = curr_segment_start
     # note that frames only contains frames with labelled instances, not all frames
-    for i, video_id, frame_idx, instance_id_start, instance_id_end in frames:
+    for _i, video_id, frame_idx, instance_id_start, instance_id_end in frames:
🧰 Tools
🪛 Ruff (0.8.2)

58-58: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9699962 and 718e5d0.

📒 Files selected for processing (5)
  • dreem/datasets/base_dataset.py (2 hunks)
  • dreem/datasets/data_utils.py (1 hunks)
  • dreem/datasets/sleap_dataset.py (7 hunks)
  • tests/test_datasets.py (1 hunks)
  • tests/test_inference.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
dreem/datasets/data_utils.py

58-58: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: Tests (windows-2022, Python 3.9)
  • GitHub Check: Lint
  • GitHub Check: Tests (ubuntu-22.04, Python 3.9)
🔇 Additional comments (6)
dreem/datasets/base_dataset.py (2)

107-135: Validate segment boundaries before stitching.
When loop-processing annotated_segments, the code assumes annotated_segments[0] exists. If for some reason annotated_segments is empty, you’ll get an IndexError at prev_end = annotated_segments[0][1]. Consider gracefully skipping or logging a warning if no segments exist.


136-153: Check fraction vs. integer chunk sampling.
When n_chunks is in (0, 1.0], you convert it to an integer proportion of the total. Just confirm that the fraction usage aligns with user expectations (e.g., 0.5 → half the chunks). Also consider validating that 0 < n_chunks <= 1.0 is meaningful if fraction-based chunking is desired.

dreem/datasets/sleap_dataset.py (3)

105-106: Ensure default gap value matches typical usage.
Setting default max_batching_gap=15 may be reasonable, but double-check it represents a sensible out-of-the-box behavior for typical SLEAP data. If large annotation gaps are common, a larger default might be necessary.


165-168: Use consistent parameter types.
Swapping frame_idx from list[int] to torch.Tensor is fine, but ensure all relevant references and docstrings are updated. Make sure the rest of the pipeline (e.g., process_segments) deals with Tensors consistently, or you might run into type mismatches.


249-249: Consider track not found scenario.
Accessing sleap_labels_obj.tracks.index(instance.track) might also fail if track is removed or invalid. Check or handle the possibility that instance.track is not in the list.

tests/test_inference.py (1)

282-282: Thanks for specifying accelerator="cpu".
This ensures consistent CPU-only tests and avoids GPU-based test variations. The addition looks good.

Comment on lines +155 to +169
remove_idx = []
for i, frame_chunk in enumerate(self.chunked_frame_idx):
if (
len(frame_chunk)
<= min(int(self.clip_length / 10), 5)
# and frame_chunk[-1] % self.clip_length == 0
):
logger.warning(
f"Warning: Batch containing frames {frame_chunk} from video {self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} frames. Removing to avoid empty batch possibility with failed frame loading"
)
remove_idx.append(i)
if len(remove_idx) > 0:
for i in sorted(remove_idx, reverse=True):
self.chunked_frame_idx.pop(i)
self.label_idx.pop(i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Revisit “empty batch” removal threshold.
The arbitrary heuristic (≤ clip_length/10 or 5 frames) for removing “small” chunks might prematurely discard valid shorter sequences. Confirm this logic aligns with the project’s “Non contiguous clips” objectives, which support sparse annotation. If partial clips are acceptable, consider a smaller threshold or a user-configurable parameter.

Comment on lines +82 to +97
def process_segments(
self, i: int, segments_to_stitch: list[torch.Tensor], clip_length: int
) -> None:
"""Process segments to stitch. Modifies state variables chunked_frame_idx and label_idx.

Args:
segments_to_stitch: list of segments to stitch
i: index of the video
clip_length: the number of frames in each chunk
Returns: None
"""
stitched_segment = torch.cat(segments_to_stitch)
frame_idx_split = torch.split(stitched_segment, clip_length)
self.chunked_frame_idx.extend(frame_idx_split)
self.label_idx.extend(len(frame_idx_split) * [i])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure robust handling for dynamic attributes.
The method references self.max_batching_gap, which appears to be defined only in subclasses (e.g., SleapDataset). If another subclass or a direct instance of BaseDataset does not initialize self.max_batching_gap, this will raise an AttributeError. Consider assigning a default value in BaseDataset to avoid runtime issues.

+ # Example fix: Provide a default in BaseDataset if you expect each subclass to possibly override it
+ self.max_batching_gap = getattr(self, "max_batching_gap", 15)

Comment on lines +208 to +239
train_ds = SleapDataset(
slp_files=[two_flies[0]],
video_files=[two_flies[1]],
data_dirs="./data/sleap",
crop_size=128,
chunk=True,
max_batching_gap=10,
clip_length=clip_length,
n_chunks=30,
)

train_ds = SleapDataset(
slp_files=[two_flies[0]],
video_files=[two_flies[1]],
data_dirs="./data/sleap",
crop_size=128,
chunk=True,
max_batching_gap=0,
clip_length=clip_length,
n_chunks=30,
)

train_ds = SleapDataset(
slp_files=[two_flies[0]],
video_files=[two_flies[1]],
data_dirs="./data/sleap",
crop_size=128,
chunk=False,
max_batching_gap=10,
clip_length=clip_length,
n_chunks=30,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add assertions to verify non-contiguous segments behavior.

The test cases for max_batching_gap parameter are missing assertions to verify that:

  1. The segments are correctly split when the gap exceeds max_batching_gap.
  2. The segments are correctly merged when the gap is within max_batching_gap.
  3. The non-chunked case correctly preserves all frames.

Do you want me to generate the assertions to verify the behavior of non-contiguous segments?

…on dataset type

- Temporarily added old chunking method for non-sleap datasets to prevent a breaking change
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
dreem/datasets/base_dataset.py (1)

82-96: Consider a more functional approach to reduce side effects.

The method directly modifies state variables chunked_frame_idx and label_idx. Consider returning the processed segments instead of modifying state, which would make the code more predictable and easier to test.

-    def process_segments(
-        self, i: int, segments_to_stitch: list[torch.Tensor], clip_length: int
-    ) -> None:
+    def process_segments(
+        self, i: int, segments_to_stitch: list[torch.Tensor], clip_length: int
+    ) -> tuple[list[torch.Tensor], list[int]]:
         """Process segments to stitch. Modifies state variables chunked_frame_idx and label_idx.

         Args:
             segments_to_stitch: list of segments to stitch
             i: index of the video
             clip_length: the number of frames in each chunk
-        Returns: None
+        Returns: Tuple of (frame_idx_split, label_idx)
         """
         stitched_segment = torch.cat(segments_to_stitch)
         frame_idx_split = torch.split(stitched_segment, clip_length)
-        self.chunked_frame_idx.extend(frame_idx_split)
-        self.label_idx.extend(len(frame_idx_split) * [i])
+        return list(frame_idx_split), [i] * len(frame_idx_split)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 718e5d0 and 4275a0c.

📒 Files selected for processing (1)
  • dreem/datasets/base_dataset.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: Tests (windows-2022, Python 3.9)
  • GitHub Check: Tests (ubuntu-22.04, Python 3.9)
  • GitHub Check: Lint
🔇 Additional comments (1)
dreem/datasets/base_dataset.py (1)

105-176: Ensure robust handling of max_batching_gap attribute.

The method references self.max_batching_gap which is defined in SleapDataset but not in BaseDataset. This could lead to AttributeError if the method is called incorrectly.

Additionally, the empty batch removal logic at lines 161-176 uses arbitrary thresholds that might not align with the non-contiguous clips feature.

Please verify if the thresholds for removing "small" chunks are still appropriate for non-contiguous clips:

#!/bin/bash
# Description: Find the distribution of chunk sizes to determine appropriate thresholds

# Search for test files that demonstrate chunk size patterns
rg -l "clip_length|max_batching_gap" tests/

Comment on lines +143 to +176
if self.n_chunks > 0 and self.n_chunks <= 1.0:
n_chunks = int(self.n_chunks * len(self.chunked_frame_idx))

elif self.n_chunks <= len(self.chunked_frame_idx):
n_chunks = int(self.n_chunks)

else:
n_chunks = len(self.chunked_frame_idx)

if n_chunks > 0 and n_chunks < len(self.chunked_frame_idx):
sample_idx = np.random.choice(
np.arange(len(self.chunked_frame_idx)), n_chunks, replace=False
)

self.chunked_frame_idx = [self.chunked_frame_idx[i] for i in sample_idx]

self.label_idx = [self.label_idx[i] for i in sample_idx]

# workaround for empty batch bug (needs to be changed). Check for batch with with only 1/10 size of clip length. Arbitrary thresholds
remove_idx = []
for i, frame_chunk in enumerate(self.chunked_frame_idx):
if (
len(frame_chunk)
<= min(int(self.clip_length / 10), 5)
# and frame_chunk[-1] % self.clip_length == 0
):
logger.warning(
f"Warning: Batch containing frames {frame_chunk} from video {self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} frames. Removing to avoid empty batch possibility with failed frame loading"
)
remove_idx.append(i)
if len(remove_idx) > 0:
for i in sorted(remove_idx, reverse=True):
self.chunked_frame_idx.pop(i)
self.label_idx.pop(i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Extract duplicated chunk processing logic.

The n_chunks calculation and small batch removal logic is duplicated between create_chunks_slp and create_chunks_other. Consider extracting these into helper methods.

+    def _calculate_n_chunks(self, total_chunks: int) -> int:
+        """Calculate the number of chunks to use based on self.n_chunks."""
+        if self.n_chunks > 0 and self.n_chunks <= 1.0:
+            return int(self.n_chunks * total_chunks)
+        elif self.n_chunks <= total_chunks:
+            return int(self.n_chunks)
+        return total_chunks
+
+    def _remove_small_chunks(self) -> None:
+        """Remove chunks that are too small to avoid empty batch issues."""
+        remove_idx = []
+        for i, frame_chunk in enumerate(self.chunked_frame_idx):
+            if len(frame_chunk) <= min(int(self.clip_length / 10), 5):
+                logger.warning(
+                    f"Warning: Batch containing frames {frame_chunk} from video "
+                    f"{self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} "
+                    "frames. Removing to avoid empty batch possibility."
+                )
+                remove_idx.append(i)
+        for i in sorted(remove_idx, reverse=True):
+            self.chunked_frame_idx.pop(i)
+            self.label_idx.pop(i)

Then use these helper methods in both create_chunks_slp and create_chunks_other.

Also applies to: 194-227

Comment on lines 98 to +104
def create_chunks(self) -> None:
"""Factory method to create chunks."""
if type(self).__name__ == "SleapDataset":
self.create_chunks_slp()
else:
self.create_chunks_other()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Replace type checking with polymorphism.

Using type(self).__name__ for routing is fragile and violates the Open/Closed Principle. Consider making create_chunks an abstract method and implementing it in each subclass.

-    def create_chunks(self) -> None:
-        """Factory method to create chunks."""
-        if type(self).__name__ == "SleapDataset":
-            self.create_chunks_slp()
-        else:
-            self.create_chunks_other()
+    @abstractmethod
+    def create_chunks(self) -> None:
+        """Create chunks for the dataset.
+        
+        This method should be implemented by subclasses to define their specific
+        chunking strategy.
+        """
+        pass
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def create_chunks(self) -> None:
"""Factory method to create chunks."""
if type(self).__name__ == "SleapDataset":
self.create_chunks_slp()
else:
self.create_chunks_other()
@abstractmethod
def create_chunks(self) -> None:
"""Create chunks for the dataset.
This method should be implemented by subclasses to define their specific
chunking strategy.
"""
pass

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
dreem/datasets/sleap_dataset.py (3)

39-39: Validate max_batching_gap parameter.

Consider adding validation to ensure max_batching_gap is a positive integer, as negative or zero values would not make sense in this context.

 def __init__(
     self,
     slp_files: list[str],
     video_files: list[str],
     data_dirs: Optional[list[str]] = None,
     padding: int = 5,
     crop_size: Union[int, list[int]] = 128,
     anchors: int | list[str] | str = "",
     chunk: bool = True,
     clip_length: int = 500,
     mode: str = "train",
     handle_missing: str = "centroid",
     augmentations: dict | None = None,
     n_chunks: int | float = 1.0,
     seed: int | None = None,
     verbose: bool = False,
     normalize_image: bool = True,
     max_batching_gap: int = 15,
 ):
+    if max_batching_gap <= 0:
+        raise ValueError("max_batching_gap must be a positive integer")
     self.max_batching_gap = max_batching_gap

Also applies to: 77-77, 106-106


143-150: Add error handling for file loading operations.

Consider adding try-except blocks to handle potential file loading errors gracefully and provide meaningful error messages.

 # load_slp is a wrapper around sio.load_slp for frame gap checks
 self.labels = []
 self.annotated_segments = {}
 for slp_file in self.slp_files:
+    try:
         labels, annotated_segments = data_utils.load_slp(slp_file)
         self.labels.append(labels)
         self.annotated_segments[slp_file] = annotated_segments
+    except Exception as e:
+        logger.error(f"Failed to load {slp_file}: {str(e)}")
+        raise RuntimeError(f"Failed to load {slp_file}") from e

211-214: Enhance frame index mismatch warning.

The warning message could be more informative by including the video file name and suggesting potential causes.

-            if frame_ind != lf.frame_idx:
-                logger.warning(f"Frame index mismatch: {frame_ind} != {lf.frame_idx}")
+            if frame_ind != lf.frame_idx:
+                logger.warning(
+                    f"Frame index mismatch in {video_name}: requested={frame_ind}, "
+                    f"actual={lf.frame_idx}. This might indicate missing frames or "
+                    "incorrect frame indexing."
+                )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4275a0c and 26cd5c0.

📒 Files selected for processing (1)
  • dreem/datasets/sleap_dataset.py (8 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: Tests (windows-2022, Python 3.9)
  • GitHub Check: Lint
  • GitHub Check: Tests (ubuntu-22.04, Python 3.9)

Comment on lines +166 to +168
def get_instances(
self, label_idx: list[int], frame_idx: torch.Tensor
) -> list[Frame]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add input validation for frame_idx tensor.

The parameter type change from list[int] to torch.Tensor requires validation to ensure correct tensor properties.

 def get_instances(
     self, label_idx: list[int], frame_idx: torch.Tensor
 ) -> list[Frame]:
+    if not isinstance(frame_idx, torch.Tensor):
+        raise TypeError("frame_idx must be a torch.Tensor")
+    if frame_idx.dtype not in [torch.int32, torch.int64]:
+        raise TypeError("frame_idx must contain integer values")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_instances(
self, label_idx: list[int], frame_idx: torch.Tensor
) -> list[Frame]:
def get_instances(
self, label_idx: list[int], frame_idx: torch.Tensor
) -> list[Frame]:
if not isinstance(frame_idx, torch.Tensor):
raise TypeError("frame_idx must be a torch.Tensor")
if frame_idx.dtype not in [torch.int32, torch.int64]:
raise TypeError("frame_idx must contain integer values")

@shaikh58
Copy link
Contributor Author

shaikh58 commented Feb 8, 2025

Also fixes issue that was addressed by temporary patch in #108

@shaikh58 shaikh58 merged commit 0408334 into main Feb 8, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support non-contiguous sets of annotated frames in training data
1 participant