Skip to content

Commit

Permalink
added SkeletonInstance counter
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Sep 8, 2024
1 parent 141d1dd commit 67a62c2
Showing 1 changed file with 124 additions and 118 deletions.
242 changes: 124 additions & 118 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,57 +57,75 @@
from sleap_io.io.utils import convert_predictions_to_dataframe


def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return]
"""Creates a Labels object from an NWB PoseTraining object.
def write_video_to_path(
video: Video,
frame_inds: Optional[list[int]] = None,
image_format: str = "png",
frame_path: Optional[str] = None,
) -> tuple[dict[int, str], Video, ImageSeries]:
"""Write individual frames of a video to a path.
Args:
pose_training: An NWB PoseTraining object.
video: The video to write.
frame_inds: The indices of the frames to write. If None, all frames are written.
image_format: The format of the image to write. Default is .png
frame_path: The directory to save the frames to. If None, the path is the video
filename without the extension.
Returns:
A Labels object.
A tuple containing a dictionary mapping frame indices to file paths,
the video, and the `ImageSeries`.
"""
labeled_frames = []
skeletons = {}
training_frames = pose_training.training_frames.training_frames.values()
for training_frame in training_frames:
source_video = training_frame.source_video
video = Video(source_video.external_file)
index_data = {}
if frame_inds is None:
frame_inds = list(range(video.backend.num_frames))

frame_idx = training_frame.source_video_frame_index
instances = []
for instance in training_frame.skeleton_instances.skeleton_instances.values():
if instance.skeleton.name not in skeletons:
skeletons[instance.skeleton.name] = nwb_skeleton_to_sleap(
instance.skeleton
)
skeleton = skeletons[instance.skeleton.name]
instances.append(
Instance.from_numpy(
points=instance.node_locations[:], skeleton=skeleton
)
)
labeled_frames.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)
)
return Labels(labeled_frames=labeled_frames)
if isinstance(video.filename, list):
save_path = video.filename[0].split(".")[0]

Check warning on line 84 in sleap_io/io/nwb.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/io/nwb.py#L84

Added line #L84 was not covered by tests
else:
save_path = video.filename.split(".")[0]

if frame_path is not None:
save_path = frame_path

Check warning on line 89 in sleap_io/io/nwb.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/io/nwb.py#L89

Added line #L89 was not covered by tests

def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return]
"""Converts an NWB skeleton to a SLEAP skeleton.
try:
os.makedirs(save_path, exist_ok=True)
except PermissionError:
filename_with_extension = video.filename.split("/")[-1]
filename = filename_with_extension.split(".")[0]
save_path = input("Permission denied. Enter a new path:") + "/" + filename
os.makedirs(save_path, exist_ok=True)

Check warning on line 97 in sleap_io/io/nwb.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/io/nwb.py#L93-L97

Added lines #L93 - L97 were not covered by tests

Args:
skeleton: An NWB skeleton.
if "cv2" in sys.modules:
for frame_idx in frame_inds:
try:
frame = video[frame_idx]
except FileNotFoundError:
video_filename = input("Video not found. Enter the video filename:")
video = Video.from_filename(video_filename)
frame = video[frame_idx]

Check warning on line 106 in sleap_io/io/nwb.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/io/nwb.py#L103-L106

Added lines #L103 - L106 were not covered by tests
frame_path = f"{save_path}/frame_{frame_idx}.{image_format}"
index_data[frame_idx] = frame_path
cv2.imwrite(frame_path, frame)
else:
for frame_idx in frame_inds:
try:
frame = video[frame_idx]
except FileNotFoundError:
video_filename = input("Video not found. Enter the filename:")
video = Video.from_filename(video_filename)
frame = video[frame_idx]
frame_path = f"{save_path}/frame_{frame_idx}.{image_format}"
index_data[frame_idx] = frame_path
iio.imwrite(frame_path, frame)

Check warning on line 120 in sleap_io/io/nwb.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/io/nwb.py#L111-L120

Added lines #L111 - L120 were not covered by tests

Returns:
A SLEAP skeleton.
"""
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
nodes=nodes,
edges=edges,
name=skeleton.name,
image_series = ImageSeries(
name="video",
external_file=os.listdir(save_path),
starting_frame=[0 for _ in range(len(os.listdir(save_path)))],
rate=30.0, # TODO - change to `video.backend.fps` when available
)
return index_data, video, image_series


def labels_to_pose_training(
Expand All @@ -130,9 +148,13 @@ def labels_to_pose_training(
skeleton_instances_list = []
source_video_list = []
for i, labeled_frame in enumerate(labels.labeled_frames):
instance_counter = 0
for instance, skeleton in zip(labeled_frame.instances, skeletons_list):
skeleton_instance = instance_to_skeleton_instance(instance, skeleton)
skeleton_instance = instance_to_skeleton_instance(
instance, skeleton, instance_counter
)
skeleton_instances_list.append(skeleton_instance)
instance_counter += 1

training_frame_skeleton_instances = SkeletonInstances(
skeleton_instances=skeleton_instances_list
Expand Down Expand Up @@ -161,6 +183,59 @@ def labels_to_pose_training(
return pose_training


def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return]
"""Creates a Labels object from an NWB PoseTraining object.
Args:
pose_training: An NWB PoseTraining object.
Returns:
A Labels object.
"""
labeled_frames = []
skeletons = {}
training_frames = pose_training.training_frames.training_frames.values()
for training_frame in training_frames:
source_video = training_frame.source_video
video = Video(source_video.external_file)

frame_idx = training_frame.source_video_frame_index
instances = []
for instance in training_frame.skeleton_instances.skeleton_instances.values():
if instance.skeleton.name not in skeletons:
skeletons[instance.skeleton.name] = nwb_skeleton_to_sleap(
instance.skeleton
)
skeleton = skeletons[instance.skeleton.name]
instances.append(
Instance.from_numpy(
points=instance.node_locations[:], skeleton=skeleton
)
)
labeled_frames.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)
)
return Labels(labeled_frames=labeled_frames)


def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return]
"""Converts an NWB skeleton to a SLEAP skeleton.
Args:
skeleton: An NWB skeleton.
Returns:
A SLEAP skeleton.
"""
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
nodes=nodes,
edges=edges,
name=skeleton.name,
)


def slp_skeleton_to_nwb(
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None
) -> NWBSkeleton: # type: ignore[return]
Expand Down Expand Up @@ -191,24 +266,26 @@ def slp_skeleton_to_nwb(


def instance_to_skeleton_instance(
instance: Instance, skeleton: NWBSkeleton # type: ignore[return]
instance: Instance,
skeleton: NWBSkeleton, # type: ignore[return]
counter: int,
) -> SkeletonInstance: # type: ignore[return]
"""Converts a SLEAP Instance to an NWB SkeletonInstance.
Args:
instance: A SLEAP Instance.
skeleton: An NWB Skeleton.
counter: An integer counter.
Returns:
An NWB SkeletonInstance.
"""
points_list = list(instance.points.values())
node_locs = [[point.x, point.y] for point in points_list]
np_node_locations = np.array(node_locs)
node_locations = np.array([[point.x, point.y] for point in points_list])
return SkeletonInstance(
name=f"skeleton_instance_{id(instance)}",
id=np.uint64(id(instance)),
node_locations=np_node_locations,
name=f"skeleton_instance_{counter}",
id=np.uint64(counter),
node_locations=node_locations,
node_visibility=[point.visible for point in instance.points.values()],
skeleton=skeleton,
)
Expand Down Expand Up @@ -239,77 +316,6 @@ def videos_to_source_videos(videos: list[Video]) -> SourceVideos: # type: ignor
return SourceVideos(image_series=source_videos)

Check warning on line 316 in sleap_io/io/nwb.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/io/nwb.py#L315-L316

Added lines #L315 - L316 were not covered by tests


def write_video_to_path(
video: Video,
frame_inds: Optional[list[int]] = None,
image_format: str = "png",
frame_path: Optional[str] = None,
) -> tuple[dict[int, str], Video, ImageSeries]:
"""Write individual frames of a video to a path.
Args:
video: The video to write.
frame_inds: The indices of the frames to write. If None, all frames are written.
image_format: The format of the image to write. Default is .png
frame_path: The directory to save the frames to. If None, the path is the video
filename without the extension.
Returns:
A tuple containing a dictionary mapping frame indices to file paths,
the video, and the `ImageSeries`.
"""
index_data = {}
if frame_inds is None:
frame_inds = list(range(video.backend.num_frames))

if isinstance(video.filename, list):
save_path = video.filename[0].split(".")[0]
else:
save_path = video.filename.split(".")[0]

if frame_path is not None:
save_path = frame_path

try:
os.makedirs(save_path, exist_ok=True)
except PermissionError:
filename_with_extension = video.filename.split("/")[-1]
filename = filename_with_extension.split(".")[0]
save_path = input("Permission denied. Enter a new path:") + "/" + filename
os.makedirs(save_path, exist_ok=True)

if "cv2" in sys.modules:
for frame_idx in frame_inds:
try:
frame = video[frame_idx]
except FileNotFoundError:
video_filename = input("Video not found. Enter the video filename:")
video = Video.from_filename(video_filename)
frame = video[frame_idx]
frame_path = f"{save_path}/frame_{frame_idx}.{image_format}"
index_data[frame_idx] = frame_path
cv2.imwrite(frame_path, frame)
else:
for frame_idx in frame_inds:
try:
frame = video[frame_idx]
except FileNotFoundError:
video_filename = input("Video not found. Enter the filename:")
video = Video.from_filename(video_filename)
frame = video[frame_idx]
frame_path = f"{save_path}/frame_{frame_idx}.{image_format}"
index_data[frame_idx] = frame_path
iio.imwrite(frame_path, frame)

image_series = ImageSeries(
name="video",
external_file=os.listdir(save_path),
starting_frame=[0 for _ in range(len(os.listdir(save_path)))],
rate=30.0, # TODO - change to `video.backend.fps` when available
)
return index_data, video, image_series


def get_timestamps(series: PoseEstimationSeries) -> np.ndarray:
"""Return a vector of timestamps for a `PoseEstimationSeries`."""
if series.timestamps is not None:
Expand Down

0 comments on commit 67a62c2

Please sign in to comment.