-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Non contiguous clips #109
Non contiguous clips #109
Changes from all commits
55f0782
1c01798
68cd663
be85d33
718e5d0
4275a0c
26cd5c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -79,7 +79,103 @@ def __init__( | |||||||||||||||||||||||||||||
self.labels = None | ||||||||||||||||||||||||||||||
self.gt_list = None | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def process_segments( | ||||||||||||||||||||||||||||||
self, i: int, segments_to_stitch: list[torch.Tensor], clip_length: int | ||||||||||||||||||||||||||||||
) -> None: | ||||||||||||||||||||||||||||||
"""Process segments to stitch. Modifies state variables chunked_frame_idx and label_idx. | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||
segments_to_stitch: list of segments to stitch | ||||||||||||||||||||||||||||||
i: index of the video | ||||||||||||||||||||||||||||||
clip_length: the number of frames in each chunk | ||||||||||||||||||||||||||||||
Returns: None | ||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
stitched_segment = torch.cat(segments_to_stitch) | ||||||||||||||||||||||||||||||
frame_idx_split = torch.split(stitched_segment, clip_length) | ||||||||||||||||||||||||||||||
self.chunked_frame_idx.extend(frame_idx_split) | ||||||||||||||||||||||||||||||
self.label_idx.extend(len(frame_idx_split) * [i]) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def create_chunks(self) -> None: | ||||||||||||||||||||||||||||||
"""Factory method to create chunks.""" | ||||||||||||||||||||||||||||||
if type(self).__name__ == "SleapDataset": | ||||||||||||||||||||||||||||||
self.create_chunks_slp() | ||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
self.create_chunks_other() | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
Comment on lines
98
to
+104
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Replace type checking with polymorphism. Using - def create_chunks(self) -> None:
- """Factory method to create chunks."""
- if type(self).__name__ == "SleapDataset":
- self.create_chunks_slp()
- else:
- self.create_chunks_other()
+ @abstractmethod
+ def create_chunks(self) -> None:
+ """Create chunks for the dataset.
+
+ This method should be implemented by subclasses to define their specific
+ chunking strategy.
+ """
+ pass 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||
def create_chunks_slp(self) -> None: | ||||||||||||||||||||||||||||||
"""Get indexing for data. | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
Creates both indexes for selecting dataset (label_idx) and frame in | ||||||||||||||||||||||||||||||
dataset (chunked_frame_idx). If chunking is false, we index directly | ||||||||||||||||||||||||||||||
using the frame ids. Setting chunking to true creates a list of lists | ||||||||||||||||||||||||||||||
containing chunk frames for indexing. This is useful for computational | ||||||||||||||||||||||||||||||
efficiency and data shuffling. To be called by subclass __init__() | ||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
self.chunked_frame_idx, self.label_idx = [], [] | ||||||||||||||||||||||||||||||
# go through each slp file and create chunks that respect max_batching_gap | ||||||||||||||||||||||||||||||
for i, slp_file in enumerate(self.label_files): | ||||||||||||||||||||||||||||||
annotated_segments = self.annotated_segments[slp_file] | ||||||||||||||||||||||||||||||
segments_to_stitch = [] | ||||||||||||||||||||||||||||||
prev_end = annotated_segments[0][1] # end of first segment | ||||||||||||||||||||||||||||||
for start, end in annotated_segments: | ||||||||||||||||||||||||||||||
# check if the start of current segment is within batching_max_gap of end of previous | ||||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||||
int(start) - int(prev_end) < self.max_batching_gap | ||||||||||||||||||||||||||||||
) or not self.chunk: # also takes care of first segment as start < prev_end | ||||||||||||||||||||||||||||||
segments_to_stitch.append(torch.arange(start, end + 1)) | ||||||||||||||||||||||||||||||
prev_end = end | ||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
# stitch previous set of segments before creating a new chunk | ||||||||||||||||||||||||||||||
self.process_segments(i, segments_to_stitch, self.clip_length) | ||||||||||||||||||||||||||||||
# reset segments_to_stitch as we are starting a new chunk | ||||||||||||||||||||||||||||||
segments_to_stitch = [torch.arange(start, end + 1)] | ||||||||||||||||||||||||||||||
prev_end = end | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if not self.chunk: | ||||||||||||||||||||||||||||||
self.process_segments( | ||||||||||||||||||||||||||||||
i, segments_to_stitch, self.labels[i].video.shape[0] | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
# add last chunk after the loop | ||||||||||||||||||||||||||||||
if segments_to_stitch: | ||||||||||||||||||||||||||||||
self.process_segments(i, segments_to_stitch, self.clip_length) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if self.n_chunks > 0 and self.n_chunks <= 1.0: | ||||||||||||||||||||||||||||||
n_chunks = int(self.n_chunks * len(self.chunked_frame_idx)) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
elif self.n_chunks <= len(self.chunked_frame_idx): | ||||||||||||||||||||||||||||||
n_chunks = int(self.n_chunks) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
n_chunks = len(self.chunked_frame_idx) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if n_chunks > 0 and n_chunks < len(self.chunked_frame_idx): | ||||||||||||||||||||||||||||||
sample_idx = np.random.choice( | ||||||||||||||||||||||||||||||
np.arange(len(self.chunked_frame_idx)), n_chunks, replace=False | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
self.chunked_frame_idx = [self.chunked_frame_idx[i] for i in sample_idx] | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
self.label_idx = [self.label_idx[i] for i in sample_idx] | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# workaround for empty batch bug (needs to be changed). Check for batch with with only 1/10 size of clip length. Arbitrary thresholds | ||||||||||||||||||||||||||||||
remove_idx = [] | ||||||||||||||||||||||||||||||
for i, frame_chunk in enumerate(self.chunked_frame_idx): | ||||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||||
len(frame_chunk) | ||||||||||||||||||||||||||||||
<= min(int(self.clip_length / 10), 5) | ||||||||||||||||||||||||||||||
# and frame_chunk[-1] % self.clip_length == 0 | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
logger.warning( | ||||||||||||||||||||||||||||||
f"Warning: Batch containing frames {frame_chunk} from video {self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} frames. Removing to avoid empty batch possibility with failed frame loading" | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
remove_idx.append(i) | ||||||||||||||||||||||||||||||
if len(remove_idx) > 0: | ||||||||||||||||||||||||||||||
for i in sorted(remove_idx, reverse=True): | ||||||||||||||||||||||||||||||
self.chunked_frame_idx.pop(i) | ||||||||||||||||||||||||||||||
self.label_idx.pop(i) | ||||||||||||||||||||||||||||||
Comment on lines
+162
to
+176
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Revisit “empty batch” removal threshold.
Comment on lines
+143
to
+176
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Extract duplicated chunk processing logic. The n_chunks calculation and small batch removal logic is duplicated between + def _calculate_n_chunks(self, total_chunks: int) -> int:
+ """Calculate the number of chunks to use based on self.n_chunks."""
+ if self.n_chunks > 0 and self.n_chunks <= 1.0:
+ return int(self.n_chunks * total_chunks)
+ elif self.n_chunks <= total_chunks:
+ return int(self.n_chunks)
+ return total_chunks
+
+ def _remove_small_chunks(self) -> None:
+ """Remove chunks that are too small to avoid empty batch issues."""
+ remove_idx = []
+ for i, frame_chunk in enumerate(self.chunked_frame_idx):
+ if len(frame_chunk) <= min(int(self.clip_length / 10), 5):
+ logger.warning(
+ f"Warning: Batch containing frames {frame_chunk} from video "
+ f"{self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} "
+ "frames. Removing to avoid empty batch possibility."
+ )
+ remove_idx.append(i)
+ for i in sorted(remove_idx, reverse=True):
+ self.chunked_frame_idx.pop(i)
+ self.label_idx.pop(i) Then use these helper methods in both Also applies to: 194-227 |
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def create_chunks_other(self) -> None: | ||||||||||||||||||||||||||||||
"""Get indexing for data. | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
Creates both indexes for selecting dataset (label_idx) and frame in | ||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -36,6 +36,7 @@ def __init__( | |||||||||||||||||||||
seed: int | None = None, | ||||||||||||||||||||||
verbose: bool = False, | ||||||||||||||||||||||
normalize_image: bool = True, | ||||||||||||||||||||||
max_batching_gap: int = 15, | ||||||||||||||||||||||
): | ||||||||||||||||||||||
"""Initialize SleapDataset. | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
@@ -73,6 +74,7 @@ def __init__( | |||||||||||||||||||||
seed: set a seed for reproducibility | ||||||||||||||||||||||
verbose: boolean representing whether to print | ||||||||||||||||||||||
normalize_image: whether to normalize the image to [0, 1] | ||||||||||||||||||||||
max_batching_gap: the max number of frames that can be unlabelled before starting a new batch | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
super().__init__( | ||||||||||||||||||||||
slp_files, | ||||||||||||||||||||||
|
@@ -101,6 +103,7 @@ def __init__( | |||||||||||||||||||||
self.n_chunks = n_chunks | ||||||||||||||||||||||
self.seed = seed | ||||||||||||||||||||||
self.normalize_image = normalize_image | ||||||||||||||||||||||
self.max_batching_gap = max_batching_gap | ||||||||||||||||||||||
if self.data_dirs is None: | ||||||||||||||||||||||
self.data_dirs = [] | ||||||||||||||||||||||
if isinstance(anchors, int): | ||||||||||||||||||||||
|
@@ -136,20 +139,18 @@ def __init__( | |||||||||||||||||||||
|
||||||||||||||||||||||
# if self.seed is not None: | ||||||||||||||||||||||
# np.random.seed(self.seed) | ||||||||||||||||||||||
self.labels = [sio.load_slp(slp_file) for slp_file in self.slp_files] | ||||||||||||||||||||||
|
||||||||||||||||||||||
# load_slp is a wrapper around sio.load_slp for frame gap checks | ||||||||||||||||||||||
self.labels = [] | ||||||||||||||||||||||
self.annotated_segments = {} | ||||||||||||||||||||||
for slp_file in self.slp_files: | ||||||||||||||||||||||
labels, annotated_segments = data_utils.load_slp(slp_file) | ||||||||||||||||||||||
self.labels.append(labels) | ||||||||||||||||||||||
self.annotated_segments[slp_file] = annotated_segments | ||||||||||||||||||||||
|
||||||||||||||||||||||
self.videos = [imageio.get_reader(vid_file) for vid_file in self.vid_files] | ||||||||||||||||||||||
# do we need this? would need to update with sleap-io | ||||||||||||||||||||||
|
||||||||||||||||||||||
# for label in self.labels: | ||||||||||||||||||||||
# label.remove_empty_instances(keep_empty_frames=False) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# note if slp is missing frames, taking last frame idx is safer than len(labels) | ||||||||||||||||||||||
# as there will be fewer labeledframes than actual frames | ||||||||||||||||||||||
self.frame_idx = [ | ||||||||||||||||||||||
torch.arange(labels[-1].frame_idx + 1) for labels in self.labels | ||||||||||||||||||||||
] | ||||||||||||||||||||||
self.skipped_frame_ct = [0 for labels in self.labels] | ||||||||||||||||||||||
# self.frame_idx = [torch.arange(len(labels)) for labels in self.labels] | ||||||||||||||||||||||
# Method in BaseDataset. Creates label_idx and chunked_frame_idx to be | ||||||||||||||||||||||
# used in call to get_instances() | ||||||||||||||||||||||
self.create_chunks() | ||||||||||||||||||||||
|
@@ -162,18 +163,20 @@ def get_indices(self, idx: int) -> tuple: | |||||||||||||||||||||
""" | ||||||||||||||||||||||
return self.label_idx[idx], self.chunked_frame_idx[idx] | ||||||||||||||||||||||
|
||||||||||||||||||||||
def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Frame]: | ||||||||||||||||||||||
def get_instances( | ||||||||||||||||||||||
self, label_idx: list[int], frame_idx: torch.Tensor | ||||||||||||||||||||||
) -> list[Frame]: | ||||||||||||||||||||||
Comment on lines
+166
to
+168
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add input validation for frame_idx tensor. The parameter type change from def get_instances(
self, label_idx: list[int], frame_idx: torch.Tensor
) -> list[Frame]:
+ if not isinstance(frame_idx, torch.Tensor):
+ raise TypeError("frame_idx must be a torch.Tensor")
+ if frame_idx.dtype not in [torch.int32, torch.int64]:
+ raise TypeError("frame_idx must contain integer values") 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||
"""Get an element of the dataset. | ||||||||||||||||||||||
|
||||||||||||||||||||||
Args: | ||||||||||||||||||||||
label_idx: index of the labels | ||||||||||||||||||||||
frame_idx: index of the frames | ||||||||||||||||||||||
frame_idx: indices of the frames to load in to the batch | ||||||||||||||||||||||
|
||||||||||||||||||||||
Returns: | ||||||||||||||||||||||
A list of `dreem.io.Frame` objects containing metadata and instance data for the batch/clip. | ||||||||||||||||||||||
|
||||||||||||||||||||||
""" | ||||||||||||||||||||||
video = self.labels[label_idx] | ||||||||||||||||||||||
sleap_labels_obj = self.labels[label_idx] | ||||||||||||||||||||||
video_name = self.video_files[label_idx] | ||||||||||||||||||||||
|
||||||||||||||||||||||
# get the correct crop size based on the video | ||||||||||||||||||||||
|
@@ -189,9 +192,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram | |||||||||||||||||||||
|
||||||||||||||||||||||
vid_reader = self.videos[label_idx] | ||||||||||||||||||||||
|
||||||||||||||||||||||
# img = vid_reader.get_data(0) | ||||||||||||||||||||||
|
||||||||||||||||||||||
skeleton = video.skeletons[-1] | ||||||||||||||||||||||
skeleton = sleap_labels_obj.skeletons[-1] | ||||||||||||||||||||||
|
||||||||||||||||||||||
frames = [] | ||||||||||||||||||||||
for i, frame_ind in enumerate(frame_idx): | ||||||||||||||||||||||
|
@@ -206,18 +207,13 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram | |||||||||||||||||||||
|
||||||||||||||||||||||
frame_ind = int(frame_ind) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# if slp is missing instances in some frames, frame_ind will be smaller than lf.frame_idx | ||||||||||||||||||||||
lf = video[frame_ind - self.skipped_frame_ct[label_idx]] | ||||||||||||||||||||||
if frame_ind < lf.frame_idx: | ||||||||||||||||||||||
logger.warning( | ||||||||||||||||||||||
f"Frame index {frame_ind} is trying to access frame {lf.frame_idx} of the slp file {video_name}. " | ||||||||||||||||||||||
f"This likely means there are no labelled instances in this frame. Skipping frame." | ||||||||||||||||||||||
) | ||||||||||||||||||||||
self.skipped_frame_ct[label_idx] += 1 | ||||||||||||||||||||||
continue | ||||||||||||||||||||||
# sleap-io method for indexing a Labels() object based on the frame's index | ||||||||||||||||||||||
lf = sleap_labels_obj[(sleap_labels_obj.video, frame_ind)] | ||||||||||||||||||||||
if frame_ind != lf.frame_idx: | ||||||||||||||||||||||
logger.warning(f"Frame index mismatch: {frame_ind} != {lf.frame_idx}") | ||||||||||||||||||||||
|
||||||||||||||||||||||
try: | ||||||||||||||||||||||
img = vid_reader.get_data(int(lf.frame_idx)) | ||||||||||||||||||||||
img = vid_reader.get_data(int(frame_ind)) | ||||||||||||||||||||||
except IndexError as e: | ||||||||||||||||||||||
logger.warning( | ||||||||||||||||||||||
f"Could not read frame {frame_ind} from {video_name} due to {e}" | ||||||||||||||||||||||
|
@@ -253,7 +249,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram | |||||||||||||||||||||
continue | ||||||||||||||||||||||
|
||||||||||||||||||||||
if instance.track is not None: | ||||||||||||||||||||||
gt_track_id = video.tracks.index(instance.track) | ||||||||||||||||||||||
gt_track_id = sleap_labels_obj.tracks.index(instance.track) | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
gt_track_id = -1 | ||||||||||||||||||||||
gt_track_ids.append(gt_track_id) | ||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -205,13 +205,47 @@ def test_sleap_dataset(two_flies): | |
n_chunks=ds_length + 10000, | ||
) | ||
|
||
train_ds = SleapDataset( | ||
slp_files=[two_flies[0]], | ||
video_files=[two_flies[1]], | ||
data_dirs="./data/sleap", | ||
crop_size=128, | ||
chunk=True, | ||
max_batching_gap=10, | ||
clip_length=clip_length, | ||
n_chunks=30, | ||
) | ||
|
||
train_ds = SleapDataset( | ||
slp_files=[two_flies[0]], | ||
video_files=[two_flies[1]], | ||
data_dirs="./data/sleap", | ||
crop_size=128, | ||
chunk=True, | ||
max_batching_gap=0, | ||
clip_length=clip_length, | ||
n_chunks=30, | ||
) | ||
|
||
train_ds = SleapDataset( | ||
slp_files=[two_flies[0]], | ||
video_files=[two_flies[1]], | ||
data_dirs="./data/sleap", | ||
crop_size=128, | ||
chunk=False, | ||
max_batching_gap=10, | ||
clip_length=clip_length, | ||
n_chunks=30, | ||
) | ||
Comment on lines
+208
to
+239
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add assertions to verify non-contiguous segments behavior. The test cases for
Do you want me to generate the assertions to verify the behavior of non-contiguous segments? |
||
|
||
|
||
def test_icy_dataset(ten_icy_particles): | ||
"""Test icy dataset logic. | ||
|
||
Args: | ||
ten_icy_particles: icy fixture used for testing | ||
""" | ||
|
||
clip_length = 8 | ||
|
||
train_ds = MicroscopyDataset( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure robust handling for dynamic attributes.
The method references self.max_batching_gap, which appears to be defined only in subclasses (e.g., SleapDataset). If another subclass or a direct instance of BaseDataset does not initialize self.max_batching_gap, this will raise an AttributeError. Consider assigning a default value in BaseDataset to avoid runtime issues.