Skip to content

Commit

Permalink
code quality update, removed PoseEstimation handling
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Sep 5, 2024
1 parent c3b286e commit dedc06d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 52 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies:
- pydocstyle
- toml
- twine
- build
- python-build
- pip
- pip:
- "--editable=.[dev]"
56 changes: 5 additions & 51 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
except ImportError:
ArrayLike = np.ndarray

from hdmf.utils import LabelledDict
from hdmf.build.errors import OrphanContainerBuildError

from pynwb import NWBFile, NWBHDF5IO, ProcessingModule
Expand Down Expand Up @@ -72,10 +71,7 @@ def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ign
training_frames = pose_training.training_frames.training_frames.values()
for training_frame in training_frames:
source_video = training_frame.source_video
if source_video.format == "external" and len(source_video.external_file) == 1:
video = Video(source_video.external_file[0])
else:
raise NotImplementedError("Only single-file external videos are supported.")
video = Video(source_video.external_file)

frame_idx = training_frame.source_video_frame_index
instances = []
Expand All @@ -89,7 +85,7 @@ def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ign
Instance.from_numpy(
points=instance.node_locations[:], skeleton=skeleton
)
) # `track` field is not stored in `SkeletonInstance` objects
)
labeled_frames.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)
)
Expand Down Expand Up @@ -212,8 +208,6 @@ def instance_to_skeleton_instance(
return SkeletonInstance(
name=f"skeleton_instance_{id(instance)}",
id=np.uint64(id(instance)),
# TODO add a counter in the loop to track the number of instances
# instead of using id
node_locations=np_node_locations,
node_visibility=[point.visible for point in instance.points.values()],
skeleton=skeleton,
Expand Down Expand Up @@ -620,7 +614,7 @@ def append_nwb_data(
.unique()
)

for track_index, track_name in enumerate(name_of_tracks_in_video):
for track_name in name_of_tracks_in_video:
pose_estimation_container = build_pose_estimation_container_for_track(
labels_data_df,
labels,
Expand Down Expand Up @@ -688,52 +682,11 @@ def append_nwb_training(
pose_training = labels_to_pose_training(labels, skeletons_list, video_info)
nwb_processing_module.add(pose_training)

confidence_definition = "Softmax output of the deep neural network"
reference_frame = (
"The coordinates are in (x, y) relative to the top-left of the image. "
"Coordinates refer to the midpoint of the pixel. "
"That is, t the midpoint of the top-left pixel is at (0, 0), whereas "
"the top-left corner of that same pixel is at (-0.5, -0.5)."
)
pose_estimation_series_list = []
for node in skeletons_list[0].nodes:
pose_estimation_series = PoseEstimationSeries(
name=node,
description=f"Marker placed on {node}",
data=np.random.rand(100, 2),
unit="pixels",
reference_frame=reference_frame,
timestamps=np.linspace(0, 10, num=100),
confidence=np.random.rand(100),
confidence_definition=confidence_definition,
)
pose_estimation_series_list.append(pose_estimation_series)

camera = nwbfile.create_device(
_ = nwbfile.create_device(
name=f"camera {i}",
description=f"Camera used to record video {i}",
manufacturer="No specified manufacturer",
)
try:
dimensions = np.array([[video.backend.shape[1], video.backend.shape[2]]])
except AttributeError:
dimensions = np.array([[400, 400]])

pose_estimation = PoseEstimation(
name="pose_estimation",
pose_estimation_series=pose_estimation_series_list,
description="Estimated positions of the nodes in the video",
original_videos=[video.filename for video in labels.videos],
labeled_videos=[video.filename for video in labels.videos],
dimensions=dimensions,
devices=[camera],
scorer="No specified scorer",
source_software="SLEAP",
source_software_version=sleap_version,
skeleton=skeletons_list[0],
)
nwb_processing_module.add(pose_estimation)

return nwbfile


Expand Down Expand Up @@ -767,6 +720,7 @@ def append_nwb(
)
else:
nwb_file = append_nwb_data(labels, nwb_file, pose_estimation_metadata)

io.write(nwb_file)


Expand Down

0 comments on commit dedc06d

Please sign in to comment.