diff --git a/environment.yml b/environment.yml index d883e016..657319f4 100644 --- a/environment.yml +++ b/environment.yml @@ -24,7 +24,7 @@ dependencies: - pydocstyle - toml - twine - - build + - python-build - pip - pip: - "--editable=.[dev]" diff --git a/sleap_io/io/nwb.py b/sleap_io/io/nwb.py index 4e922a52..68ad466f 100644 --- a/sleap_io/io/nwb.py +++ b/sleap_io/io/nwb.py @@ -24,7 +24,6 @@ except ImportError: ArrayLike = np.ndarray -from hdmf.utils import LabelledDict from hdmf.build.errors import OrphanContainerBuildError from pynwb import NWBFile, NWBHDF5IO, ProcessingModule @@ -72,10 +71,7 @@ def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ign training_frames = pose_training.training_frames.training_frames.values() for training_frame in training_frames: source_video = training_frame.source_video - if source_video.format == "external" and len(source_video.external_file) == 1: - video = Video(source_video.external_file[0]) - else: - raise NotImplementedError("Only single-file external videos are supported.") + video = Video(source_video.external_file) frame_idx = training_frame.source_video_frame_index instances = [] @@ -89,7 +85,7 @@ def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ign Instance.from_numpy( points=instance.node_locations[:], skeleton=skeleton ) - ) # `track` field is not stored in `SkeletonInstance` objects + ) labeled_frames.append( LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) ) @@ -212,8 +208,6 @@ def instance_to_skeleton_instance( return SkeletonInstance( name=f"skeleton_instance_{id(instance)}", id=np.uint64(id(instance)), - # TODO add a counter in the loop to track the number of instances - # instead of using id node_locations=np_node_locations, node_visibility=[point.visible for point in instance.points.values()], skeleton=skeleton, @@ -620,7 +614,7 @@ def append_nwb_data( .unique() ) - for track_index, track_name in enumerate(name_of_tracks_in_video): + for track_name in name_of_tracks_in_video: pose_estimation_container = build_pose_estimation_container_for_track( labels_data_df, labels, @@ -688,52 +682,11 @@ def append_nwb_training( pose_training = labels_to_pose_training(labels, skeletons_list, video_info) nwb_processing_module.add(pose_training) - confidence_definition = "Softmax output of the deep neural network" - reference_frame = ( - "The coordinates are in (x, y) relative to the top-left of the image. " - "Coordinates refer to the midpoint of the pixel. " - "That is, t the midpoint of the top-left pixel is at (0, 0), whereas " - "the top-left corner of that same pixel is at (-0.5, -0.5)." - ) - pose_estimation_series_list = [] - for node in skeletons_list[0].nodes: - pose_estimation_series = PoseEstimationSeries( - name=node, - description=f"Marker placed on {node}", - data=np.random.rand(100, 2), - unit="pixels", - reference_frame=reference_frame, - timestamps=np.linspace(0, 10, num=100), - confidence=np.random.rand(100), - confidence_definition=confidence_definition, - ) - pose_estimation_series_list.append(pose_estimation_series) - - camera = nwbfile.create_device( + _ = nwbfile.create_device( name=f"camera {i}", description=f"Camera used to record video {i}", manufacturer="No specified manufacturer", ) - try: - dimensions = np.array([[video.backend.shape[1], video.backend.shape[2]]]) - except AttributeError: - dimensions = np.array([[400, 400]]) - - pose_estimation = PoseEstimation( - name="pose_estimation", - pose_estimation_series=pose_estimation_series_list, - description="Estimated positions of the nodes in the video", - original_videos=[video.filename for video in labels.videos], - labeled_videos=[video.filename for video in labels.videos], - dimensions=dimensions, - devices=[camera], - scorer="No specified scorer", - source_software="SLEAP", - source_software_version=sleap_version, - skeleton=skeletons_list[0], - ) - nwb_processing_module.add(pose_estimation) - return nwbfile @@ -767,6 +720,7 @@ def append_nwb( ) else: nwb_file = append_nwb_data(labels, nwb_file, pose_estimation_metadata) + io.write(nwb_file)