Skip to content

Commit

Permalink
Add multi-video support
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Sep 25, 2024
1 parent bf7fadf commit 62dc540
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 82 deletions.
17 changes: 11 additions & 6 deletions sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def save_nwb(
append: bool = True,
frame_inds: Optional[list[int]] = None,
frame_path: Optional[str] = None,
image_format: str = "png",
):
"""Save a SLEAP dataset to NWB format.
Expand All @@ -80,12 +81,14 @@ def save_nwb(
tracked predictions that is used for downstream analysis.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
frame_inds: Optional list of frame indices to save when saving in training data
format. If `None`, all frames will be saved. No effect if `as_training` is
`False`.
frame_path: The path to save the extracted frame images to when saving in
training data format. If `None`, the path is the video filename without the
extension. No effect if `as_training` is `False`.
frame_inds: Optional list of labeled frame indices within the Labels to save
when saving in training data format. If `None`, all labeled frames in the
labels will be saved. No effect if `as_training` is `False`.
frame_path: The path to a folder to save the extracted frame images to when
saving in training data format. If `None`, the path is the NWB filename
without the extension. No effect if `as_training` is `False`.
image_format: The image format to use when saving extracted frame images.
Defaults to "png". No effect if `as_training` is `False`.
See also: nwb.write_nwb, nwb.append_nwb, nwb.append_nwb_training
"""
Expand All @@ -96,6 +99,7 @@ def save_nwb(
as_training=as_training,
frame_inds=frame_inds,
frame_path=frame_path,
image_format=image_format,
)
else:
nwb.write_nwb(
Expand All @@ -104,6 +108,7 @@ def save_nwb(
as_training=as_training,
frame_inds=frame_inds,
frame_path=frame_path,
image_format=image_format,
)


Expand Down
188 changes: 112 additions & 76 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,97 +57,125 @@
from sleap_io.io.utils import convert_predictions_to_dataframe


def write_video_to_path(
video: Video,
frame_inds: Optional[list[int]] = None,
def save_frame_images(
labels: Labels,
save_path: str | Path,
image_format: str = "png",
frame_path: Optional[str] = None,
) -> tuple[dict[int, str], Video, ImageSeries]:
"""Write individual frames of a video to a path.
frame_inds: Optional[list[int]] = None,
) -> dict[int, str]:
"""Save frames of a labels project to individual image files.
Args:
video: The video to write.
frame_inds: The indices of the frames to write. If None, all frames are written.
image_format: The format of the image to write. Default is .png
frame_path: The directory to save the frames to. If None, the path is the video
filename without the extension.
labels: A Labels object.
save_path: The directory to save the frames to.
image_format: The format of the image to write. Default is "png".
frame_inds: The indices of the labeled frames to write. If `None`, all labeled
frames in `labels` are written.
Returns:
A tuple containing a dictionary mapping frame indices to file paths,
the video, and the `ImageSeries`.
A dictionary mapping labeled frame indices to file paths.
"""
index_data = {}
if frame_inds is None:
frame_inds = list(range(video.backend.num_frames))
frame_inds = list(range(len(labels.labeled_frames)))

if isinstance(video.filename, list):
save_path = video.filename[0].split(".")[0]
else:
save_path = video.filename.split(".")[0]
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)

if frame_path is not None:
save_path = frame_path
lf_ind_to_img_path = {}
for lf_ind in frame_inds:
lf = labels.labeled_frames[lf_ind]

try:
os.makedirs(save_path, exist_ok=True)
except PermissionError:
filename_with_extension = video.filename.split("/")[-1]
filename = filename_with_extension.split(".")[0]
save_path = input("Permission denied. Enter a new path:") + "/" + filename
os.makedirs(save_path, exist_ok=True)
if lf.video.exists():
image = lf.image
else:
# TODO: Error handling for missing video files.
continue

video_ind = labels.videos.index(lf.video)
fidx = lf.frame_idx
img_path = save_path / f"video_{video_ind}.frame_{fidx}.{image_format}"

if "cv2" in sys.modules:
cv2.imwrite(img_path.as_posix(), image)
else:
iio.imwrite(img_path.as_posix(), image)

lf_ind_to_img_path[lf_ind] = img_path
return lf_ind_to_img_path

if "cv2" in sys.modules:
for frame_idx in frame_inds:
try:
frame = video[frame_idx]
except FileNotFoundError:
video_filename = input("Video not found. Enter the video filename:")
video = Video.from_filename(video_filename)
frame = video[frame_idx]
frame_path = f"{save_path}/frame_{frame_idx}.{image_format}"
index_data[frame_idx] = frame_path
cv2.imwrite(frame_path, frame)
else:
for frame_idx in frame_inds:
try:
frame = video[frame_idx]
except FileNotFoundError:
video_filename = input("Video not found. Enter the filename:")
video = Video.from_filename(video_filename)
frame = video[frame_idx]
frame_path = f"{save_path}/frame_{frame_idx}.{image_format}"
index_data[frame_idx] = frame_path
iio.imwrite(frame_path, frame)

image_series = ImageSeries(
name="video",
external_file=os.listdir(save_path),
starting_frame=[0 for _ in range(len(os.listdir(save_path)))],
rate=30.0, # TODO - change to `video.backend.fps` when available
)
return index_data, video, image_series

def make_image_series(
labels: Labels, lf_ind_to_img_path: dict[int, str]
) -> tuple[dict[Video, ImageSeries], dict[int, int]]:
"""Create an NWB ImageSeries from a dictionary of labeled frame indices to image paths.
Args:
labels: A Labels object.
lf_ind_to_img_path: A dictionary mapping labeled frame indices to image paths.
Returns:
A tuple of:
video_image_series: A dictionary mapping SLEAP Videos to corresponding NWB
ImageSeries.
lf_ind_to_series_ind: A dictionary mapping labeled frame indices to indices
within the ImageSeries.
"""
# Get the video for each labeled frame.
lf_ind_to_video = {
lf_ind: lf.video for lf_ind, lf in enumerate(labels.labeled_frames)
}

# Group the labeled frames by video.
video_to_lf_inds = {}
for lf_ind, video in lf_ind_to_video.items():
if video not in video_to_lf_inds:
video_to_lf_inds[video] = []
video_to_lf_inds[video].append(lf_ind)

# Create an ImageSeries for each video.
video_image_series = {}
lf_ind_to_series_ind = {}
for video, lf_inds in video_to_lf_inds.items():
image_files = [lf_ind_to_img_path[lf_ind] for lf_ind in lf_inds]
video_ind = labels.videos.index(video)
image_series = ImageSeries(
name=f"video_{video_ind}",
external_file=image_files,
starting_frame=[0 for _ in range(len(image_files))],
# TODO: Include the frame index within the source video that the image came from.
# TODO: Include more metadata from the video backend.
rate=30.0, # TODO - change to `video.backend.fps` when available
)
video_image_series[video] = image_series
for series_ind, lf_ind in enumerate(lf_inds):
lf_ind_to_series_ind[lf_ind] = series_ind

return video_image_series, lf_ind_to_series_ind


def labels_to_pose_training(
labels: Labels,
skeletons_list: list[NWBSkeleton], # type: ignore[return]
video_info: tuple[dict[int, str], Video, ImageSeries],
video_image_series: dict[Video, ImageSeries],
lf_ind_to_series_ind: dict[int, int],
) -> PoseTraining: # type: ignore[return]
"""Creates an NWB PoseTraining object from a Labels object.
Args:
labels: A Labels object.
skeletons_list: A list of NWB skeletons.
video_info: A tuple containing a dictionary mapping frame indices to file paths,
the video, and the `ImageSeries`.
video_image_series: A dictionary mapping SLEAP Videos to corresponding NWB
ImageSeries.
lf_ind_to_series_ind: A dictionary mapping labeled frame indices to indices
within the ImageSeries.
Returns:
A PoseTraining object.
"""
training_frame_list = []
skeleton_instances_list = []
source_video_list = []
for i, labeled_frame in enumerate(labels.labeled_frames):
for lf_ind, labeled_frame in enumerate(labels.labeled_frames):
instance_counter = 0
for instance, skeleton in zip(labeled_frame.instances, skeletons_list):
skeleton_instance = instance_to_skeleton_instance(
Expand All @@ -159,18 +187,16 @@ def labels_to_pose_training(
training_frame_skeleton_instances = SkeletonInstances(
skeleton_instances=skeleton_instances_list
)
training_frame_video_index = labeled_frame.frame_idx
source_video = video_image_series[labeled_frame.video]

image_series = video_info[2]
source_video = image_series
if source_video not in source_video_list:
source_video_list.append(source_video)
training_frame = TrainingFrame(
name=f"training_frame_{i}",
name=f"training_frame_{lf_ind}",
annotator="N/A",
skeleton_instances=training_frame_skeleton_instances,
source_video=source_video,
source_video_frame_index=training_frame_video_index,
source_video_frame_index=lf_ind_to_series_ind[lf_ind],
)
training_frame_list.append(training_frame)

Expand Down Expand Up @@ -399,12 +425,12 @@ def read_nwb(path: str) -> Labels:
frame_inds = np.searchsorted(
timestamps, get_timestamps(pose_estimation_series)
)
tracks_numpy[frame_inds, track_idx, node_idx, :] = (
pose_estimation_series.data[:]
)
confidence[frame_inds, track_idx, node_idx] = (
pose_estimation_series.confidence[:]
)
tracks_numpy[
frame_inds, track_idx, node_idx, :
] = pose_estimation_series.data[:]
confidence[
frame_inds, track_idx, node_idx
] = pose_estimation_series.confidence[:]

video_tracks[Path(pose_estimation.original_videos[0]).as_posix()] = (
tracks_numpy,
Expand Down Expand Up @@ -640,6 +666,7 @@ def append_nwb_training(
pose_estimation_metadata: Optional[dict] = None,
frame_inds: Optional[list[int]] = None,
frame_path: Optional[str] = None,
image_format: str = "png",
) -> NWBFile:
"""Append training data from a Labels object to an in-memory NWB file.
Expand All @@ -650,6 +677,7 @@ def append_nwb_training(
frame_inds: The indices of the frames to write. If None, all frames are written.
frame_path: The path to save the frames. If None, the path is the video
filename without the extension.
image_format: The format of the image to write. Default is "png".
Returns:
An in-memory NWB file with the PoseTraining data appended.
Expand Down Expand Up @@ -682,10 +710,18 @@ def append_nwb_training(
]
skeletons = Skeletons(skeletons=skeletons_list)
nwb_processing_module.add(skeletons)
video_info = write_video_to_path(
labels.videos[0], frame_inds, frame_path=frame_path
lf_ind_to_img_path = save_frame_images(
labels,
save_path=frame_path,
image_format=image_format,
frame_inds=frame_inds,
)
image_series_list, lf_ind_to_series_ind = make_image_series(
labels, lf_ind_to_img_path
)
pose_training = labels_to_pose_training(
labels, skeletons_list, image_series_list, lf_ind_to_series_ind
)
pose_training = labels_to_pose_training(labels, skeletons_list, video_info)
nwb_processing_module.add(pose_training)

_ = nwbfile.create_device(
Expand Down

0 comments on commit 62dc540

Please sign in to comment.