Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SkeletonInstances in TrainingFrame #3

Merged
merged 7 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions spec/ndx-pose.extensions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ groups:
then `source_video` is required.
required: false
groups:
- name: skeleton_instance
neurodata_type_inc: SkeletonInstance
doc: Position data for a single instance of a skeleton in a single training frame.
- name: skeleton_instances
neurodata_type_inc: SkeletonInstances
doc: Position data for all instances of a skeleton in a single training frame.
links:
- name: source_video
target_type: ImageSeries
Expand All @@ -159,8 +159,8 @@ groups:
- neurodata_type_def: SkeletonInstance
neurodata_type_inc: NWBDataInterface
default_name: skeleton_instance
doc: Group that holds ground-truth pose data for a single instance of a skeleton
in a single frame. This is meant to be used within a TrainingFrame.
doc: 'Group that holds ground-truth pose data for a single instance of a skeleton
in a single frame. '
attributes:
- name: id
dtype: uint8
Expand Down Expand Up @@ -200,6 +200,16 @@ groups:
- neurodata_type_inc: TrainingFrame
doc: Ground-truth position data for all instances of a skeleton in a single frame.
quantity: '*'
- neurodata_type_def: SkeletonInstances
neurodata_type_inc: NWBDataInterface
default_name: skeleton_instances
doc: Organizational group to hold skeleton instances.This is meant to be used within
a TrainingFrame.
groups:
- neurodata_type_inc: SkeletonInstance
doc: Ground-truth position data for a single instance of a skeleton in a single
training frame.
quantity: '*'
- neurodata_type_def: SourceVideos
neurodata_type_inc: NWBDataInterface
default_name: source_videos
Expand Down
5 changes: 4 additions & 1 deletion src/pynwb/ndx_pose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from pynwb import load_namespaces, get_class

# Set path of the namespace.yaml file to the expected install location
ndx_pose_specpath = os.path.join(os.path.dirname(__file__), "spec", "ndx-pose.namespace.yaml")
ndx_pose_specpath = os.path.join(
os.path.dirname(__file__), "spec", "ndx-pose.namespace.yaml"
)

# If the extension has not been installed yet but we are running directly from
# the git repo
Expand Down Expand Up @@ -30,6 +32,7 @@
TrainingFrame,
TrainingFrames,
SkeletonInstance,
SkeletonInstances,
SourceVideos,
PoseTraining,
) # noqa: E402, F401
75 changes: 56 additions & 19 deletions src/pynwb/ndx_pose/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Skeleton = get_class("Skeleton", "ndx-pose")
Skeletons = get_class("Skeletons", "ndx-pose")
SkeletonInstance = get_class("SkeletonInstance", "ndx-pose")
SkeletonInstances = get_class("SkeletonInstances", "ndx-pose")
TrainingFrame = get_class("TrainingFrame", "ndx-pose")
TrainingFrames = get_class("TrainingFrames", "ndx-pose")
SourceVideos = get_class("SourceVideos", "ndx-pose")
Expand All @@ -30,7 +31,9 @@ class PoseEstimationSeries(SpatialSeries):
{
"name": "name",
"type": str,
"doc": ("Name of this PoseEstimationSeries, usually the name of a body part."),
"doc": (
"Name of this PoseEstimationSeries, usually the name of a body part."
),
},
{
"name": "data",
Expand All @@ -47,7 +50,9 @@ class PoseEstimationSeries(SpatialSeries):
"name": "confidence",
"type": ("array_data", "data"),
"shape": (None,),
"doc": ("Confidence or likelihood of the estimated positions, scaled to be between 0 and 1."),
"doc": (
"Confidence or likelihood of the estimated positions, scaled to be between 0 and 1."
),
"default": None,
},
{
Expand All @@ -64,7 +69,8 @@ class PoseEstimationSeries(SpatialSeries):
"name": "confidence_definition",
"type": str,
"doc": (
"Description of how the confidence was computed, e.g., " "'Softmax output of the deep neural network'."
"Description of how the confidence was computed, e.g., "
"'Softmax output of the deep neural network'."
),
"default": None,
},
Expand All @@ -85,7 +91,9 @@ class PoseEstimationSeries(SpatialSeries):
)
def __init__(self, **kwargs):
"""Construct a new PoseEstimationSeries representing pose estimates for a particular body part."""
confidence, confidence_definition = popargs("confidence", "confidence_definition", kwargs)
confidence, confidence_definition = popargs(
"confidence", "confidence_definition", kwargs
)
super().__init__(**kwargs)
self.confidence = confidence
self.confidence_definition = confidence_definition
Expand Down Expand Up @@ -160,7 +168,9 @@ class PoseEstimation(MultiContainerInterface):
"name": "labeled_videos",
"type": ("array_data", "data"),
"shape": (None,),
"doc": ("Paths to the labeled video files. The number of files should equal the number of camera devices."),
"doc": (
"Paths to the labeled video files. The number of files should equal the number of camera devices."
),
"default": None,
},
{
Expand All @@ -185,7 +195,9 @@ class PoseEstimation(MultiContainerInterface):
{
"name": "source_software",
"type": str,
"doc": ("Name of the software tool used. Specifying the version attribute is strongly encouraged."),
"doc": (
"Name of the software tool used. Specifying the version attribute is strongly encouraged."
),
"default": None,
},
{
Expand Down Expand Up @@ -226,11 +238,14 @@ def __init__(self, **kwargs):
nodes, edges, skeleton = popargs("nodes", "edges", "skeleton", kwargs)
if nodes is not None or edges is not None:
if skeleton is not None:
raise ValueError("Cannot specify both 'nodes' and 'edges' and 'skeleton'.")
raise ValueError(
"Cannot specify both 'nodes' and 'edges' and 'skeleton'."
)
skeleton = Skeleton(name="subject", nodes=nodes, edges=edges)
warnings.warn(
"The 'nodes' and 'edges' arguments are deprecated. Please use the 'skeleton' argument instead.",
DeprecationWarning, stacklevel=2
DeprecationWarning,
stacklevel=2,
)

# devices must be added to the NWBFile before being linked to from a PoseEstimation object.
Expand All @@ -243,10 +258,16 @@ def __init__(self, **kwargs):
"All devices linked to from a PoseEstimation object must be added to the NWBFile first."
)

pose_estimation_series, description = popargs("pose_estimation_series", "description", kwargs)
original_videos, labeled_videos = popargs("original_videos", "labeled_videos", kwargs)
pose_estimation_series, description = popargs(
"pose_estimation_series", "description", kwargs
)
original_videos, labeled_videos = popargs(
"original_videos", "labeled_videos", kwargs
)
dimensions, scorer = popargs("dimensions", "scorer", kwargs)
source_software, source_software_version = popargs("source_software", "source_software_version", kwargs)
source_software, source_software_version = popargs(
"source_software", "source_software_version", kwargs
)
super().__init__(**kwargs)
self.pose_estimation_series = pose_estimation_series
self.description = description
Expand All @@ -266,25 +287,41 @@ def __init__(self, **kwargs):
# TODO validate that the nodes correspond to the names of the pose estimation series objects

# validate that len(original_videos) == len(labeled_videos) == len(dimensions) == len(cameras)
if original_videos is not None and (devices is None or len(original_videos) != len(devices)):
raise ValueError("The number of original videos should equal the number of camera devices.")
if labeled_videos is not None and (devices is None or len(labeled_videos) != len(devices)):
raise ValueError("The number of labeled videos should equal the number of camera devices.")
if dimensions is not None and (devices is None or len(dimensions) != len(devices)):
raise ValueError("The number of dimensions should equal the number of camera devices.")
if original_videos is not None and (
devices is None or len(original_videos) != len(devices)
):
raise ValueError(
"The number of original videos should equal the number of camera devices."
)
if labeled_videos is not None and (
devices is None or len(labeled_videos) != len(devices)
):
raise ValueError(
"The number of labeled videos should equal the number of camera devices."
)
if dimensions is not None and (
devices is None or len(dimensions) != len(devices)
):
raise ValueError(
"The number of dimensions should equal the number of camera devices."
)

@property
def nodes(self):
return self.skeleton.nodes

@nodes.setter
def nodes(self, value):
raise ValueError("'nodes' is deprecated. Please use the 'skeleton' field instead.")
raise ValueError(
"'nodes' is deprecated. Please use the 'skeleton' field instead."
)

@property
def edges(self):
return self.skeleton.edges

@edges.setter
def edges(self, value):
raise ValueError("'edges' is deprecated. Please use the 'skeleton' field instead.")
raise ValueError(
"'edges' is deprecated. Please use the 'skeleton' field instead."
)
54 changes: 47 additions & 7 deletions src/pynwb/ndx_pose/testing/mock/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
from pynwb.testing.mock.utils import name_generator
from pynwb.testing.mock.device import mock_Device

from ...pose import PoseEstimationSeries, Skeleton, PoseEstimation, SkeletonInstance, TrainingFrame, Skeletons, PoseTraining
from ...pose import (
PoseEstimationSeries,
Skeleton,
PoseEstimation,
SkeletonInstance,
SkeletonInstances,
TrainingFrame,
Skeletons,
PoseTraining,
)


def mock_PoseEstimationSeries(
Expand Down Expand Up @@ -91,7 +100,9 @@ def mock_PoseEstimation(
NWBFile should be provided so that the skeleton can be added to the NWBFile in a PoseTraining object.
"""
skeleton = skeleton or mock_Skeleton()
pose_estimation_series = pose_estimation_series or [mock_PoseEstimationSeries(name=name) for name in skeleton.nodes]
pose_estimation_series = pose_estimation_series or [
mock_PoseEstimationSeries(name=name) for name in skeleton.nodes
]
pe = PoseEstimation(
pose_estimation_series=pose_estimation_series,
description=description,
Expand All @@ -110,16 +121,20 @@ def mock_PoseEstimation(
pose_training = PoseTraining(skeletons=skeletons)

if "behavior" not in nwbfile.processing:
behavior_pm = nwbfile.create_processing_module(name="behavior", description="processed behavioral data")
behavior_pm = nwbfile.create_processing_module(
name="behavior", description="processed behavioral data"
)
else:
behavior_pm = nwbfile.processing["behavior"]
behavior_pm.add(pe)
behavior_pm.add(pose_training)

return pe


def mock_SkeletonInstance(
*,
name: Optional[str] = None,
id: Optional[np.uint] = np.uint(10),
node_locations: Optional[Any] = None,
node_visibility: list = None,
Expand All @@ -138,18 +153,35 @@ def mock_SkeletonInstance(
edges=np.array([[0, 1]], dtype="uint8"),
)
if node_locations is None:
node_locations = np.arange(num_nodes * 2, dtype=np.float64).reshape((num_nodes, 2))
node_locations = np.arange(num_nodes * 2, dtype=np.float64).reshape(
(num_nodes, 2)
)

if name is None:
name = skeleton.name + "_instance_" + str(id)
if node_visibility is None:
node_visibility = np.ones(num_nodes, dtype="bool")
skeleton_instance = SkeletonInstance(
name=name,
id=id,
node_locations=node_locations,
node_visibility=node_visibility,
skeleton=skeleton,
)

return skeleton_instance


def mock_SkeletonInstances(skeleton_instances=None):
if skeleton_instances is None:
skeleton_instances = [mock_SkeletonInstance()]
if not isinstance(skeleton_instances, list):
skeleton_instances = [skeleton_instances]
return SkeletonInstances(
skeleton_instances=skeleton_instances,
)


def mock_source_video(
*,
name: Optional[str] = None,
Expand All @@ -172,20 +204,28 @@ def mock_source_frame(
):
return RGBImage(name=name, data=np.random.rand(640, 480, 3).astype("uint8"))

def mock_source_frame(
*,
name: Optional[str] = None,
):
return RGBImage(name=name, data=np.random.rand(640, 480, 3).astype("uint8"))


def mock_TrainingFrame(
*,
name: Optional[str] = None,
annotator: Optional[str] = "Awesome Possum",
skeleton_instance: SkeletonInstance = None,
skeleton_instances: SkeletonInstances = None,
source_video: ImageSeries = None,
source_frame: Image = None,
source_video_frame_index: np.uint = np.uint(10),
):
training_frame = TrainingFrame(
name=name or name_generator("TrainingFrame"),
annotator=annotator,
skeleton_instance=skeleton_instance or mock_SkeletonInstance(),
source_video=source_video or (mock_source_video() if source_frame is None else None),
skeleton_instances=skeleton_instances or mock_SkeletonInstances(),
source_video=source_video
or (mock_source_video() if source_frame is None else None),
source_frame=source_frame,
source_video_frame_index=source_video_frame_index,
)
Expand Down
Loading