Skip to content

Commit

Permalink
Open video file only at init (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad authored Jul 23, 2024
1 parent 3690105 commit a2a544e
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 12 deletions.
24 changes: 18 additions & 6 deletions dreem/datasets/microscopy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
seed,
)

self.videos = videos
self.vid_files = videos
self.tracks = tracks
self.chunk = chunk
self.clip_length = clip_length
Expand Down Expand Up @@ -92,13 +92,19 @@ def __init__(
parser(self.tracks[video_idx]) for video_idx in range(len(self.tracks))
]

self.videos = []
for vid_file in self.vid_files:
if not isinstance(vid_file, list):
self.videos.append(data_utils.LazyTiffStack(vid_file))
else:
self.videos.append([Image.open(frame_file) for frame_file in vid_file])
self.frame_idx = [
(
torch.arange(Image.open(video).n_frames)
if isinstance(video, str)
else torch.arange(len(video))
)
for video in self.videos
for video in self.vid_files
]

# Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
Expand Down Expand Up @@ -128,17 +134,14 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

video = self.videos[label_idx]

if not isinstance(video, list):
video = data_utils.LazyTiffStack(self.videos[label_idx])

frames = []
for frame_id in frame_idx:
instances, gt_track_ids, centroids = [], [], []

img = (
video.get_section(frame_id)
if not isinstance(video, list)
else np.array(Image.open(video[frame_id]))
else np.array(video[frame_id])
)

lf = labels[labels["FRAME"].astype(int) == frame_id.item()]
Expand Down Expand Up @@ -202,3 +205,12 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
)

return frames

def __del__(self):
"""Handle file closing before deletion."""
for vid_reader in self.videos:
if not isinstance(vid_reader, list):
vid_reader.close()
else:
for frame_reader in vid_reader:
frame_reader.close()
9 changes: 7 additions & 2 deletions dreem/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(
# if self.seed is not None:
# np.random.seed(self.seed)
self.labels = [sio.load_slp(slp_file) for slp_file in self.slp_files]

self.videos = [imageio.get_reader(vid_file) for vid_file in self.vid_files]
# do we need this? would need to update with sleap-io

# for label in self.labels:
Expand Down Expand Up @@ -140,7 +140,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

video_name = self.video_files[label_idx]

vid_reader = imageio.get_reader(video_name, "ffmpeg")
vid_reader = self.videos[label_idx]

img = vid_reader.get_data(0)

Expand Down Expand Up @@ -370,3 +370,8 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
frames.append(frame)

return frames

def __del__(self):
"""Handle file closing before garbage collection."""
for reader in self.videos:
reader.close()
2 changes: 1 addition & 1 deletion dreem/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_tracker_cfg(self) -> dict:
tracker_cfg[key] = val
return tracker_cfg

def get_gtr_runner(self, ckpt_path=None) -> "GTRRunner":
def get_gtr_runner(self, ckpt_path: str | None = None) -> "GTRRunner":
"""Get lightning module for training, validation, and inference.
Args:
Expand Down
6 changes: 4 additions & 2 deletions dreem/io/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import h5py
from numpy.typing import ArrayLike
from typing import Self
from typing import Self, Any

logger = logging.getLogger("dreem.io")

Expand Down Expand Up @@ -246,7 +246,9 @@ def to_slp(
)
raise RuntimeError(f"Failed to convert to sio.PredictedInstance: {e}")

def to_h5(self, frame_group=h5py.Group, label=None, **kwargs: dict) -> h5py.Group:
def to_h5(
self, frame_group: h5py.Group, label: Any = None, **kwargs: dict
) -> h5py.Group:
"""Convert instance to an h5 group".
By default we always save:
Expand Down
2 changes: 1 addition & 1 deletion dreem/models/gtr_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
scheduler_cfg: dict | None = None,
metrics: dict[str, list[str]] | None = None,
persistent_tracking: dict[str, bool] | None = None,
test_save_path="./test_results.h5",
test_save_path: str = "./test_results.h5",
):
"""Initialize a lightning module for GTR.
Expand Down

0 comments on commit a2a544e

Please sign in to comment.