Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance training dataset format support #24

Merged
merged 7 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
## ndx-pose 0.2.0 (Upcoming)

### Breaking changes
- Removed "nodes" and "edges" fields from `PoseEstimation` neurodata type. Create a `Skeleton` object and pass
it to the `skeleton` keyword argument of `PoseEstimation.__init__` instead. @rly (#7)
- Removed the `nodes` and `edges` fields from `PoseEstimation` neurodata type. To specify these,
create a `Skeleton` object with those values, create a `Skeletons` object and pass the `Skeleton`
object to that, and add the `Skeletons` object to your "behavior" processing module. @rly (#7, #24)

### Major changes
- Added support for storing training data in the new `PoseTraining` neurodata type and other new types.
@roomrys, @CBroz1, @rly, @talmo, @eberrigan (#7, #21, #24)

### Minor changes
- Made `PoseEstimation.confidence` optional. @h-mayorquin (#11)
Expand Down
360 changes: 269 additions & 91 deletions README.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# pinned dependencies to reproduce a working development environment
hdmf==3.11.0
pynwb==2.5.0
hdmf==3.12.2
pynwb==2.6.0
31 changes: 21 additions & 10 deletions spec/ndx-pose.extensions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ groups:
doc: Array of pairs of indices corresponding to edges between nodes. Index values
correspond to row indices of the 'nodes' dataset. Index values use 0-indexing.
quantity: '?'
links:
- target_type: Subject
doc: The Subject object in the NWB file, if this Skeleton corresponds to the Subject.
quantity: '?'
- neurodata_type_def: PoseEstimationSeries
neurodata_type_inc: SpatialSeries
doc: Estimated position (x, y) or (x, y, z) of a body part over time.
Expand Down Expand Up @@ -142,9 +146,9 @@ groups:
then `source_video` is required.
required: false
groups:
- name: skeleton_instance
neurodata_type_inc: SkeletonInstance
doc: Position data for a single instance of a skeleton in a single training frame.
- name: skeleton_instances
neurodata_type_inc: SkeletonInstances
doc: Position data for all instances of a skeleton in a single training frame.
links:
- name: source_video
target_type: ImageSeries
Expand All @@ -160,7 +164,7 @@ groups:
neurodata_type_inc: NWBDataInterface
default_name: skeleton_instance
doc: Group that holds ground-truth pose data for a single instance of a skeleton
in a single frame. This is meant to be used within a TrainingFrame.
in a single frame.
attributes:
- name: id
dtype: uint8
Expand Down Expand Up @@ -200,6 +204,16 @@ groups:
- neurodata_type_inc: TrainingFrame
doc: Ground-truth position data for all instances of a skeleton in a single frame.
quantity: '*'
- neurodata_type_def: SkeletonInstances
neurodata_type_inc: NWBDataInterface
default_name: skeleton_instances
doc: Organizational group to hold skeleton instances. This is meant to be used within
a TrainingFrame.
groups:
- neurodata_type_inc: SkeletonInstance
doc: Ground-truth position data for a single instance of a skeleton in a single
training frame.
quantity: '*'
- neurodata_type_def: SourceVideos
neurodata_type_inc: NWBDataInterface
default_name: source_videos
Expand All @@ -210,7 +224,7 @@ groups:
quantity: '*'
- neurodata_type_def: Skeletons
neurodata_type_inc: NWBDataInterface
default_name: skeletons
default_name: Skeletons
doc: Organizational group to hold skeletons.
groups:
- neurodata_type_inc: Skeleton
Expand All @@ -219,12 +233,9 @@ groups:
- neurodata_type_def: PoseTraining
neurodata_type_inc: NWBDataInterface
default_name: PoseTraining
doc: Group that holds images, ground-truth annotations, and metadata for training
a pose estimator.
doc: Group that holds source videos and ground-truth annotations for training a
pose estimator.
groups:
- name: skeletons
neurodata_type_inc: Skeletons
doc: Organizational group to hold skeletons.
- name: training_frames
neurodata_type_inc: TrainingFrames
doc: Organizational group to hold training frames.
Expand Down
6 changes: 6 additions & 0 deletions spec/ndx-pose.namespace.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@ namespaces:
- Alexander Mathis
- Liezl Maree
- Chris Brozdowski
- Heberto Mayorquin
- Talmo Pereira
- Elizabeth Berrigan
contact:
- rly@lbl.gov
- bdichter@lbl.gov
- alexander.mathis@epfl.ch
- lmaree@salk.edu
- cbroz@datajoint.com
- h.mayorquin@gmail.com
- talmo@salk.edu
- eberrigan@salk.edu
doc: NWB extension to store pose estimation data
name: ndx-pose
schema:
Expand Down
5 changes: 4 additions & 1 deletion src/pynwb/ndx_pose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from pynwb import load_namespaces, get_class

# Set path of the namespace.yaml file to the expected install location
ndx_pose_specpath = os.path.join(os.path.dirname(__file__), "spec", "ndx-pose.namespace.yaml")
ndx_pose_specpath = os.path.join(
os.path.dirname(__file__), "spec", "ndx-pose.namespace.yaml"
)

# If the extension has not been installed yet but we are running directly from
# the git repo
Expand Down Expand Up @@ -30,6 +32,7 @@
TrainingFrame,
TrainingFrames,
SkeletonInstance,
SkeletonInstances,
SourceVideos,
PoseTraining,
) # noqa: E402, F401
75 changes: 56 additions & 19 deletions src/pynwb/ndx_pose/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Skeleton = get_class("Skeleton", "ndx-pose")
Skeletons = get_class("Skeletons", "ndx-pose")
SkeletonInstance = get_class("SkeletonInstance", "ndx-pose")
SkeletonInstances = get_class("SkeletonInstances", "ndx-pose")
TrainingFrame = get_class("TrainingFrame", "ndx-pose")
TrainingFrames = get_class("TrainingFrames", "ndx-pose")
SourceVideos = get_class("SourceVideos", "ndx-pose")
Expand All @@ -30,7 +31,9 @@ class PoseEstimationSeries(SpatialSeries):
{
"name": "name",
"type": str,
"doc": ("Name of this PoseEstimationSeries, usually the name of a body part."),
"doc": (
"Name of this PoseEstimationSeries, usually the name of a body part."
),
},
{
"name": "data",
Expand All @@ -47,7 +50,9 @@ class PoseEstimationSeries(SpatialSeries):
"name": "confidence",
"type": ("array_data", "data"),
"shape": (None,),
"doc": ("Confidence or likelihood of the estimated positions, scaled to be between 0 and 1."),
"doc": (
"Confidence or likelihood of the estimated positions, scaled to be between 0 and 1."
),
"default": None,
},
{
Expand All @@ -64,7 +69,8 @@ class PoseEstimationSeries(SpatialSeries):
"name": "confidence_definition",
"type": str,
"doc": (
"Description of how the confidence was computed, e.g., " "'Softmax output of the deep neural network'."
"Description of how the confidence was computed, e.g., "
"'Softmax output of the deep neural network'."
),
"default": None,
},
Expand All @@ -85,7 +91,9 @@ class PoseEstimationSeries(SpatialSeries):
)
def __init__(self, **kwargs):
"""Construct a new PoseEstimationSeries representing pose estimates for a particular body part."""
confidence, confidence_definition = popargs("confidence", "confidence_definition", kwargs)
confidence, confidence_definition = popargs(
"confidence", "confidence_definition", kwargs
)
super().__init__(**kwargs)
self.confidence = confidence
self.confidence_definition = confidence_definition
Expand Down Expand Up @@ -160,7 +168,9 @@ class PoseEstimation(MultiContainerInterface):
"name": "labeled_videos",
"type": ("array_data", "data"),
"shape": (None,),
"doc": ("Paths to the labeled video files. The number of files should equal the number of camera devices."),
"doc": (
"Paths to the labeled video files. The number of files should equal the number of camera devices."
),
"default": None,
},
{
Expand All @@ -185,7 +195,9 @@ class PoseEstimation(MultiContainerInterface):
{
"name": "source_software",
"type": str,
"doc": ("Name of the software tool used. Specifying the version attribute is strongly encouraged."),
"doc": (
"Name of the software tool used. Specifying the version attribute is strongly encouraged."
),
"default": None,
},
{
Expand Down Expand Up @@ -226,11 +238,14 @@ def __init__(self, **kwargs):
nodes, edges, skeleton = popargs("nodes", "edges", "skeleton", kwargs)
if nodes is not None or edges is not None:
if skeleton is not None:
raise ValueError("Cannot specify both 'nodes' and 'edges' and 'skeleton'.")
raise ValueError(
"Cannot specify both 'nodes' and 'edges' and 'skeleton'."
)
skeleton = Skeleton(name="subject", nodes=nodes, edges=edges)
warnings.warn(
"The 'nodes' and 'edges' arguments are deprecated. Please use the 'skeleton' argument instead.",
DeprecationWarning, stacklevel=2
DeprecationWarning,
stacklevel=2,
)

# devices must be added to the NWBFile before being linked to from a PoseEstimation object.
Expand All @@ -243,10 +258,16 @@ def __init__(self, **kwargs):
"All devices linked to from a PoseEstimation object must be added to the NWBFile first."
)

pose_estimation_series, description = popargs("pose_estimation_series", "description", kwargs)
original_videos, labeled_videos = popargs("original_videos", "labeled_videos", kwargs)
pose_estimation_series, description = popargs(
"pose_estimation_series", "description", kwargs
)
original_videos, labeled_videos = popargs(
"original_videos", "labeled_videos", kwargs
)
dimensions, scorer = popargs("dimensions", "scorer", kwargs)
source_software, source_software_version = popargs("source_software", "source_software_version", kwargs)
source_software, source_software_version = popargs(
"source_software", "source_software_version", kwargs
)
super().__init__(**kwargs)
self.pose_estimation_series = pose_estimation_series
self.description = description
Expand All @@ -266,25 +287,41 @@ def __init__(self, **kwargs):
# TODO validate that the nodes correspond to the names of the pose estimation series objects

# validate that len(original_videos) == len(labeled_videos) == len(dimensions) == len(cameras)
if original_videos is not None and (devices is None or len(original_videos) != len(devices)):
raise ValueError("The number of original videos should equal the number of camera devices.")
if labeled_videos is not None and (devices is None or len(labeled_videos) != len(devices)):
raise ValueError("The number of labeled videos should equal the number of camera devices.")
if dimensions is not None and (devices is None or len(dimensions) != len(devices)):
raise ValueError("The number of dimensions should equal the number of camera devices.")
if original_videos is not None and (
devices is None or len(original_videos) != len(devices)
):
raise ValueError(
"The number of original videos should equal the number of camera devices."
)
if labeled_videos is not None and (
devices is None or len(labeled_videos) != len(devices)
):
raise ValueError(
"The number of labeled videos should equal the number of camera devices."
)
if dimensions is not None and (
devices is None or len(dimensions) != len(devices)
):
raise ValueError(
"The number of dimensions should equal the number of camera devices."
)

@property
def nodes(self):
return self.skeleton.nodes

@nodes.setter
def nodes(self, value):
raise ValueError("'nodes' is deprecated. Please use the 'skeleton' field instead.")
raise ValueError(
"'nodes' is deprecated. Please use the 'skeleton' field instead."
)

@property
def edges(self):
return self.skeleton.edges

@edges.setter
def edges(self, value):
raise ValueError("'edges' is deprecated. Please use the 'skeleton' field instead.")
raise ValueError(
"'edges' is deprecated. Please use the 'skeleton' field instead."
)
Loading