-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Non contiguous clips #109
Non contiguous clips #109
Conversation
…mes that have labels - implemented segment stitcher in BaseDataset that chunks according to max_batching_gap - refactored SleapDataset and BaseDataset to using frame_idx from labeledFrame objects, rather than enumerated index created in SleapDataset. Avoids indexing related bugs - chunker now makes the full decision on which frames are in a batch - cleaned up index related code in SleapDataset including skip frame counts
…5 in SleapDataset init - Have max_batching_gap as part of train/val/test dataset configs separately rather than a common value. Allows more flexibility, and cleaner config handling - Rename variables - Lint
…alse, just stitch all the segments into a single batch
- fixed multi gpu test case failure by forcing to cpu - tested with no chunking with large dataset
WalkthroughThis update enhances the dataset processing and testing pipelines. A new method, Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant SleapDataset
participant DataUtils
participant BaseDataset
Client->>SleapDataset: Initialize dataset (max_batching_gap, clip_length, etc.)
SleapDataset->>DataUtils: load_slp(labels_path)
DataUtils-->>SleapDataset: Return (Labels, Annotated Segments)
SleapDataset->>BaseDataset: create_chunks(...)
BaseDataset->>BaseDataset: process_segments(i, segments, clip_length)
Client->>SleapDataset: get_instances(frame_idx Tensor)
SleapDataset->>SleapDataset: Retrieve instances using updated label access
Possibly related issues
Possibly related PRs
Poem
✨ Finishing Touches
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (7)
dreem/datasets/base_dataset.py (2)
82-97
: Consider clarifying docstring.
You might clarify exactly how “stitching” works and whether partial segments shorter than clip_length are allowed or padded, particularly in unusual boundary conditions. This helps new contributors and future maintainers.
107-135
: Potential performance concern when using torch.cat repeatedly.
Collecting segments in a list and concatenating them each time could be costly for larger datasets. Consider building a single list/array of frame indices and splitting only once, or at least limit the number of concatenations in performance-critical contexts.dreem/datasets/sleap_dataset.py (3)
142-149
: Handle exceptions from load_slp gracefully.
If load_slp fails (e.g., invalid file, partial data), the code could crash. Adding basic error handling or logging around data_utils.load_slp would improve robustness.Do you want me to create a follow-up to gracefully catch load errors or corrupt label data?
209-210
: Guard against missing frames.
Querying sleap_labels_obj[(sleap_labels_obj.video, frame_ind)] can raise a KeyError if no labels exist for that frame. Consider verifying the presence of frame data prior to indexing to avoid runtime errors.
213-214
: Confirm video length bounds.
Reading frame_ind from vid_reader can fail if frame_ind is out of range. Although you catch an IndexError, you might log additional details (e.g., the maximum valid index). This can simplify debugging.dreem/datasets/data_utils.py (2)
28-39
: Add missing type hints and docstring.The function is missing:
- Type hints for the return values.
- Documentation for the
annotated_segments
return value.Apply this diff to add the missing type hints and docstring:
-def load_slp(labels_path: str, open_videos: bool = True) -> Labels: +def load_slp(labels_path: str, open_videos: bool = True) -> tuple[Labels, list[tuple[int, int]]]: """Read a SLEAP labels file. Args: labels_path: A string path to the SLEAP labels file. open_videos: If `True` (the default), attempt to open the video backend for I/O. If `False`, the backend will not be opened (useful for reading metadata when the video files are not available). Returns: - The processed `Labels` object. + A tuple containing: + - The processed `Labels` object. + - A list of tuples containing (start_frame, end_frame) for each annotated segment. """
58-74
: Fix unused loop variable and add validation.The function has the following issues:
- Unused loop variable
i
in frames loop.- Missing validation for empty frames dataset.
Apply this diff to fix the issues:
+ if len(frames) == 0: + return Labels( + labeled_frames=[], + videos=videos, + skeletons=skeletons, + tracks=tracks, + provenance=provenance, + ), [] + curr_segment_start = frames[0][2] curr_frame = curr_segment_start # note that frames only contains frames with labelled instances, not all frames - for i, video_id, frame_idx, instance_id_start, instance_id_end in frames: + for _i, video_id, frame_idx, instance_id_start, instance_id_end in frames:🧰 Tools
🪛 Ruff (0.8.2)
58-58: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
dreem/datasets/base_dataset.py
(2 hunks)dreem/datasets/data_utils.py
(1 hunks)dreem/datasets/sleap_dataset.py
(7 hunks)tests/test_datasets.py
(1 hunks)tests/test_inference.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
dreem/datasets/data_utils.py
58-58: Loop control variable i
not used within loop body
Rename unused i
to _i
(B007)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: Tests (windows-2022, Python 3.9)
- GitHub Check: Lint
- GitHub Check: Tests (ubuntu-22.04, Python 3.9)
🔇 Additional comments (6)
dreem/datasets/base_dataset.py (2)
107-135
: Validate segment boundaries before stitching.
When loop-processing annotated_segments, the code assumes annotated_segments[0] exists. If for some reason annotated_segments is empty, you’ll get an IndexError at prev_end = annotated_segments[0][1]. Consider gracefully skipping or logging a warning if no segments exist.
136-153
: Check fraction vs. integer chunk sampling.
When n_chunks is in (0, 1.0], you convert it to an integer proportion of the total. Just confirm that the fraction usage aligns with user expectations (e.g., 0.5 → half the chunks). Also consider validating that 0 < n_chunks <= 1.0 is meaningful if fraction-based chunking is desired.dreem/datasets/sleap_dataset.py (3)
105-106
: Ensure default gap value matches typical usage.
Setting default max_batching_gap=15 may be reasonable, but double-check it represents a sensible out-of-the-box behavior for typical SLEAP data. If large annotation gaps are common, a larger default might be necessary.
165-168
: Use consistent parameter types.
Swapping frame_idx from list[int] to torch.Tensor is fine, but ensure all relevant references and docstrings are updated. Make sure the rest of the pipeline (e.g., process_segments) deals with Tensors consistently, or you might run into type mismatches.
249-249
: Consider track not found scenario.
Accessing sleap_labels_obj.tracks.index(instance.track) might also fail if track is removed or invalid. Check or handle the possibility that instance.track is not in the list.tests/test_inference.py (1)
282-282
: Thanks for specifying accelerator="cpu".
This ensures consistent CPU-only tests and avoids GPU-based test variations. The addition looks good.
remove_idx = [] | ||
for i, frame_chunk in enumerate(self.chunked_frame_idx): | ||
if ( | ||
len(frame_chunk) | ||
<= min(int(self.clip_length / 10), 5) | ||
# and frame_chunk[-1] % self.clip_length == 0 | ||
): | ||
logger.warning( | ||
f"Warning: Batch containing frames {frame_chunk} from video {self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} frames. Removing to avoid empty batch possibility with failed frame loading" | ||
) | ||
remove_idx.append(i) | ||
if len(remove_idx) > 0: | ||
for i in sorted(remove_idx, reverse=True): | ||
self.chunked_frame_idx.pop(i) | ||
self.label_idx.pop(i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Revisit “empty batch” removal threshold.
The arbitrary heuristic (≤ clip_length/10 or 5 frames) for removing “small” chunks might prematurely discard valid shorter sequences. Confirm this logic aligns with the project’s “Non contiguous clips” objectives, which support sparse annotation. If partial clips are acceptable, consider a smaller threshold or a user-configurable parameter.
def process_segments( | ||
self, i: int, segments_to_stitch: list[torch.Tensor], clip_length: int | ||
) -> None: | ||
"""Process segments to stitch. Modifies state variables chunked_frame_idx and label_idx. | ||
|
||
Args: | ||
segments_to_stitch: list of segments to stitch | ||
i: index of the video | ||
clip_length: the number of frames in each chunk | ||
Returns: None | ||
""" | ||
stitched_segment = torch.cat(segments_to_stitch) | ||
frame_idx_split = torch.split(stitched_segment, clip_length) | ||
self.chunked_frame_idx.extend(frame_idx_split) | ||
self.label_idx.extend(len(frame_idx_split) * [i]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure robust handling for dynamic attributes.
The method references self.max_batching_gap, which appears to be defined only in subclasses (e.g., SleapDataset). If another subclass or a direct instance of BaseDataset does not initialize self.max_batching_gap, this will raise an AttributeError. Consider assigning a default value in BaseDataset to avoid runtime issues.
+ # Example fix: Provide a default in BaseDataset if you expect each subclass to possibly override it
+ self.max_batching_gap = getattr(self, "max_batching_gap", 15)
train_ds = SleapDataset( | ||
slp_files=[two_flies[0]], | ||
video_files=[two_flies[1]], | ||
data_dirs="./data/sleap", | ||
crop_size=128, | ||
chunk=True, | ||
max_batching_gap=10, | ||
clip_length=clip_length, | ||
n_chunks=30, | ||
) | ||
|
||
train_ds = SleapDataset( | ||
slp_files=[two_flies[0]], | ||
video_files=[two_flies[1]], | ||
data_dirs="./data/sleap", | ||
crop_size=128, | ||
chunk=True, | ||
max_batching_gap=0, | ||
clip_length=clip_length, | ||
n_chunks=30, | ||
) | ||
|
||
train_ds = SleapDataset( | ||
slp_files=[two_flies[0]], | ||
video_files=[two_flies[1]], | ||
data_dirs="./data/sleap", | ||
crop_size=128, | ||
chunk=False, | ||
max_batching_gap=10, | ||
clip_length=clip_length, | ||
n_chunks=30, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add assertions to verify non-contiguous segments behavior.
The test cases for max_batching_gap
parameter are missing assertions to verify that:
- The segments are correctly split when the gap exceeds
max_batching_gap
. - The segments are correctly merged when the gap is within
max_batching_gap
. - The non-chunked case correctly preserves all frames.
Do you want me to generate the assertions to verify the behavior of non-contiguous segments?
…on dataset type - Temporarily added old chunking method for non-sleap datasets to prevent a breaking change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
dreem/datasets/base_dataset.py (1)
82-96
: Consider a more functional approach to reduce side effects.The method directly modifies state variables
chunked_frame_idx
andlabel_idx
. Consider returning the processed segments instead of modifying state, which would make the code more predictable and easier to test.- def process_segments( - self, i: int, segments_to_stitch: list[torch.Tensor], clip_length: int - ) -> None: + def process_segments( + self, i: int, segments_to_stitch: list[torch.Tensor], clip_length: int + ) -> tuple[list[torch.Tensor], list[int]]: """Process segments to stitch. Modifies state variables chunked_frame_idx and label_idx. Args: segments_to_stitch: list of segments to stitch i: index of the video clip_length: the number of frames in each chunk - Returns: None + Returns: Tuple of (frame_idx_split, label_idx) """ stitched_segment = torch.cat(segments_to_stitch) frame_idx_split = torch.split(stitched_segment, clip_length) - self.chunked_frame_idx.extend(frame_idx_split) - self.label_idx.extend(len(frame_idx_split) * [i]) + return list(frame_idx_split), [i] * len(frame_idx_split)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
dreem/datasets/base_dataset.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: Tests (windows-2022, Python 3.9)
- GitHub Check: Tests (ubuntu-22.04, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (1)
dreem/datasets/base_dataset.py (1)
105-176
: Ensure robust handling of max_batching_gap attribute.The method references
self.max_batching_gap
which is defined inSleapDataset
but not inBaseDataset
. This could lead toAttributeError
if the method is called incorrectly.Additionally, the empty batch removal logic at lines 161-176 uses arbitrary thresholds that might not align with the non-contiguous clips feature.
Please verify if the thresholds for removing "small" chunks are still appropriate for non-contiguous clips:
#!/bin/bash # Description: Find the distribution of chunk sizes to determine appropriate thresholds # Search for test files that demonstrate chunk size patterns rg -l "clip_length|max_batching_gap" tests/
if self.n_chunks > 0 and self.n_chunks <= 1.0: | ||
n_chunks = int(self.n_chunks * len(self.chunked_frame_idx)) | ||
|
||
elif self.n_chunks <= len(self.chunked_frame_idx): | ||
n_chunks = int(self.n_chunks) | ||
|
||
else: | ||
n_chunks = len(self.chunked_frame_idx) | ||
|
||
if n_chunks > 0 and n_chunks < len(self.chunked_frame_idx): | ||
sample_idx = np.random.choice( | ||
np.arange(len(self.chunked_frame_idx)), n_chunks, replace=False | ||
) | ||
|
||
self.chunked_frame_idx = [self.chunked_frame_idx[i] for i in sample_idx] | ||
|
||
self.label_idx = [self.label_idx[i] for i in sample_idx] | ||
|
||
# workaround for empty batch bug (needs to be changed). Check for batch with with only 1/10 size of clip length. Arbitrary thresholds | ||
remove_idx = [] | ||
for i, frame_chunk in enumerate(self.chunked_frame_idx): | ||
if ( | ||
len(frame_chunk) | ||
<= min(int(self.clip_length / 10), 5) | ||
# and frame_chunk[-1] % self.clip_length == 0 | ||
): | ||
logger.warning( | ||
f"Warning: Batch containing frames {frame_chunk} from video {self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} frames. Removing to avoid empty batch possibility with failed frame loading" | ||
) | ||
remove_idx.append(i) | ||
if len(remove_idx) > 0: | ||
for i in sorted(remove_idx, reverse=True): | ||
self.chunked_frame_idx.pop(i) | ||
self.label_idx.pop(i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Extract duplicated chunk processing logic.
The n_chunks calculation and small batch removal logic is duplicated between create_chunks_slp
and create_chunks_other
. Consider extracting these into helper methods.
+ def _calculate_n_chunks(self, total_chunks: int) -> int:
+ """Calculate the number of chunks to use based on self.n_chunks."""
+ if self.n_chunks > 0 and self.n_chunks <= 1.0:
+ return int(self.n_chunks * total_chunks)
+ elif self.n_chunks <= total_chunks:
+ return int(self.n_chunks)
+ return total_chunks
+
+ def _remove_small_chunks(self) -> None:
+ """Remove chunks that are too small to avoid empty batch issues."""
+ remove_idx = []
+ for i, frame_chunk in enumerate(self.chunked_frame_idx):
+ if len(frame_chunk) <= min(int(self.clip_length / 10), 5):
+ logger.warning(
+ f"Warning: Batch containing frames {frame_chunk} from video "
+ f"{self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} "
+ "frames. Removing to avoid empty batch possibility."
+ )
+ remove_idx.append(i)
+ for i in sorted(remove_idx, reverse=True):
+ self.chunked_frame_idx.pop(i)
+ self.label_idx.pop(i)
Then use these helper methods in both create_chunks_slp
and create_chunks_other
.
Also applies to: 194-227
def create_chunks(self) -> None: | ||
"""Factory method to create chunks.""" | ||
if type(self).__name__ == "SleapDataset": | ||
self.create_chunks_slp() | ||
else: | ||
self.create_chunks_other() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Replace type checking with polymorphism.
Using type(self).__name__
for routing is fragile and violates the Open/Closed Principle. Consider making create_chunks
an abstract method and implementing it in each subclass.
- def create_chunks(self) -> None:
- """Factory method to create chunks."""
- if type(self).__name__ == "SleapDataset":
- self.create_chunks_slp()
- else:
- self.create_chunks_other()
+ @abstractmethod
+ def create_chunks(self) -> None:
+ """Create chunks for the dataset.
+
+ This method should be implemented by subclasses to define their specific
+ chunking strategy.
+ """
+ pass
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def create_chunks(self) -> None: | |
"""Factory method to create chunks.""" | |
if type(self).__name__ == "SleapDataset": | |
self.create_chunks_slp() | |
else: | |
self.create_chunks_other() | |
@abstractmethod | |
def create_chunks(self) -> None: | |
"""Create chunks for the dataset. | |
This method should be implemented by subclasses to define their specific | |
chunking strategy. | |
""" | |
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
dreem/datasets/sleap_dataset.py (3)
39-39
: Validate max_batching_gap parameter.Consider adding validation to ensure
max_batching_gap
is a positive integer, as negative or zero values would not make sense in this context.def __init__( self, slp_files: list[str], video_files: list[str], data_dirs: Optional[list[str]] = None, padding: int = 5, crop_size: Union[int, list[int]] = 128, anchors: int | list[str] | str = "", chunk: bool = True, clip_length: int = 500, mode: str = "train", handle_missing: str = "centroid", augmentations: dict | None = None, n_chunks: int | float = 1.0, seed: int | None = None, verbose: bool = False, normalize_image: bool = True, max_batching_gap: int = 15, ): + if max_batching_gap <= 0: + raise ValueError("max_batching_gap must be a positive integer") self.max_batching_gap = max_batching_gapAlso applies to: 77-77, 106-106
143-150
: Add error handling for file loading operations.Consider adding try-except blocks to handle potential file loading errors gracefully and provide meaningful error messages.
# load_slp is a wrapper around sio.load_slp for frame gap checks self.labels = [] self.annotated_segments = {} for slp_file in self.slp_files: + try: labels, annotated_segments = data_utils.load_slp(slp_file) self.labels.append(labels) self.annotated_segments[slp_file] = annotated_segments + except Exception as e: + logger.error(f"Failed to load {slp_file}: {str(e)}") + raise RuntimeError(f"Failed to load {slp_file}") from e
211-214
: Enhance frame index mismatch warning.The warning message could be more informative by including the video file name and suggesting potential causes.
- if frame_ind != lf.frame_idx: - logger.warning(f"Frame index mismatch: {frame_ind} != {lf.frame_idx}") + if frame_ind != lf.frame_idx: + logger.warning( + f"Frame index mismatch in {video_name}: requested={frame_ind}, " + f"actual={lf.frame_idx}. This might indicate missing frames or " + "incorrect frame indexing." + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
dreem/datasets/sleap_dataset.py
(8 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: Tests (windows-2022, Python 3.9)
- GitHub Check: Lint
- GitHub Check: Tests (ubuntu-22.04, Python 3.9)
def get_instances( | ||
self, label_idx: list[int], frame_idx: torch.Tensor | ||
) -> list[Frame]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add input validation for frame_idx tensor.
The parameter type change from list[int]
to torch.Tensor
requires validation to ensure correct tensor properties.
def get_instances(
self, label_idx: list[int], frame_idx: torch.Tensor
) -> list[Frame]:
+ if not isinstance(frame_idx, torch.Tensor):
+ raise TypeError("frame_idx must be a torch.Tensor")
+ if frame_idx.dtype not in [torch.int32, torch.int64]:
+ raise TypeError("frame_idx must contain integer values")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def get_instances( | |
self, label_idx: list[int], frame_idx: torch.Tensor | |
) -> list[Frame]: | |
def get_instances( | |
self, label_idx: list[int], frame_idx: torch.Tensor | |
) -> list[Frame]: | |
if not isinstance(frame_idx, torch.Tensor): | |
raise TypeError("frame_idx must be a torch.Tensor") | |
if frame_idx.dtype not in [torch.int32, torch.int64]: | |
raise TypeError("frame_idx must contain integer values") |
Also fixes issue that was addressed by temporary patch in #108 |
This PR implements a feature supporting non-contiguous annotated segments in training data. Currently, the pipeline expects the entire file to be annotated (except possibly a handful of missing frames), but does not support sparse annotation. This PR introduces support for this feature, as well as a redesign of chunking, and consistent indexing:
NOTE: Only SleapDataset is supported in this fix. Other Datasets will eventually need to be added. Currently, the other datasets are not in use. Test cases related to other Dataset types will fail
Highlights:
Summary by CodeRabbit
New Features
Tests