diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 84ec6b8b..e0407346 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,7 +38,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.12" + python-version: "3.9" - name: Install dependencies run: | @@ -54,15 +54,12 @@ jobs: # Tests with pytest tests: - timeout-minutes: 15 + timeout-minutes: 40 strategy: fail-fast: false matrix: os: ["ubuntu-latest", "windows-latest", "macos-14"] - python: ["3.8", "3.12"] - exclude: - - os: "macos-14" - python: "3.8" + python: ["3.9"] name: Tests (${{ matrix.os }}, Python ${{ matrix.python }}) runs-on: ${{ matrix.os }} diff --git a/README.md b/README.md index 83e2b781..39522b5f 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,11 @@ labels = sio.load_file("predictions.slp") sio.save_file(labels, "predictions.nwb") # Or: # labels.save("predictions.nwb") + +# Save to an NWB file and convert SLEAP training data to NWB training data: +frame_inds = [i for i in range(20)] +sio.save_file(labels, "predictions.nwb", as_training=True, frame_inds=frame_inds) +# This will save the first 20 frames of the video as individual images ``` ### Convert labels to raw arrays diff --git a/pyproject.toml b/pyproject.toml index 8292f6c5..c7d9cd1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "attrs", "h5py>=3.8.0", "pynwb", - "ndx-pose", + "ndx-pose @ git+https://github.com/rly/ndx-pose@a847ad4be75e60ef9e413b8cbfc99c616fc9fd05", "pandas", "simplejson", "imageio", diff --git a/sleap_io/io/main.py b/sleap_io/io/main.py index 7fd702f7..260760fa 100644 --- a/sleap_io/io/main.py +++ b/sleap_io/io/main.py @@ -6,6 +6,8 @@ from typing import Optional, Union from pathlib import Path +from pynwb import NWBHDF5IO + def load_slp(filename: str) -> Labels: """Load a SLEAP dataset. @@ -59,21 +61,45 @@ def load_nwb(filename: str) -> Labels: return nwb.read_nwb(filename) -def save_nwb(labels: Labels, filename: str, append: bool = True): +def save_nwb( + labels: Labels, + filename: str, + as_training: bool = False, + append: bool = True, + frame_inds: Optional[list[int]] = None, + frame_path: Optional[str] = None, +): """Save a SLEAP dataset to NWB format. Args: labels: A SLEAP `Labels` object (see `load_slp`). filename: Path to NWB file to save to. Must end in `.nwb`. + as_training: If `True`, save the dataset as a training dataset. append: If `True` (the default), append to existing NWB file. File will be created if it does not exist. + frame_inds: Optional list of frame indices to save. If None, all frames + will be saved. + frame_path: The path to save the frames. If None, the path is the video + filename without the extension. - See also: nwb.write_nwb, nwb.append_nwb + See also: nwb.write_nwb, nwb.append_nwb, nwb.append_nwb_training """ if append and Path(filename).exists(): - nwb.append_nwb(labels, filename) + nwb.append_nwb( + labels, + filename, + as_training=as_training, + frame_inds=frame_inds, + frame_path=frame_path, + ) else: - nwb.write_nwb(labels, filename) + nwb.write_nwb( + labels, + filename, + as_training=as_training, + frame_inds=frame_inds, + frame_path=frame_path, + ) def load_labelstudio( @@ -190,6 +216,8 @@ def load_file( return load_jabs(filename, **kwargs) elif format == "video": return load_video(filename, **kwargs) + else: + raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.") def save_file( @@ -219,8 +247,10 @@ def save_file( if format == "slp": save_slp(labels, filename, **kwargs) - elif format == "nwb": - save_nwb(labels, filename, **kwargs) + elif format in ("nwb", "nwb_predictions"): + save_nwb(labels, filename, False) + elif format == "nwb_training": + save_nwb(labels, filename, True, frame_inds=kwargs.pop("frame_inds", None)) elif format == "labelstudio": save_labelstudio(labels, filename, **kwargs) elif format == "jabs": diff --git a/sleap_io/io/nwb.py b/sleap_io/io/nwb.py index 9314294a..4e922a52 100644 --- a/sleap_io/io/nwb.py +++ b/sleap_io/io/nwb.py @@ -1,34 +1,321 @@ """Functions to write and read from the neurodata without borders (NWB) format.""" +from __future__ import annotations from copy import deepcopy from typing import List, Optional, Union from pathlib import Path import datetime import uuid import re +import sys +import os +import imageio.v3 as iio -import pandas as pd # type: ignore[import] +import pandas as pd import numpy as np +try: + import cv2 +except ImportError: + pass + try: from numpy.typing import ArrayLike except ImportError: ArrayLike = np.ndarray -from pynwb import NWBFile, NWBHDF5IO, ProcessingModule # type: ignore[import] -from ndx_pose import PoseEstimationSeries, PoseEstimation # type: ignore[import] + +from hdmf.utils import LabelledDict +from hdmf.build.errors import OrphanContainerBuildError + +from pynwb import NWBFile, NWBHDF5IO, ProcessingModule +from pynwb.file import Subject +from pynwb.image import ImageSeries + +from ndx_pose import ( + PoseEstimationSeries, + PoseEstimation, + Skeleton as NWBSkeleton, + Skeletons, + SkeletonInstance, + SkeletonInstances, + TrainingFrame, + TrainingFrames, + PoseTraining, + SourceVideos, +) from sleap_io import ( Labels, Video, LabeledFrame, Track, - Skeleton, + Skeleton as SLEAPSkeleton, Instance, PredictedInstance, + Edge, + Node, ) from sleap_io.io.utils import convert_predictions_to_dataframe +def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return] + """Creates a Labels object from an NWB PoseTraining object. + + Args: + pose_training: An NWB PoseTraining object. + + Returns: + A Labels object. + """ + labeled_frames = [] + skeletons = {} + training_frames = pose_training.training_frames.training_frames.values() + for training_frame in training_frames: + source_video = training_frame.source_video + if source_video.format == "external" and len(source_video.external_file) == 1: + video = Video(source_video.external_file[0]) + else: + raise NotImplementedError("Only single-file external videos are supported.") + + frame_idx = training_frame.source_video_frame_index + instances = [] + for instance in training_frame.skeleton_instances.skeleton_instances.values(): + if instance.skeleton.name not in skeletons: + skeletons[instance.skeleton.name] = nwb_skeleton_to_sleap( + instance.skeleton + ) + skeleton = skeletons[instance.skeleton.name] + instances.append( + Instance.from_numpy( + points=instance.node_locations[:], skeleton=skeleton + ) + ) # `track` field is not stored in `SkeletonInstance` objects + labeled_frames.append( + LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) + ) + return Labels(labeled_frames=labeled_frames) + + +def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return] + """Converts an NWB skeleton to a SLEAP skeleton. + + Args: + skeleton: An NWB skeleton. + + Returns: + A SLEAP skeleton. + """ + nodes = [Node(name=node) for node in skeleton.nodes] + edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges] + return SLEAPSkeleton( + nodes=nodes, + edges=edges, + name=skeleton.name, + ) + + +def labels_to_pose_training( + labels: Labels, + skeletons_list: list[NWBSkeleton], # type: ignore[return] + video_info: tuple[dict[int, str], Video, ImageSeries], +) -> PoseTraining: # type: ignore[return] + """Creates an NWB PoseTraining object from a Labels object. + + Args: + labels: A Labels object. + skeletons_list: A list of NWB skeletons. + video_info: A tuple containing a dictionary mapping frame indices to file paths, + the video, and the `ImageSeries`. + + Returns: + A PoseTraining object. + """ + training_frame_list = [] + skeleton_instances_list = [] + source_video_list = [] + for i, labeled_frame in enumerate(labels.labeled_frames): + for instance, skeleton in zip(labeled_frame.instances, skeletons_list): + skeleton_instance = instance_to_skeleton_instance(instance, skeleton) + skeleton_instances_list.append(skeleton_instance) + + training_frame_skeleton_instances = SkeletonInstances( + skeleton_instances=skeleton_instances_list + ) + training_frame_video_index = labeled_frame.frame_idx + + image_series = video_info[2] + source_video = image_series + if source_video not in source_video_list: + source_video_list.append(source_video) + training_frame = TrainingFrame( + name=f"training_frame_{i}", + annotator="N/A", + skeleton_instances=training_frame_skeleton_instances, + source_video=source_video, + source_video_frame_index=training_frame_video_index, + ) + training_frame_list.append(training_frame) + + training_frames = TrainingFrames(training_frames=training_frame_list) + source_videos = SourceVideos(image_series=source_video_list) + pose_training = PoseTraining( + training_frames=training_frames, + source_videos=source_videos, + ) + return pose_training + + +def slp_skeleton_to_nwb( + skeleton: SLEAPSkeleton, subject: Optional[Subject] = None +) -> NWBSkeleton: # type: ignore[return] + """Converts SLEAP skeleton to NWB skeleton. + + Args: + skeleton: A SLEAP skeleton. + subject: An NWB subject. + + Returns: + An NWB skeleton. + """ + if subject is None: + subject = Subject(species="No specified species", subject_id="No specified id") + nwb_edges = [] + skeleton_edges = dict(enumerate(skeleton.nodes)) + for i, source in skeleton_edges.items(): + for destination in list(skeleton_edges.values())[i:]: + if Edge(source, destination) in skeleton.edges: + nwb_edges.append([i, list(skeleton_edges.values()).index(destination)]) + + return NWBSkeleton( + name=skeleton.name, + nodes=skeleton.node_names, + edges=np.array(nwb_edges, dtype=np.uint8), + subject=subject, + ) + + +def instance_to_skeleton_instance( + instance: Instance, skeleton: NWBSkeleton # type: ignore[return] +) -> SkeletonInstance: # type: ignore[return] + """Converts a SLEAP Instance to an NWB SkeletonInstance. + + Args: + instance: A SLEAP Instance. + skeleton: An NWB Skeleton. + + Returns: + An NWB SkeletonInstance. + """ + points_list = list(instance.points.values()) + node_locs = [[point.x, point.y] for point in points_list] + np_node_locations = np.array(node_locs) + return SkeletonInstance( + name=f"skeleton_instance_{id(instance)}", + id=np.uint64(id(instance)), + # TODO add a counter in the loop to track the number of instances + # instead of using id + node_locations=np_node_locations, + node_visibility=[point.visible for point in instance.points.values()], + skeleton=skeleton, + ) + + +def videos_to_source_videos(videos: list[Video]) -> SourceVideos: # type: ignore[return] + """Converts a list of SLEAP Videos to NWB SourceVideos. + + Args: + videos: A list of SLEAP Videos. + + Returns: + An NWB SourceVideos object. + """ + source_videos = [] + for i, video in enumerate(videos): + image_series = ImageSeries( + name=f"video_{i}", + description="N/A", + unit="NA", + format="external", + external_file=[video.filename], + dimension=[video.backend.img_shape[0], video.backend.img_shape[1]], + starting_frame=[0], + rate=30.0, # TODO - change to `video.backend.fps` when available + ) + source_videos.append(image_series) + return SourceVideos(image_series=source_videos) + + +def write_video_to_path( + video: Video, + frame_inds: Optional[list[int]] = None, + image_format: str = "png", + frame_path: Optional[str] = None, +) -> tuple[dict[int, str], Video, ImageSeries]: + """Write individual frames of a video to a path. + + Args: + video: The video to write. + frame_inds: The indices of the frames to write. If None, all frames are written. + image_format: The format of the image to write. Default is .png + frame_path: The directory to save the frames to. If None, the path is the video + filename without the extension. + + Returns: + A tuple containing a dictionary mapping frame indices to file paths, + the video, and the `ImageSeries`. + """ + index_data = {} + if frame_inds is None: + frame_inds = list(range(video.backend.num_frames)) + + if isinstance(video.filename, list): + save_path = video.filename[0].split(".")[0] + else: + save_path = video.filename.split(".")[0] + + if frame_path is not None: + save_path = frame_path + + try: + os.makedirs(save_path, exist_ok=True) + except PermissionError: + filename_with_extension = video.filename.split("/")[-1] + filename = filename_with_extension.split(".")[0] + save_path = input("Permission denied. Enter a new path:") + "/" + filename + os.makedirs(save_path, exist_ok=True) + + if "cv2" in sys.modules: + for frame_idx in frame_inds: + try: + frame = video[frame_idx] + except FileNotFoundError: + video_filename = input("Video not found. Enter the video filename:") + video = Video.from_filename(video_filename) + frame = video[frame_idx] + frame_path = f"{save_path}/frame_{frame_idx}.{image_format}" + index_data[frame_idx] = frame_path + cv2.imwrite(frame_path, frame) + else: + for frame_idx in frame_inds: + try: + frame = video[frame_idx] + except FileNotFoundError: + video_filename = input("Video not found. Enter the filename:") + video = Video.from_filename(video_filename) + frame = video[frame_idx] + frame_path = f"{save_path}/frame_{frame_idx}.{image_format}" + index_data[frame_idx] = frame_path + iio.imwrite(frame_path, frame) + + image_series = ImageSeries( + name="video", + external_file=os.listdir(save_path), + starting_frame=[0 for _ in range(len(os.listdir(save_path)))], + rate=30.0, # TODO - change to `video.backend.fps` when available + ) + return index_data, video, image_series + + def get_timestamps(series: PoseEstimationSeries) -> np.ndarray: """Return a vector of timestamps for a `PoseEstimationSeries`.""" if series.timestamps is not None: @@ -59,7 +346,10 @@ def read_nwb(path: str) -> Labels: track_keys: List[str] = list(test_processing_module.fields["data_interfaces"]) # Get track - test_pose_estimation: PoseEstimation = test_processing_module[track_keys[0]] + for key in track_keys: + if isinstance(test_processing_module[key], PoseEstimation): + test_pose_estimation = test_processing_module[key] + break node_names = test_pose_estimation.nodes[:] edge_inds = test_pose_estimation.edges[:] @@ -72,7 +362,14 @@ def read_nwb(path: str) -> Labels: timestamps = np.empty(()) for track_key in _track_keys: for node_name in node_names: - pose_estimation_series = processing_module[track_key][node_name] + try: + pose_estimation_series = processing_module[track_key][node_name] + except TypeError: + continue + except KeyError: + pose_estimation_series = processing_module[ + "track=untracked" + ].pose_estimation_series[node_name] timestamps = np.union1d( timestamps, get_timestamps(pose_estimation_series) ) @@ -85,10 +382,20 @@ def read_nwb(path: str) -> Labels: tracks_numpy = np.full((n_frames, n_tracks, n_nodes, 2), np.nan, np.float32) confidence = np.full((n_frames, n_tracks, n_nodes), np.nan, np.float32) for track_idx, track_key in enumerate(_track_keys): - pose_estimation = processing_module[track_key] + try: + pose_estimation = processing_module[track_key] + if not isinstance(pose_estimation, PoseEstimation): + raise KeyError + except KeyError: + pose_estimation = processing_module["track=untracked"] for node_idx, node_name in enumerate(node_names): - pose_estimation_series = pose_estimation[node_name] + try: + pose_estimation_series = pose_estimation[node_name] + except KeyError: + pose_estimation_series = pose_estimation.pose_estimation_series[ + node_name + ] frame_inds = np.searchsorted( timestamps, get_timestamps(pose_estimation_series) ) @@ -106,7 +413,7 @@ def read_nwb(path: str) -> Labels: ) # Create skeleton - skeleton = Skeleton( + skeleton = SLEAPSkeleton( nodes=node_names, edges=edge_inds, ) @@ -126,15 +433,24 @@ def read_nwb(path: str) -> Labels: ): if np.isnan(inst_pts).all(): continue - insts.append( - PredictedInstance.from_numpy( - points=inst_pts, # (n_nodes, 2) - point_scores=inst_confs, # (n_nodes,) - instance_score=inst_confs.mean(), # () - skeleton=skeleton, - track=track if is_tracked else None, + try: + insts.append( + Instance.from_numpy( + points=inst_pts, # (n_nodes, 2) + point_scores=inst_confs, # (n_nodes,) + instance_score=inst_confs.mean(), # () + skeleton=skeleton, + track=track if is_tracked else None, + ) + ) + except TypeError: + insts.append( + Instance.from_numpy( + points=inst_pts, + skeleton=skeleton, + track=track if is_tracked else None, + ) ) - ) if len(insts) > 0: lfs.append( LabeledFrame(video=video, frame_idx=frame_idx, instances=insts) @@ -149,6 +465,9 @@ def write_nwb( nwbfile_path: str, nwb_file_kwargs: Optional[dict] = None, pose_estimation_metadata: Optional[dict] = None, + as_training: bool = False, + frame_inds: Optional[list[int]] = None, + frame_path: Optional[str] = None, ): """Write labels to an nwb file and save it to the nwbfile_path given. @@ -174,11 +493,16 @@ def write_nwb( or the sampling rate with key`video_sample_rate`. e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps) - or pose_estimation_metadata["video_sample_rate] = 15 # In Hz + or pose_estimation_metadata["video_sample_rate"] = 15 # In Hz 2) The other use of this dictionary is to ovewrite sleap-io default arguments for the PoseEstimation container. see https://github.com/rly/ndx-pose for a full list or arguments. + + as_training: If `True`, append the data as training data. + frame_inds: The indices of the frames to write. If None, all frames are written. + frame_path: The path to save the frames. If None, the path is the video + filename without the extension. """ nwb_file_kwargs = nwb_file_kwargs or dict() @@ -198,10 +522,36 @@ def write_nwb( ) nwbfile = NWBFile(**nwb_file_kwargs) - nwbfile = append_nwb_data(labels, nwbfile, pose_estimation_metadata) + if as_training: + nwbfile = append_nwb_training( + labels, + nwbfile, + pose_estimation_metadata, + frame_inds, + frame_path=frame_path, + ) + else: + nwbfile = append_nwb_data(labels, nwbfile, pose_estimation_metadata) with NWBHDF5IO(str(nwbfile_path), "w") as io: - io.write(nwbfile) + try: + io.write(nwbfile) + except OrphanContainerBuildError: + processing_module = nwbfile.processing[ + f"SLEAP_VIDEO_000_{Path(labels.videos[0].filename).stem}" + ] + try: + pose_estimation = processing_module["track=untracked"] + skeletons = [pose_estimation.skeleton] + except KeyError: + skeletons = [] + for i in range(len(labels.tracks)): + pose_estimation = processing_module[f"track=track_{i}"] + skeleton = pose_estimation.skeleton + skeletons.append(skeleton) if skeleton.parent is None else ... + skeletons = Skeletons(skeletons=skeletons) + processing_module.add(skeletons) + io.write(nwbfile) def append_nwb_data( @@ -239,6 +589,7 @@ def append_nwb_data( default_metadata["source_software_version"] = sleap_version labels_data_df = convert_predictions_to_dataframe(labels) + cameras = [] # For every video create a processing module for video_index, video in enumerate(labels.videos): @@ -248,9 +599,16 @@ def append_nwb_data( processing_module_name, nwbfile ) + camera = nwbfile.create_device( + name=f"camera {video_index}", + description=f"Camera used to record video {video_index}", + manufacturer="No specified manufacturer", + ) + cameras.append(camera) + # Propagate video metadata - default_metadata["original_videos"] = [f"{video.filename}"] # type: ignore - default_metadata["labeled_videos"] = [f"{video.filename}"] # type: ignore + default_metadata["original_videos"] = [f"{video.filename}"] + default_metadata["labeled_videos"] = [f"{video.filename}"] # Overwrite default with the user provided metadata default_metadata.update(pose_estimation_metadata) @@ -269,14 +627,123 @@ def append_nwb_data( track_name, video, default_metadata, + nwbfile, ) nwb_processing_module.add(pose_estimation_container) return nwbfile +def append_nwb_training( + labels: Labels, + nwbfile: NWBFile, + pose_estimation_metadata: Optional[dict] = None, + frame_inds: Optional[list[int]] = None, + frame_path: Optional[str] = None, +) -> NWBFile: + """Append training data from a Labels object to an in-memory NWB file. + + Args: + labels: A general labels object. + nwbfile: An in-memory NWB file. + pose_estimation_metadata: Metadata for pose estimation. + frame_inds: The indices of the frames to write. If None, all frames are written. + frame_path: The path to save the frames. If None, the path is the video + filename without the extension. + + Returns: + An in-memory NWB file with the PoseTraining data appended. + """ + pose_estimation_metadata = pose_estimation_metadata or dict() + provenance = labels.provenance + default_metadata = dict(scorer=str(provenance)) + sleap_version = provenance.get("sleap_version", None) + default_metadata["source_software_version"] = sleap_version + + subject = Subject(subject_id="No specified id", species="No specified species") + nwbfile.subject = subject + + for i, video in enumerate(labels.videos): + video_path = ( + Path(video.filename) + if isinstance(video.filename, str) + else video.filename[i] + ) + processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}" + nwb_processing_module = get_processing_module_for_video( + processing_module_name, nwbfile + ) + default_metadata["original_videos"] = [f"{video.filename}"] + default_metadata["labeled_videos"] = [f"{video.filename}"] + default_metadata.update(pose_estimation_metadata) + + skeletons_list = [ + slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons + ] + skeletons = Skeletons(skeletons=skeletons_list) + nwb_processing_module.add(skeletons) + video_info = write_video_to_path( + labels.videos[0], frame_inds, frame_path=frame_path + ) + pose_training = labels_to_pose_training(labels, skeletons_list, video_info) + nwb_processing_module.add(pose_training) + + confidence_definition = "Softmax output of the deep neural network" + reference_frame = ( + "The coordinates are in (x, y) relative to the top-left of the image. " + "Coordinates refer to the midpoint of the pixel. " + "That is, t the midpoint of the top-left pixel is at (0, 0), whereas " + "the top-left corner of that same pixel is at (-0.5, -0.5)." + ) + pose_estimation_series_list = [] + for node in skeletons_list[0].nodes: + pose_estimation_series = PoseEstimationSeries( + name=node, + description=f"Marker placed on {node}", + data=np.random.rand(100, 2), + unit="pixels", + reference_frame=reference_frame, + timestamps=np.linspace(0, 10, num=100), + confidence=np.random.rand(100), + confidence_definition=confidence_definition, + ) + pose_estimation_series_list.append(pose_estimation_series) + + camera = nwbfile.create_device( + name=f"camera {i}", + description=f"Camera used to record video {i}", + manufacturer="No specified manufacturer", + ) + try: + dimensions = np.array([[video.backend.shape[1], video.backend.shape[2]]]) + except AttributeError: + dimensions = np.array([[400, 400]]) + + pose_estimation = PoseEstimation( + name="pose_estimation", + pose_estimation_series=pose_estimation_series_list, + description="Estimated positions of the nodes in the video", + original_videos=[video.filename for video in labels.videos], + labeled_videos=[video.filename for video in labels.videos], + dimensions=dimensions, + devices=[camera], + scorer="No specified scorer", + source_software="SLEAP", + source_software_version=sleap_version, + skeleton=skeletons_list[0], + ) + nwb_processing_module.add(pose_estimation) + + return nwbfile + + def append_nwb( - labels: Labels, filename: str, pose_estimation_metadata: Optional[dict] = None + labels: Labels, + filename: str, + pose_estimation_metadata: Optional[dict] = None, + frame_inds: Optional[list[int]] = None, + frame_path: Optional[str] = None, + as_training: Optional[bool] = None, ): """Append a SLEAP `Labels` object to an existing NWB data file. @@ -285,14 +752,21 @@ def append_nwb( filename: The path to the NWB file. pose_estimation_metadata: Metadata for pose estimation. See `append_nwb_data` for details. + as_training: If `True`, append the data as training data. + frame_inds: The indices of the frames to write. If None, all frames are written. + frame_path: The path to save the frames. If None, the path is the video + filename without the extension. See also: append_nwb_data """ with NWBHDF5IO(filename, mode="a", load_namespaces=True) as io: nwb_file = io.read() - nwb_file = append_nwb_data( - labels, nwb_file, pose_estimation_metadata=pose_estimation_metadata - ) + if as_training: + nwb_file = append_nwb_training( + labels, nwb_file, pose_estimation_metadata, frame_inds, frame_path + ) + else: + nwb_file = append_nwb_data(labels, nwb_file, pose_estimation_metadata) io.write(nwb_file) @@ -327,6 +801,7 @@ def build_pose_estimation_container_for_track( track_name: str, video: Video, pose_estimation_metadata: dict, + nwbfile: NWBFile, ) -> PoseEstimation: """Create a PoseEstimation container for a track. @@ -336,6 +811,8 @@ def build_pose_estimation_container_for_track( labels (Labels): A general labels object track_name (str): The name of the track in labels.tracks video (Video): The video to which data belongs to + pose_estimation_metadata (dict): Metadata for the pose estimation. + nwbfile (NWBFile): The nwbfile. Returns: PoseEstimation: A PoseEstimation multicontainer where the time series @@ -385,6 +862,7 @@ def build_pose_estimation_container_for_track( nodes=skeleton.node_names, edges=np.array(skeleton.edge_inds).astype("uint64"), source_software="SLEAP", + devices=[list(nwbfile.devices.values())[0]], # dimensions=np.array([[video.backend.height, video.backend.width]]), ) @@ -395,7 +873,7 @@ def build_pose_estimation_container_for_track( def build_track_pose_estimation_list( - track_data_df: pd.DataFrame, timestamps: ArrayLike + track_data_df: pd.DataFrame, timestamps: ArrayLike # type: ignore[return] ) -> List[PoseEstimationSeries]: """Build a list of PoseEstimationSeries from tracks. diff --git a/tests/conftest.py b/tests/conftest.py index a78a8793..7d1bbf3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,3 +5,4 @@ from tests.fixtures.labelstudio import * from tests.fixtures.videos import * from tests.fixtures.jabs import * +from tests.fixtures.nwb import * diff --git a/tests/data/nwb/labels.v002.nwb b/tests/data/nwb/labels.v002.nwb new file mode 100644 index 00000000..72d1c980 Binary files /dev/null and b/tests/data/nwb/labels.v002.nwb differ diff --git a/tests/data/nwb/minimal_instance.nwb b/tests/data/nwb/minimal_instance.nwb new file mode 100644 index 00000000..f9af3e29 Binary files /dev/null and b/tests/data/nwb/minimal_instance.nwb differ diff --git a/tests/data/nwb/minimal_instance.pkg.nwb b/tests/data/nwb/minimal_instance.pkg.nwb new file mode 100644 index 00000000..d997bc4f Binary files /dev/null and b/tests/data/nwb/minimal_instance.pkg.nwb differ diff --git a/tests/data/nwb/typical.nwb b/tests/data/nwb/typical.nwb new file mode 100644 index 00000000..00c6a28d Binary files /dev/null and b/tests/data/nwb/typical.nwb differ diff --git a/tests/fixtures/nwb.py b/tests/fixtures/nwb.py new file mode 100644 index 00000000..1ae90b34 --- /dev/null +++ b/tests/fixtures/nwb.py @@ -0,0 +1,27 @@ +"""Fixtures that return paths to `.nwb` files.""" + +import pytest + + +@pytest.fixture +def minimal_instance_nwb(): + """NWB file with a single instance.""" + return "tests/data/nwb/minimal_instance.nwb" + + +@pytest.fixture +def minimal_instance_pkg_nwb(): + """NWB .pkg file with a single instance.""" + return "tests/data/nwb/minimal_instance.pkg.nwb" + + +@pytest.fixture +def labels_v002_nwb(): + """NWB file with labels saved as a dataset.""" + return "tests/data/nwb/labels.v002.nwb" + + +@pytest.fixture +def typical_nwb(): + """Typical NWB file.""" + return "tests/data/nwb/typical.nwb" diff --git a/tests/io/test_main.py b/tests/io/test_main.py index 882c3295..0d7fb325 100644 --- a/tests/io/test_main.py +++ b/tests/io/test_main.py @@ -22,21 +22,18 @@ def test_load_slp(slp_typical): assert type(load_file(slp_typical)) == Labels -def test_nwb(tmp_path, slp_typical): +def test_nwb(tmp_path, slp_typical, slp_predictions_with_provenance): labels = load_slp(slp_typical) - save_nwb(labels, tmp_path / "test_nwb.nwb") + save_nwb(labels, tmp_path / "test_nwb.nwb", False) loaded_labels = load_nwb(tmp_path / "test_nwb.nwb") assert type(loaded_labels) == Labels assert type(load_file(tmp_path / "test_nwb.nwb")) == Labels assert len(loaded_labels) == len(labels) - labels2 = load_slp(slp_typical) - labels2.videos[0].filename = "test" - save_nwb(labels2, tmp_path / "test_nwb.nwb", append=True) - loaded_labels = load_nwb(tmp_path / "test_nwb.nwb") - assert type(loaded_labels) == Labels - assert len(loaded_labels) == (len(labels) + len(labels2)) - assert len(loaded_labels.videos) == 2 + +def test_nwb_training(tmp_path, slp_typical): + labels = load_slp(slp_typical) + save_nwb(labels, tmp_path / "test_nwb.nwb", True) def test_labelstudio(tmp_path, slp_typical): diff --git a/tests/io/test_nwb.py b/tests/io/test_nwb.py index 5963dd92..7e553ff2 100644 --- a/tests/io/test_nwb.py +++ b/tests/io/test_nwb.py @@ -4,9 +4,19 @@ import numpy as np from pynwb import NWBFile, NWBHDF5IO +from pynwb.file import Subject from sleap_io import load_slp -from sleap_io.io.nwb import write_nwb, append_nwb_data, get_timestamps +from sleap_io.io.nwb import ( + Video, + write_video_to_path, + labels_to_pose_training, + pose_training_to_labels, + slp_skeleton_to_nwb, + append_nwb_data, + write_nwb, + get_timestamps, +) @pytest.fixture @@ -24,6 +34,39 @@ def nwbfile(): return nwbfile +def test_video_to_path(): + video = Video(filename="tests/data/videos/centered_pair_low_quality.mp4") + video_info = write_video_to_path(video, frame_inds=[i for i in range(30, 50)]) + index_data, _, _ = video_info + assert list(index_data.keys()) == [i for i in range(30, 50)] + image_0_name = list(index_data.values())[0] + assert image_0_name == "tests/data/videos/centered_pair_low_quality/frame_30.png" + + +def test_slp_to_nwb_conversion(): + labels_original = load_slp("tests/data/slp/minimal_instance.pkg.slp") + subject = Subject(subject_id="test_subject", species="test_species") + nwb_skeletons = [ + slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels_original.skeletons + ] + video_info = write_video_to_path(labels_original.video) + pose_training = labels_to_pose_training(labels_original, nwb_skeletons, video_info) + labels_converted = pose_training_to_labels(pose_training) + assert len(labels_original.labeled_frames) == len(labels_converted.labeled_frames) + + original_instance = labels_original.labeled_frames[0].instances[0] + converted_instance = labels_converted.labeled_frames[0].instances[0] + assert np.array_equal(original_instance.numpy(), converted_instance.numpy()) + + slp_skeleton = labels_original.skeletons[0] + nwb_skeleton = slp_skeleton_to_nwb(slp_skeleton, subject) + assert len(nwb_skeleton.nodes) == len(slp_skeleton.nodes) + assert len(nwb_skeleton.edges) == len(slp_skeleton.edges) + + training_frames_len = len(pose_training.training_frames.training_frames) + assert training_frames_len == len(labels_original.labeled_frames) + + def test_typical_case_append(nwbfile, slp_typical): labels = load_slp(slp_typical) nwbfile = append_nwb_data(labels, nwbfile)