Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for more .nwb training data to SLEAP #104

Open
wants to merge 92 commits into
base: main
Choose a base branch
from

Conversation

keyaloding
Copy link

@keyaloding keyaloding commented Jul 12, 2024

Description

Right now, data from NWB (Neurodata without Borders) is partially supported in sleap-io. The PoseEstimation and PoseEstimationSeries data structures are supported, but the TrainingFrame, TrainingFrames, PoseTraining, and SourceVideos structures are not. These data structures correspond to data structures in SLEAP as shown in rly/ndx-pose#24. I have added support for these by allowing the user to export SLEAP training data as NWB training data.
I have also updated the README with an example of how to use this feature.

Types of changes

  • Bugfix
  • New feature
  • Refactor / Code style update (no logical changes)
  • Build / CI changes
  • Documentation Update
  • Other (explain)

Does this address any currently open issues?

#100, #86, rly/ndx-pose#29

Summary by CodeRabbit

  • New Features

    • Enhanced conversion capabilities between SLEAP and NWB formats with several new functions.
    • Improved data handling in the file-saving process, allowing for more flexible appending of training data.
    • Introduced functionality for saving labels to NWB files with options for training datasets and frame extraction.
    • Added new tests to validate data conversion integrity.
  • Bug Fixes

    • Improved robustness of data handling in file operations, ensuring consistency and reliability.
  • Refactor

    • Revised existing functions to optimize their performance and compatibility with NWB data structures.

Copy link
Contributor

coderabbitai bot commented Jul 12, 2024

Walkthrough

The recent updates enhance interoperability between SLEAP and NWB formats by introducing new functions for handling pose training and skeleton data. These changes improve data management, incorporate better error handling, and refine the loading and saving processes, ensuring seamless integration for users working with labeled datasets.

Changes

File Change Summary
environment.yml Replaced build dependency with python-build.
pyproject.toml Updated ndx-pose dependency to a specific commit hash from GitHub for improved version control.
sleap_io/io/nwb.py Introduced new functions for converting pose training data and skeleton structures; refined existing functions for better metadata management and data handling.
tests/io/test_main.py Added test_nwb_training and modified test_nwb to accommodate new saving behaviors.
tests/io/test_nwb.py Added test_video_to_path and test_slp_to_nwb_conversion to validate new conversion functionalities and data integrity.
README.md Updated the example for saving labels to NWB files to include new parameters for training data.
docs/index.md Added comments clarifying the new functionality for saving labels and converting SLEAP training data.
sleap_io/io/main.py Modified save_nwb and save_file functions to handle new parameters for training datasets and frame extraction.

Possibly related issues

  • Support more NWB training data types #100: The changes support additional NWB training data types, including PoseTraining, which aligns with the objectives of expanding data handling capabilities as discussed in the issue.

Possibly related PRs

Poem

In the meadow where code does play,
SLEAP and NWB dance through the day.
Functions bloom like springtime flowers,
Data flows in newfound powers.
With each change, a leap we take,
In the world of code, joy we make! 🌼🐇


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@keyaloding keyaloding marked this pull request as ready for review July 12, 2024 16:43
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Outside diff range, codebase verification and nitpick comments (1)
sleap_io/io/main.py (1)

50-66: Ensure load_nwb function calls match the new signature

The load_nwb function calls in tests/io/test_main.py do not match the new signature. The as_training parameter is missing in these calls.

  • tests/io/test_main.py, lines where load_nwb is called:
    • loaded_labels = load_nwb(tmp_path / "test_nwb.nwb") (twice)

Please update these calls to include the as_training parameter.

Analysis chain

LGTM! But verify the function usage in the codebase.

The code changes are approved.

However, ensure that all function calls to load_nwb match the new signature.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify all function calls to `load_nwb` match the new signature.

# Test: Search for the function usage. Expect: Only occurances of the new signature.
rg --type python -A 5 $'load_nwb'

Length of output: 63


Script:

#!/bin/bash
# Description: Verify all function calls to `load_nwb` match the new signature.

# Test: Search for the function usage. Expect: Only occurrences of the new signature.
rg --type py -A 5 $'load_nwb'

Length of output: 1954


Script:

#!/bin/bash
# Description: Verify that the function calls to `load_nwb` in tests/io/test_main.py match the new signature.

# Test: Extract lines of code where `load_nwb` is called in tests/io/test_main.py.
rg 'load_nwb' tests/io/test_main.py -A 5 -B 5

Length of output: 908

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 95b9c81 and f2fd6f5.

Files selected for processing (3)
  • pyproject.toml (1 hunks)
  • sleap_io/io/main.py (3 hunks)
  • sleap_io/io/nwb.py (9 hunks)
Files skipped from review due to trivial changes (1)
  • pyproject.toml
Additional comments not posted (7)
sleap_io/io/main.py (2)

69-93: Verify the logic for handling as_training argument.

The logic for handling the as_training argument is correct and well-implemented. The function correctly differentiates between saving as a training dataset and saving as a regular dataset based on the value of as_training.


69-93: LGTM! But verify the function usage in the codebase.

The code changes are approved.

However, ensure that all function calls to save_nwb match the new signature.

Verification successful

Function calls to save_nwb match the new signature.

The function calls to save_nwb in the following files match the new signature:

  • sleap_io/io/main.py
  • tests/io/test_main.py

No further issues were found.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify all function calls to `save_nwb` match the new signature.

# Test: Search for the function usage. Expect: Only occurances of the new signature.
rg --type python -A 5 $'save_nwb'

Length of output: 63


Script:

#!/bin/bash
# Description: Verify all function calls to `save_nwb` match the new signature.

# Test: Search for the function usage. Expect: Only occurrences of the new signature.
rg --glob '*.py' -A 5 $'save_nwb'

Length of output: 2089


Script:

#!/bin/bash
# Description: Extract lines of code where `save_nwb` is used to verify the function signature.

# Extract lines from sleap_io/__init__.py
rg --glob 'sleap_io/__init__.py' -A 5 $'save_nwb'

# Extract lines from sleap_io/io/main.py
rg --glob 'sleap_io/io/main.py' -A 5 $'save_nwb'

# Extract lines from tests/io/test_main.py
rg --glob 'tests/io/test_main.py' -A 5 $'save_nwb'

Length of output: 2191

sleap_io/io/nwb.py (5)

92-144: LGTM!

The function correctly converts a SLEAP Labels object to an NWB PoseTraining object.


147-169: LGTM!

The function correctly converts a SLEAP Skeleton to an NWB Skeleton.


172-191: LGTM!

The function correctly converts a SLEAP Instance to an NWB SkeletonInstance.


194-216: LGTM!

The function correctly converts a list of SLEAP Videos to NWB SourceVideos.


394-411: LGTM!

The function correctly handles the writing of PoseTraining data to an NWB file, including metadata handling.

Comment on lines 50 to 66
def load_nwb(filename: str, as_training: Optional[bool]=None) -> Labels:
"""Load an NWB dataset as a SLEAP `Labels` object.
Args:
filename: Path to a NWB file (`.nwb`).
as_training: If `True`, load the dataset as a training dataset.
Returns:
The dataset as a `Labels` object.
"""
return nwb.read_nwb(filename)
if as_training is None:
return

if as_training:
return nwb.read_nwb_training(filename)
else:
return nwb.read_nwb(filename)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify the logic for handling as_training argument.

The logic for handling the as_training argument is correct, but the initial check for as_training is None is redundant. The function should directly check for the value of as_training without the initial if condition.

-  if as_training is None:
-    return 
-  
  if as_training:
    return nwb.read_nwb_training(filename)
  else:
    return nwb.read_nwb(filename)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def load_nwb(filename: str, as_training: Optional[bool]=None) -> Labels:
"""Load an NWB dataset as a SLEAP `Labels` object.
Args:
filename: Path to a NWB file (`.nwb`).
as_training: If `True`, load the dataset as a training dataset.
Returns:
The dataset as a `Labels` object.
"""
return nwb.read_nwb(filename)
if as_training is None:
return
if as_training:
return nwb.read_nwb_training(filename)
else:
return nwb.read_nwb(filename)
if as_training:
return nwb.read_nwb_training(filename)
else:
return nwb.read_nwb(filename)

Comment on lines 48 to 71
def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return]
"""Creates a Labels object from an NWB PoseTraining object.
Args:
pose_training: An NWB PoseTraining object.
Returns:
A Labels object.
"""
labeled_frames = []
for training_frame in pose_training.training_frames:
video = Video(filename=f"{training_frame.source_videos}")
frame_idx = training_frame # TODO
instances = [
Instance.from_numpy(
points=instance.node_locations,
skeleton=nwb_skeleton_to_sleap(instance.skeleton),
)
for instance in training_frame.skeleton_instances
]
labeled_frames.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)
)
return Labels(labeled_frames=labeled_frames)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incomplete functionality in pose_training_to_labels.

The function contains a TODO comment indicating incomplete functionality for determining frame_idx from training_frame.

-  frame_idx = training_frame # TODO
+  frame_idx = training_frame.frame_index  # Assuming `frame_index` is an attribute of `TrainingFrame`

Do you want me to assist in completing this functionality or open a GitHub issue to track this task?

Committable suggestion was skipped due to low confidence.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between f2fd6f5 and 6c062ec.

Files selected for processing (1)
  • sleap_io/io/nwb.py (9 hunks)
Additional comments not posted (8)
sleap_io/io/nwb.py (8)

18-19: Update imports to include ImageSeries.

The addition of ImageSeries from pynwb.image is appropriate given the new functionalities related to handling image data. This aligns with the PR's objective of enhancing NWB data handling.


39-43: Review the renaming of Skeleton to SLEAPSkeleton.

Renaming Skeleton to SLEAPSkeleton helps avoid naming conflicts with Skeleton from ndx_pose. This is a good practice to prevent ambiguity in large codebases.


172-191: Review the function instance_to_skeleton_instance.

Converts a SLEAP Instance to an NWB SkeletonInstance. The function correctly handles the conversion of points to node locations and ensures the visibility of nodes is maintained. The use of name_generator for the instance name and the setting of an arbitrary ID (np.uint(10)) should be reviewed for appropriateness and potential conflicts.

Consider verifying the use of name_generator and the rationale behind the fixed ID setting.


48-71: Review the function pose_training_to_labels.

This function converts PoseTraining objects to Labels. The loop iterates over training_frames, creating LabeledFrame instances. Each instance is created using data from training_frame, including video filename and frame index. The use of Instance.from_numpy suggests that the conversion handles the data appropriately.

However, ensure that the source_video_frame_index and other attributes are correctly mapped and exist in the NWB structure as expected.

Verification successful

The attributes source_video_frame_index, source_video, and skeleton_instances are correctly mapped and exist in the NWB structure as expected.

  • The attributes are used appropriately within the NWB and SLEAP structures, ensuring the conversion in pose_training_to_labels is handled correctly.
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the existence and correct mapping of attributes in NWB structure.

# Test: Search for the usage of `source_video_frame_index` and related attributes in the codebase. Expect: Proper usage and definitions.
rg --type python -A 5 $'source_video_frame_index'

Length of output: 79


Script:

#!/bin/bash
# Description: Verify the existence and correct mapping of attributes in NWB structure.

# Test: Search for the usage of `source_video_frame_index` and related attributes in the codebase. Expect: Proper usage and definitions.
rg --type py -A 5 'source_video_frame_index'
rg --type py -A 5 'source_video'
rg --type py -A 5 'skeleton_instances'

Length of output: 12535


74-89: Review the function nwb_skeleton_to_sleap.

The function converts an NWB skeleton to a SLEAP skeleton. The conversion of nodes and edges is straightforward and uses list comprehensions. The function appears to handle the data correctly according to the provided structure and types.

Ensure that the nodes and edges attributes are correctly utilized in the NWB and SLEAP data structures.

Verification successful

The nodes and edges attributes are correctly utilized in the NWB and SLEAP data structures.

The search results confirm extensive usage and testing of these attributes throughout the codebase, indicating that the function nwb_skeleton_to_sleap handles them correctly.

  • Verified in:
    • tests/io/test_nwb.py
    • tests/model/test_skeleton.py
    • sleap_io/io/nwb.py
    • sleap_io/io/slp.py
    • sleap_io/model/skeleton.py
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the correct utilization of `nodes` and `edges` attributes in the data structures.

# Test: Search for the usage of `nodes` and `edges` in the codebase. Expect: Correct utilization in NWB and SLEAP structures.
rg --type python -A 5 $'nodes|edges'

Length of output: 68


Script:

#!/bin/bash
# Description: Verify the correct utilization of `nodes` and `edges` attributes in the data structures.

# Test: Search for the usage of `nodes` and `edges` in the codebase. Expect: Correct utilization in NWB and SLEAP structures.
rg -t py -A 5 'nodes|edges'

Length of output: 44681


92-144: Review the function labels_to_pose_training.

This function creates an NWB PoseTraining object from a Labels object. It iterates over labeled frames, extracting instances and converting them using instance_to_skeleton_instance. The creation of TrainingFrame and SourceVideos is handled well, with attention to detail in metadata and video handling.

However, the function uses a custom name_generator which might introduce non-determinism or conflicts. Verify the implementation of this generator to ensure it produces unique and appropriate names without conflicts.


194-216: Review the function videos_to_source_videos.

This function converts a list of SLEAP Videos to NWB SourceVideos. The handling of video metadata and the creation of ImageSeries objects are done correctly. The function ensures that each video is represented as an ImageSeries, preserving essential metadata like dimensions and rate.

Ensure that the video conversion process does not lose any critical information and that the ImageSeries objects are correctly utilized in the NWB structure.

Verification successful

Review the function videos_to_source_videos.

The function converts a list of SLEAP Videos to NWB SourceVideos, ensuring that each video is represented as an ImageSeries with essential metadata like dimensions and rate. The presence of test cases in tests/io/test_nwb.py that verify the dimensions and rate attributes confirms that the video conversion process is correctly handled and no critical information is lost.

  • Relevant Test Cases:
    • tests/io/test_nwb.py: Contains tests verifying dimensions, rate, and other related attributes.
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the correct utilization and information preservation in video conversion.

# Test: Search for the usage of `ImageSeries` and related attributes in the codebase. Expect: Correct utilization and information preservation.
rg --type python -A 5 $'ImageSeries|dimension|rate'

Length of output: 83


Script:

#!/bin/bash
# Description: Verify the correct utilization and information preservation in video conversion.

# Test: Search for the usage of `ImageSeries` and related attributes in the codebase. Expect: Correct utilization and information preservation.
rg -A 5 'ImageSeries|dimension|rate'

Length of output: 29675


147-169: Review the function slp_skeleton_to_nwb.

Converts a SLEAP skeleton to an NWB skeleton. The function handles the conversion of nodes and edges, ensuring that the data is correctly transformed. The use of a dictionary comprehension for skeleton_edges and the subsequent edge conversion is efficient.

Ensure that the conversion logic correctly preserves the relationships and data integrity between the original and converted structures.

Verification successful

Verification successful for slp_skeleton_to_nwb function.

The slp_skeleton_to_nwb function is correctly integrated into the instance_to_skeleton_instance function, ensuring that the relationships and data integrity are preserved during the conversion process.

  • sleap_io/io/nwb.py: Verified usage and context of slp_skeleton_to_nwb within instance_to_skeleton_instance.
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the data integrity and relationship preservation in skeleton conversion.

# Test: Search for the usage and results of skeleton conversion in the codebase. Expect: Correct data integrity and relationship preservation.
rg --type python -A 5 $'skeleton_edges|nwb_edges'

Length of output: 81


Script:

#!/bin/bash
# Description: Verify the data integrity and relationship preservation in skeleton conversion.

# Test: Search for the usage and results of skeleton conversion in the codebase. Expect: Correct data integrity and relationship preservation.
rg -A 5 'skeleton_edges|nwb_edges'

Length of output: 1120


Script:

#!/bin/bash
# Description: Verify the data integrity and relationship preservation in skeleton conversion.

# Test: Search for the usage of the slp_skeleton_to_nwb function in the codebase. Expect: Correct data integrity and relationship preservation.
rg 'slp_skeleton_to_nwb'

Length of output: 205


Script:

#!/bin/bash
# Description: Verify the data integrity and relationship preservation in skeleton conversion.

# Test: Search for the context around the call to slp_skeleton_to_nwb(instance.skeleton) in sleap_io/io/nwb.py. Expect: Correct data integrity and relationship preservation.
rg -A 10 -B 10 'slp_skeleton_to_nwb(instance.skeleton)' sleap_io/io/nwb.py

Length of output: 76


Script:

#!/bin/bash
# Description: Verify the data integrity and relationship preservation in skeleton conversion.

# Test: Search for the function call slp_skeleton_to_nwb in sleap_io/io/nwb.py with broader context. Expect: Correct data integrity and relationship preservation.
rg -A 10 -B 10 'slp_skeleton_to_nwb' sleap_io/io/nwb.py

Length of output: 1351

Comment on lines 394 to 411
def write_nwb_training(pose_training: PoseTraining, # type: ignore[return]
nwbfile_path: str,
nwb_file_kwargs: Optional[dict],
pose_estimation_metadata: Optional[dict] = None,
):
"""Writes data from a `PoseTraining` object to an NWB file.
Args:
pose_training: A `PoseTraining` object.
nwbfile_path: The path where the nwb file is to be written.
"""
nwb_file_kwargs = nwb_file_kwargs or {}

nwbfile = NWBFile(**nwb_file_kwargs)
nwbfile = append_nwb_data(pose_training, nwbfile, pose_estimation_metadata)
with NWBHDF5IO(str(nwbfile_path), "w") as io:
io.write(nwbfile)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review the function write_nwb_training.

This function writes data from a PoseTraining object to an NWB file. The function setup and the handling of optional metadata are done correctly. The use of append_nwb_data to integrate the training data into the NWB file is appropriate.

However, the function lacks error handling for potential issues during the file writing process. Consider adding error handling mechanisms to improve robustness.

+  try:
       nwbfile = NWBFile(**nwb_file_kwargs)
       nwbfile = append_nwb_data(pose_training, nwbfile, pose_estimation_metadata)
       with NWBHDF5IO(str(nwbfile_path), "w") as io:
           io.write(nwbfile)
+  except Exception as e:
+      # Handle potential file writing errors
+      print(f"Error writing NWB file: {e}")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def write_nwb_training(pose_training: PoseTraining, # type: ignore[return]
nwbfile_path: str,
nwb_file_kwargs: Optional[dict],
pose_estimation_metadata: Optional[dict] = None,
):
"""Writes data from a `PoseTraining` object to an NWB file.
Args:
pose_training: A `PoseTraining` object.
nwbfile_path: The path where the nwb file is to be written.
"""
nwb_file_kwargs = nwb_file_kwargs or {}
nwbfile = NWBFile(**nwb_file_kwargs)
nwbfile = append_nwb_data(pose_training, nwbfile, pose_estimation_metadata)
with NWBHDF5IO(str(nwbfile_path), "w") as io:
io.write(nwbfile)
try:
nwbfile = NWBFile(**nwb_file_kwargs)
nwbfile = append_nwb_data(pose_training, nwbfile, pose_estimation_metadata)
with NWBHDF5IO(str(nwbfile_path), "w") as io:
io.write(nwbfile)
except Exception as e:
# Handle potential file writing errors
print(f"Error writing NWB file: {e}")

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Outside diff range, codebase verification and nitpick comments (9)
tests/io/test_nwb.py (9)

Line range hint 45-67: Review of test_typical_case_append

This function effectively tests the appending of SLEAP data to NWB files, ensuring that the structure and naming conventions are maintained. It might be beneficial to extract some of the repeated logic (like processing module extraction and naming) into helper functions for better readability and maintainability.


Line range hint 69-91: Review of test_typical_case_append_with_metadata_propagation

This test function correctly ensures that metadata is propagated during the data appending process. The use of explicit metadata values for testing is a good practice. However, consider parameterizing these metadata values to facilitate easier updates and expansions of the test cases in the future.


Line range hint 93-106: Review of test_provenance_writing

The function effectively tests the propagation of provenance information, ensuring data traceability. The structure of the test is clear, but consider adding more detailed comments to explain the significance of each assertion for future maintainers.


Line range hint 108-130: Review of test_default_metadata_overwriting

This test function correctly checks that default metadata values can be overwritten, allowing for customization. The test is well-structured, but consider adding more detailed comments to explain the significance of each assertion for future maintainers.


Line range hint 132-160: Review of test_complex_case_append

This function effectively handles the appending of complex case data, ensuring correct structure and naming within the NWB file. The test is comprehensive, but consider extracting some of the logic into helper functions for better readability and maintainability.


Line range hint 162-194: Review of test_complex_case_append_with_timestamps_metadata

This test function correctly ensures that timestamps metadata is propagated and used accurately. The test is well-structured, but consider adding more detailed comments to explain the significance of each assertion for future maintainers.


Line range hint 196-203: Review of test_assertion_with_no_predicted_instance

This test function correctly handles the scenario where no predicted instances are found, ensuring robust error handling. The use of explicit error matching is a good practice. Consider adding more detailed comments to explain the significance of this test for future maintainers.


Line range hint 205-216: Review of test_typical_case_write

This test function ensures that typical case data is correctly written to an NWB file. The test is well-structured, but consider adding more detailed comments to explain the significance of each assertion for future maintainers.


Line range hint 218-235: Review of test_get_timestamps

This test function correctly retrieves timestamps from data series, ensuring accurate data handling. The test is well-structured, but consider adding more detailed comments to explain the significance of each assertion for future maintainers.

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 6c062ec and f7d8a0c.

Files selected for processing (1)
  • tests/io/test_nwb.py (2 hunks)
Additional comments not posted (2)
tests/io/test_nwb.py (2)

9-15: Review of new imports and functions

The imports and new function declarations for pose_training_to_labels and labels_to_pose_training are correctly placed and follow the Python convention. However, ensure these functions are used appropriately across the codebase and that their implementation is aligned with the project's standards.


33-43: Review of test_nwb_slp_conversion

This test function correctly checks the integrity of data through conversions. It's important to ensure that not only the count of instances and labeled frames are maintained, but also that the content of these structures is unchanged. Consider adding assertions for deeper properties of the instances if not already covered elsewhere.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between f7d8a0c and 23a5a83.

Files selected for processing (1)
  • tests/io/test_nwb.py (2 hunks)
Files skipped from review as they are similar to previous changes (1)
  • tests/io/test_nwb.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 23a5a83 and cf4bcf7.

Files selected for processing (1)
  • sleap_io/io/nwb.py (9 hunks)

Comment on lines 530 to 540
def append_nwb_training(pose_training: PoseTraining, nwbfile_path: str) -> NWBFile: # type: ignore[return]
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implement the append_nwb_training function.

Currently, the function raises NotImplementedError. This needs to be addressed either by implementing the function or by documenting it as a known limitation.

- raise NotImplementedError
+ # TODO: Implement or document this function
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def append_nwb_training(pose_training: PoseTraining, nwbfile_path: str) -> NWBFile: # type: ignore[return]
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
raise NotImplementedError
def append_nwb_training(pose_training: PoseTraining, nwbfile_path: str) -> NWBFile: # type: ignore[return]
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
# TODO: Implement or document this function

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between cf4bcf7 and 8948e92.

Files selected for processing (1)
  • sleap_io/io/nwb.py (9 hunks)
Files skipped from review as they are similar to previous changes (1)
  • sleap_io/io/nwb.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 8948e92 and be6ccbc.

Files selected for processing (2)
  • save_test.ipynb (1 hunks)
  • sleap_io/io/main.py (3 hunks)
Additional comments not posted (1)
sleap_io/io/main.py (1)

50-63: Remove redundant check for as_training.

The previous review comment suggested removing the redundant check for as_training is None, which seems to have been addressed. However, the logic can be further simplified by directly returning the function based on the as_training flag.

def load_nwb(filename: str, as_training: Optional[bool]=None) -> Labels:
    """Load an NWB dataset as a SLEAP `Labels` object.

    Args:
        filename: Path to a NWB file (`.nwb`).
        as_training: If `True`, load the dataset as a training dataset.

    Returns:
        The dataset as a `Labels` object.
    """
-    if as_training:
-        return nwb.read_nwb_training(filename)
-    else:
-        return nwb.read_nwb(filename)
+    return nwb.read_nwb_training(filename) if as_training else nwb.read_nwb(filename)

Comment on lines 66 to 89
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
See also: nwb.write_nwb, nwb.append_nwb
"""
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
else:
nwb.write_nwb(labels, filename)
if as_training:
pose_training = nwb.labels_to_pose_training(labels, **kwargs)
if append and Path(filename).exists():
nwb.append_nwb_training(pose_training, filename, **kwargs)
else:
nwb.write_nwb_training(pose_training, filename, **kwargs)

else:
if append and Path(filename).exists():
nwb.append_nwb(labels, filename, **kwargs)
else:
nwb.write_nwb(labels, filename)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor conditional logic in save_nwb.

The function save_nwb handles saving of NWB datasets with an optional training mode. The logic can be refactored to reduce the depth of conditionals and improve readability.

def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs):
    """Save a SLEAP dataset to NWB format.

    Args:
        labels: A SLEAP `Labels` object (see `load_slp`).
        filename: Path to NWB file to save to. Must end in `.nwb`.
        as_training: If `True`, save the dataset as a training dataset.
        append: If `True` (the default), append to existing NWB file. File will be
            created if it does not exist.

    See also: nwb.write_nwb, nwb.append_nwb
    """
-    if as_training:
-        pose_training = nwb.labels_to_pose_training(labels, **kwargs)
-        if append and Path(filename).exists():
-            nwb.append_nwb_training(pose_training, filename, **kwargs)
-        else:
-            nwb.write_nwb_training(pose_training, filename, **kwargs)
-
-    else:
-        if append and Path(filename).exists():
-            nwb.append_nwb(labels, filename, **kwargs)
-        else:
-            nwb.write_nwb(labels, filename)
+    func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
+    action = nwb.append_nwb_training if as_training else nwb.append_nwb
+    write = nwb.write_nwb_training if as_training else nwb.write_nwb
+
+    data = func(labels, **kwargs)
+    if append and Path(filename).exists():
+        action(data, filename, **kwargs)
+    else:
+        write(data, filename, **kwargs)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
See also: nwb.write_nwb, nwb.append_nwb
"""
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
else:
nwb.write_nwb(labels, filename)
if as_training:
pose_training = nwb.labels_to_pose_training(labels, **kwargs)
if append and Path(filename).exists():
nwb.append_nwb_training(pose_training, filename, **kwargs)
else:
nwb.write_nwb_training(pose_training, filename, **kwargs)
else:
if append and Path(filename).exists():
nwb.append_nwb(labels, filename, **kwargs)
else:
nwb.write_nwb(labels, filename)
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
See also: nwb.write_nwb, nwb.append_nwb
"""
func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
action = nwb.append_nwb_training if as_training else nwb.append_nwb
write = nwb.write_nwb_training if as_training else nwb.write_nwb
data = func(labels, **kwargs)
if append and Path(filename).exists():
action(data, filename, **kwargs)
else:
write(data, filename, **kwargs)

save_test.ipynb Outdated
Comment on lines 14 to 42
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'tests/data/slp/minimal_instance.pkg.slp'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01msleap_io\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01msio\u001b[39;00m\n\u001b[1;32m 3\u001b[0m labels_original \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_slp(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtests/data/slp/minimal_instance.pkg.slp\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m \u001b[43mlabels_original\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mminimal_instance.pkg.nwb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnwb_training\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m labels_loaded \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_nwb(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mminimal_instance.pkg.nwb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/salk/io_fork/sleap_io/model/labels.py:372\u001b[0m, in \u001b[0;36mLabels.save\u001b[0;34m(self, filename, format, embed, **kwargs)\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Save labels to file in specified format.\u001b[39;00m\n\u001b[1;32m 349\u001b[0m \n\u001b[1;32m 350\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 368\u001b[0m \u001b[38;5;124;03m This argument is only valid for the SLP backend.\u001b[39;00m\n\u001b[1;32m 369\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 370\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msleap_io\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m save_file\n\u001b[0;32m--> 372\u001b[0m \u001b[43msave_file\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43membed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/main.py:241\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(labels, filename, format, **kwargs)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mformat\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabelstudio\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 240\u001b[0m save_labelstudio(labels, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 241\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mformat\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mjabs\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 242\u001b[0m pose_version \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpose_version\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m5\u001b[39m)\n\u001b[1;32m 243\u001b[0m save_jabs(labels, pose_version, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/main.py:92\u001b[0m, in \u001b[0;36msave_nwb\u001b[0;34m(labels, filename, as_training, append, **kwargs)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 89\u001b[0m nwb\u001b[38;5;241m.\u001b[39mwrite_nwb(labels, filename)\n\u001b[0;32m---> 92\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mload_labelstudio\u001b[39m(\n\u001b[1;32m 93\u001b[0m filename: \u001b[38;5;28mstr\u001b[39m, skeleton: Optional[Union[Skeleton, \u001b[38;5;28mlist\u001b[39m[\u001b[38;5;28mstr\u001b[39m]]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 94\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Labels:\n\u001b[1;32m 95\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Read Label Studio-style annotations from a file and return a `Labels` object.\u001b[39;00m\n\u001b[1;32m 96\u001b[0m \n\u001b[1;32m 97\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[38;5;124;03m Parsed labels as a `Labels` instance.\u001b[39;00m\n\u001b[1;32m 106\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m labelstudio\u001b[38;5;241m.\u001b[39mread_labels(filename, skeleton\u001b[38;5;241m=\u001b[39mskeleton)\n",
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:388\u001b[0m, in \u001b[0;36mwrite_nwb\u001b[0;34m(labels, nwbfile_path, nwb_file_kwargs, pose_estimation_metadata)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrite_nwb\u001b[39m(\n\u001b[1;32m 355\u001b[0m labels: Labels,\n\u001b[1;32m 356\u001b[0m nwbfile_path: \u001b[38;5;28mstr\u001b[39m,\n\u001b[1;32m 357\u001b[0m nwb_file_kwargs: Optional[\u001b[38;5;28mdict\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 358\u001b[0m pose_estimation_metadata: Optional[\u001b[38;5;28mdict\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 359\u001b[0m ):\n\u001b[1;32m 360\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Write labels to an nwb file and save it to the nwbfile_path given.\u001b[39;00m\n\u001b[1;32m 361\u001b[0m \n\u001b[1;32m 362\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[1;32m 363\u001b[0m \u001b[38;5;124;03m labels: A general `Labels` object.\u001b[39;00m\n\u001b[1;32m 364\u001b[0m \u001b[38;5;124;03m nwbfile_path: The path where the nwb file is to be written.\u001b[39;00m\n\u001b[1;32m 365\u001b[0m \u001b[38;5;124;03m nwb_file_kwargs: A dict containing metadata to the nwbfile. Example:\u001b[39;00m\n\u001b[1;32m 366\u001b[0m \u001b[38;5;124;03m nwb_file_kwargs = {\u001b[39;00m\n\u001b[1;32m 367\u001b[0m \u001b[38;5;124;03m 'session_description: 'your_session_description',\u001b[39;00m\n\u001b[1;32m 368\u001b[0m \u001b[38;5;124;03m 'identifier': 'your session_identifier',\u001b[39;00m\n\u001b[1;32m 369\u001b[0m \u001b[38;5;124;03m }\u001b[39;00m\n\u001b[1;32m 370\u001b[0m \u001b[38;5;124;03m For a full list of possible values see:\u001b[39;00m\n\u001b[1;32m 371\u001b[0m \u001b[38;5;124;03m https://pynwb.readthedocs.io/en/stable/pynwb.file.html#pynwb.file.NWBFile\u001b[39;00m\n\u001b[1;32m 372\u001b[0m \n\u001b[1;32m 373\u001b[0m \u001b[38;5;124;03m Defaults to None and default values are used to generate the nwb file.\u001b[39;00m\n\u001b[1;32m 374\u001b[0m \n\u001b[1;32m 375\u001b[0m \u001b[38;5;124;03m pose_estimation_metadata: This argument has a dual purpose:\u001b[39;00m\n\u001b[1;32m 376\u001b[0m \n\u001b[1;32m 377\u001b[0m \u001b[38;5;124;03m 1) It can be used to pass time information about the video which is\u001b[39;00m\n\u001b[1;32m 378\u001b[0m \u001b[38;5;124;03m necessary for synchronizing frames in pose estimation tracking to other\u001b[39;00m\n\u001b[1;32m 379\u001b[0m \u001b[38;5;124;03m modalities. Either the video timestamps can be passed to\u001b[39;00m\n\u001b[1;32m 380\u001b[0m \u001b[38;5;124;03m This can be used to pass the timestamps with the key `video_timestamps`\u001b[39;00m\n\u001b[1;32m 381\u001b[0m \u001b[38;5;124;03m or the sampling rate with key`video_sample_rate`.\u001b[39;00m\n\u001b[1;32m 382\u001b[0m \n\u001b[1;32m 383\u001b[0m \u001b[38;5;124;03m e.g. pose_estimation_metadata[\"video_timestamps\"] = np.array(timestamps)\u001b[39;00m\n\u001b[1;32m 384\u001b[0m \u001b[38;5;124;03m or pose_estimation_metadata[\"video_sample_rate\"] = 15 # In Hz\u001b[39;00m\n\u001b[1;32m 385\u001b[0m \n\u001b[1;32m 386\u001b[0m \u001b[38;5;124;03m 2) The other use of this dictionary is to ovewrite sleap-io default\u001b[39;00m\n\u001b[1;32m 387\u001b[0m \u001b[38;5;124;03m arguments for the PoseEstimation container.\u001b[39;00m\n\u001b[0;32m--> 388\u001b[0m \u001b[38;5;124;03m see https://github.com/rly/ndx-pose for a full list or arguments.\u001b[39;00m\n\u001b[1;32m 389\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m 390\u001b[0m nwb_file_kwargs \u001b[38;5;241m=\u001b[39m nwb_file_kwargs \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mdict\u001b[39m()\n\u001b[1;32m 392\u001b[0m \u001b[38;5;66;03m# Add required values for nwbfile if not present\u001b[39;00m\n",
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:471\u001b[0m, in \u001b[0;36mappend_nwb_data\u001b[0;34m(labels, nwbfile, pose_estimation_metadata)\u001b[0m\n\u001b[1;32m 469\u001b[0m labels_data_df = convert_predictions_to_dataframe(labels)\n\u001b[1;32m 470\u001b[0m break\n\u001b[0;32m--> 471\u001b[0m else:\n\u001b[1;32m 472\u001b[0m labels_data_df = pd.DataFrame()\n\u001b[1;32m 474\u001b[0m # For every video create a processing module\n",
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/pandas/core/frame.py:4102\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 4100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mnlevels \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 4101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> 4102\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4103\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m 4104\u001b[0m indexer \u001b[38;5;241m=\u001b[39m [indexer]\n",
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/pandas/core/indexes/range.py:417\u001b[0m, in \u001b[0;36mRangeIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[1;32m 416\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, Hashable):\n\u001b[0;32m--> 417\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key)\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n\u001b[1;32m 419\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key)\n",
"\u001b[0;31mKeyError\u001b[0m: 'tests/data/slp/minimal_instance.pkg.slp'"
]
}
],
"source": [
"import sleap_io as sio\n",
"\n",
"labels_original = sio.load_slp(\"tests/data/slp/minimal_instance.pkg.slp\")\n",
"labels_original.save(\"minimal_instance.pkg.nwb\", format=\"nwb_training\")\n",
"labels_loaded = sio.load_nwb(\"minimal_instance.pkg.nwb\")"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle file loading and saving errors more gracefully.

The notebook cell attempts to load and save files, but it results in a KeyError. This suggests that error handling needs to be improved to provide more informative messages or to ensure that the file paths are correct.

try:
    labels_original = sio.load_slp("tests/data/slp/minimal_instance.pkg.slp")
    labels_original.save("minimal_instance.pkg.nwb", format="nwb_training")
    labels_loaded = sio.load_nwb("minimal_instance.pkg.nwb")
except Exception as e:
    print(f"Error occurred: {e}")

save_test.ipynb Outdated
Comment on lines 46 to 90
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"minimal_instance.pkg.nwb_images/img_0.png\n"
]
},
{
"ename": "ValueError",
"evalue": "Can't write images with one color channel.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 16\u001b[0m\n\u001b[1;32m 14\u001b[0m img_path \u001b[38;5;241m=\u001b[39m save_path \u001b[38;5;241m/\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimg_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.png\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28mprint\u001b[39m(img_path)\n\u001b[0;32m---> 16\u001b[0m \u001b[43miio\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimage\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 17\u001b[0m img_paths\u001b[38;5;241m.\u001b[39mappend(img_path)\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28mprint\u001b[39m(img_paths)\n",
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/v3.py:147\u001b[0m, in \u001b[0;36mimwrite\u001b[0;34m(uri, image, plugin, extension, format_hint, **kwargs)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Write an ndimage to the given URI.\u001b[39;00m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03mThe exact behavior depends on the file type and plugin used. To learn about\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 136\u001b[0m \n\u001b[1;32m 137\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m imopen(\n\u001b[1;32m 140\u001b[0m uri,\n\u001b[1;32m 141\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 145\u001b[0m extension\u001b[38;5;241m=\u001b[39mextension,\n\u001b[1;32m 146\u001b[0m ) \u001b[38;5;28;01mas\u001b[39;00m img_file:\n\u001b[0;32m--> 147\u001b[0m encoded \u001b[38;5;241m=\u001b[39m \u001b[43mimg_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m encoded\n",
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/plugins/pillow.py:433\u001b[0m, in \u001b[0;36mPillowPlugin.write\u001b[0;34m(self, ndimage, mode, format, is_batch, **kwargs)\u001b[0m\n\u001b[1;32m 431\u001b[0m is_batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 432\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m--> 433\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt write images with one color channel.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 434\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m4\u001b[39m]:\n\u001b[1;32m 435\u001b[0m \u001b[38;5;66;03m# Note: this makes a channel-last assumption\u001b[39;00m\n\u001b[1;32m 436\u001b[0m is_batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"\u001b[0;31mValueError\u001b[0m: Can't write images with one color channel."
]
}
],
"source": [
"import imageio.v3 as iio\n",
"from pathlib import Path\n",
"import sleap_io as sio\n",
"\n",
"save_path = Path(\"minimal_instance.pkg.nwb_images\")\n",
"try:\n",
" save_path.mkdir(parents=True, exist_ok=True)\n",
"except Exception as e:\n",
" print(f\"An error {e} occurred. The directory could not be created.\")\n",
"img_paths = []\n",
"\n",
"labels_original = sio.load_slp(\"tests/data/slp/minimal_instance.pkg.slp\")\n",
"for i, lf in enumerate(labels_original):\n",
" img_path = save_path / f\"img_{i}.png\"\n",
" print(img_path)\n",
" iio.imwrite(img_path, lf.image)\n",
" img_paths.append(img_path)\n",
"print(img_paths)"
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve error handling for image writing.

This cell attempts to write images, but fails due to an unsupported single color channel. This issue should be caught and handled, or the notebook should ensure that only supported image formats are processed.

try:
    for i, lf in enumerate(labels_original):
        img_path = save_path / f"img_{i}.png"
        print(img_path)
        iio.imwrite(img_path, lf.image)
        img_paths.append(img_path)
except ValueError as e:
    print(f"Failed to write image: {e}")

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between be6ccbc and 0aabd86.

Files selected for processing (3)
  • save_test.ipynb (1 hunks)
  • sleap_io/io/main.py (3 hunks)
  • sleap_io/io/nwb.py (10 hunks)
Files skipped from review as they are similar to previous changes (1)
  • save_test.ipynb
Additional context used
Ruff
sleap_io/io/main.py

76-76: Local variable pose_training is assigned to but never used

Remove assignment to unused variable pose_training

(F841)

sleap_io/io/nwb.py

519-519: Local variable pose_training is assigned to but never used

Remove assignment to unused variable pose_training

(F841)


520-520: Local variable provenance is assigned to but never used

Remove assignment to unused variable provenance

(F841)

Comment on lines 50 to 55
def load_nwb(filename: str, as_training: Optional[bool]=None) -> Labels:
"""Load an NWB dataset as a SLEAP `Labels` object.
Args:
filename: Path to a NWB file (`.nwb`).
as_training: If `True`, load the dataset as a training dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify conditional logic in load_nwb.

The function now includes an as_training parameter to load datasets as training data. However, the logic can be simplified to avoid redundancy.

-  if as_training is None:
-    return nwb.read_nwb(filename)
-  
  if as_training:
    return nwb.read_nwb_training(filename)
  else:
    return nwb.read_nwb(filename)

Committable suggestion was skipped due to low confidence.

Comment on lines 63 to 86
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
See also: nwb.write_nwb, nwb.append_nwb
"""
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
else:
nwb.write_nwb(labels, filename)
if as_training:
pose_training = nwb.labels_to_pose_training(labels, **kwargs)
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
else:
nwb.write_nwb(labels, filename)

else:
if append and Path(filename).exists():
nwb.append_nwb(labels, filename, **kwargs)
else:
nwb.write_nwb(labels, filename)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor save_nwb to reduce complexity and address previous comments.

The function has been updated to handle the as_training parameter. However, the current implementation is complex and can be simplified. Additionally, address the previous comment about reducing the depth of conditionals.

-    if as_training:
-        pose_training = nwb.labels_to_pose_training(labels, **kwargs)
-        if append and Path(filename).exists():
-            nwb.append_nwb(labels, filename)
-        else:
-            nwb.write_nwb(labels, filename)
-
-    else:
-        if append and Path(filename).exists():
-            nwb.append_nwb(labels, filename, **kwargs)
-        else:
-            nwb.write_nwb(labels, filename)
+    func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
+    action = nwb.append_nwb_training if as_training else nwb.append_nwb
+    write = nwb.write_nwb_training if as_training else nwb.write_nwb
+
+    data = func(labels, **kwargs)
+    if append and Path(filename).exists():
+        action(data, filename, **kwargs)
+    else:
+        write(data, filename, **kwargs)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
See also: nwb.write_nwb, nwb.append_nwb
"""
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
else:
nwb.write_nwb(labels, filename)
if as_training:
pose_training = nwb.labels_to_pose_training(labels, **kwargs)
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
else:
nwb.write_nwb(labels, filename)
else:
if append and Path(filename).exists():
nwb.append_nwb(labels, filename, **kwargs)
else:
nwb.write_nwb(labels, filename)
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
See also: nwb.write_nwb, nwb.append_nwb
"""
func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
action = nwb.append_nwb_training if as_training else nwb.append_nwb
write = nwb.write_nwb_training if as_training else nwb.write_nwb
data = func(labels, **kwargs)
if append and Path(filename).exists():
action(data, filename, **kwargs)
else:
write(data, filename, **kwargs)
Tools
Ruff

76-76: Local variable pose_training is assigned to but never used

Remove assignment to unused variable pose_training

(F841)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Outside diff range, codebase verification and nitpick comments (8)
sleap_io/io/nwb.py (8)

Line range hint 238-357: Break down read_nwb and add error handling.

The function is long and performs multiple tasks. Consider breaking it down into smaller functions and adding error handling for invalid inputs.

def read_nwb(path: str) -> Labels:
+    if not isinstance(path, str) or not path.endswith(".nwb"):
+        raise ValueError("Invalid path: Expected a .nwb file.")

Line range hint 359-414: Add error handling and type annotations to write_nwb.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def write_nwb(
    labels: Labels,
    nwbfile_path: str,
    nwb_file_kwargs: Optional[dict] = None,
    pose_estimation_metadata: Optional[dict] = None,
    as_training: Optional[bool] = None,
):
+    if not isinstance(labels, Labels):
+        raise ValueError("Invalid input: Expected a Labels object.")

Line range hint 417-491: Break down append_nwb_data and add error handling.

The function is long and performs multiple tasks. Consider breaking it down into smaller functions and adding error handling for invalid inputs.

def append_nwb_data(
    labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] = None
) -> NWBFile:
+    if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile):
+        raise ValueError("Invalid input: Expected a Labels object and an NWBFile object.")

Line range hint 510-536: Add error handling and type annotations to append_nwb.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def append_nwb(
    labels: Labels, filename: str, pose_estimation_metadata: Optional[dict] = None,
    as_training: Optional[bool] = None
):
+    if not isinstance(labels, Labels) or not isinstance(filename, str):
+        raise ValueError("Invalid input: Expected a Labels object and a filename string.")

Line range hint 539-552: Add error handling and type annotations to get_processing_module_for_video.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def get_processing_module_for_video(
    processing_module_name: str, nwbfile: NWBFile
) -> ProcessingModule:
+    if not isinstance(processing_module_name, str) or not isinstance(nwbfile, NWBFile):
+        raise ValueError("Invalid input: Expected a processing module name string and an NWBFile object.")

Line range hint 555-636: Break down build_pose_estimation_container_for_track and add error handling.

The function is long and performs multiple tasks. Consider breaking it down into smaller functions and adding error handling for invalid inputs.

def build_pose_estimation_container_for_track(
    labels_data_df: pd.DataFrame,
    labels: Labels,
    track_name: str,
    video: Video,
    pose_estimation_metadata: dict,
) -> PoseEstimation:
+    if not isinstance(labels_data_df, pd.DataFrame) or not isinstance(labels, Labels) or not isinstance(track_name, str) or not isinstance(video, Video):
+        raise ValueError("Invalid input: Expected a DataFrame, Labels object, track name string, and Video object.")

Line range hint 638-686: Add error handling and type annotations to build_track_pose_estimation_list.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def build_track_pose_estimation_list(
    track_data_df: pd.DataFrame, timestamps: ArrayLike
) -> List[PoseEstimationSeries]:
+    if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, np.ndarray):
+        raise ValueError("Invalid input: Expected a DataFrame and an ndarray.")

Line range hint 238-246: Add error handling and type annotations to get_timestamps.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def get_timestamps(series: PoseEstimationSeries) -> np.ndarray:
+    if series is None or not hasattr(series, 'timestamps') or not hasattr(series, 'data'):
+        raise ValueError("Invalid series: Missing required data.")
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 0aabd86 and 85d47ab.

Files selected for processing (2)
  • sleap_io/io/main.py (3 hunks)
  • sleap_io/io/nwb.py (11 hunks)
Files skipped from review as they are similar to previous changes (1)
  • sleap_io/io/main.py
Additional context used
Ruff
sleap_io/io/nwb.py

504-504: Local variable pose_training is assigned to but never used

Remove assignment to unused variable pose_training

(F841)


505-505: Local variable provenance is assigned to but never used

Remove assignment to unused variable provenance

(F841)

Comment on lines 494 to 507
def append_nwb_training(labels: Labels, nwbfile_path: str) -> NWBFile: # type: ignore[return]
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_training = labels_to_pose_training(labels)
provenance = labels.provenance
raise NotImplementedError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implement or document the append_nwb_training function.

Currently, the function raises NotImplementedError. This needs to be addressed either by implementing the function or by documenting it as a known limitation.

- raise NotImplementedError
+ # TODO: Implement or document this function
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def append_nwb_training(labels: Labels, nwbfile_path: str) -> NWBFile: # type: ignore[return]
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_training = labels_to_pose_training(labels)
provenance = labels.provenance
raise NotImplementedError
def append_nwb_training(labels: Labels, nwbfile_path: str) -> NWBFile: # type: ignore[return]
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_training = labels_to_pose_training(labels)
provenance = labels.provenance
# TODO: Implement or document this function
Tools
Ruff

504-504: Local variable pose_training is assigned to but never used

Remove assignment to unused variable pose_training

(F841)


505-505: Local variable provenance is assigned to but never used

Remove assignment to unused variable provenance

(F841)

Comment on lines 75 to 90
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return]
"""Converts an NWB skeleton to a SLEAP skeleton.
Args:
skeleton: An NWB skeleton.
Returns:
A SLEAP skeleton.
"""
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
nodes=nodes,
edges=edges,
name=skeleton.name,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling and type annotations to nwb_skeleton_to_sleap.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton:
+    if skeleton is None or not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
+        raise ValueError("Invalid skeleton: Missing required data.")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return]
"""Converts an NWB skeleton to a SLEAP skeleton.
Args:
skeleton: An NWB skeleton.
Returns:
A SLEAP skeleton.
"""
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
nodes=nodes,
edges=edges,
name=skeleton.name,
)
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton:
if skeleton is None or not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
raise ValueError("Invalid skeleton: Missing required data.")
"""Converts an NWB skeleton to a SLEAP skeleton.
Args:
skeleton: An NWB skeleton.
Returns:
A SLEAP skeleton.
"""
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
nodes=nodes,
edges=edges,
name=skeleton.name,
)

Comment on lines 220 to 236
def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs):
"""Write a SLEAP package to an NWB file.
Args:
filename: The path to the SLEAP package.
labels: The SLEAP Labels object.
"""
assert filename.endswith(".pkg.slp")

path = filename.split(".slp")[0]
save_path = Path(path + ".nwb_images")
img_paths = []
for i, labeled_frame in enumerate(labels.labeled_frames):
img_path = save_path / f"frame_{i}.png"
imwrite(img_path, labeled_frame.image)
img_paths.append(img_path)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling and type annotations to sleap_pkg_to_nwb.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs):
+    if not filename.endswith(".pkg.slp"):
+        raise ValueError("Invalid filename: Expected a .pkg.slp file.")

Committable suggestion was skipped due to low confidence.

Comment on lines 195 to 217
def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignore[return]
"""Converts a list of SLEAP Videos to NWB SourceVideos.
Args:
videos: A list of SLEAP Videos.
Returns:
An NWB SourceVideos object.
"""
source_videos = []
for i, video in enumerate(videos):
image_series = ImageSeries(
name=f"video_{i}",
description="Video file",
unit="NA",
format="external",
external_file=[video.filename],
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]],
starting_frame=[0],
rate=30.0, # TODO - change to `video.backend.fps` when available
)
source_videos.append(image_series)
return SourceVideos(image_series=source_videos)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling and type annotations to videos_to_source_videos.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def videos_to_source_videos(videos: List[Video]) -> SourceVideos:
+    if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos):
+        raise ValueError("Invalid input: Expected a list of Video objects.")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignore[return]
"""Converts a list of SLEAP Videos to NWB SourceVideos.
Args:
videos: A list of SLEAP Videos.
Returns:
An NWB SourceVideos object.
"""
source_videos = []
for i, video in enumerate(videos):
image_series = ImageSeries(
name=f"video_{i}",
description="Video file",
unit="NA",
format="external",
external_file=[video.filename],
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]],
starting_frame=[0],
rate=30.0, # TODO - change to `video.backend.fps` when available
)
source_videos.append(image_series)
return SourceVideos(image_series=source_videos)
def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignore[return]
"""Converts a list of SLEAP Videos to NWB SourceVideos.
Args:
videos: A list of SLEAP Videos.
Returns:
An NWB SourceVideos object.
"""
if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos):
raise ValueError("Invalid input: Expected a list of Video objects.")
source_videos = []
for i, video in enumerate(videos):
image_series = ImageSeries(
name=f"video_{i}",
description="Video file",
unit="NA",
format="external",
external_file=[video.filename],
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]],
starting_frame=[0],
rate=30.0, # TODO - change to `video.backend.fps` when available
)
source_videos.append(image_series)
return SourceVideos(image_series=source_videos)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Outside diff range, codebase verification and nitpick comments (1)
sleap_io/io/nwb.py (1)

Line range hint 641-678:
Add error handling and type annotations to build_track_pose_estimation_list.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def build_track_pose_estimation_list(
    track_data_df: pd.DataFrame, timestamps: ArrayLike  # type: ignore[return]
) -> List[PoseEstimationSeries]:
+    if not isinstance(track_data_df, pd.DataFrame):
+        raise ValueError("Invalid input: Expected a pandas DataFrame.")
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 85d47ab and 1b8e08b.

Files selected for processing (2)
  • sleap_io/io/main.py (3 hunks)
  • sleap_io/io/nwb.py (11 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py

505-505: Local variable pose_training is assigned to but never used

Remove assignment to unused variable pose_training

(F841)


506-506: Local variable provenance is assigned to but never used

Remove assignment to unused variable provenance

(F841)

Comment on lines 63 to 91
def save_nwb(
labels: Labels,
filename: str,
as_training: bool = None,
append: bool = True,
**kwargs,
):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
See also: nwb.write_nwb, nwb.append_nwb
"""
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
if as_training:
if append and Path(filename).exists():
nwb.append_nwb_training(labels, filename)
else:
nwb.write_nwb(labels, filename, None, None, True)

else:
nwb.write_nwb(labels, filename)
if append and Path(filename).exists():
nwb.append_nwb(labels, filename, **kwargs)
else:
nwb.write_nwb(labels, filename)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor save_nwb to reduce complexity and address previous comments.

The function has been updated to handle the as_training parameter. However, the current implementation is complex and can be simplified. Additionally, address the previous comment about reducing the depth of conditionals.

-    if as_training:
-        pose_training = nwb.labels_to_pose_training(labels, **kwargs)
-        if append and Path(filename).exists():
-            nwb.append_nwb_training(pose_training, filename, **kwargs)
-        else:
-            nwb.write_nwb_training(pose_training, filename, **kwargs)
-
-    else:
-        if append and Path(filename).exists():
-            nwb.append_nwb(labels, filename, **kwargs)
-        else:
-            nwb.write_nwb(labels, filename)
+    func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
+    action = nwb.append_nwb_training if as_training else nwb.append_nwb
+    write = nwb.write_nwb_training if as_training else nwb.write_nwb
+
+    data = func(labels, **kwargs)
+    if append and Path(filename).exists():
+        action(data, filename, **kwargs)
+    else:
+        write(data, filename, **kwargs)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def save_nwb(
labels: Labels,
filename: str,
as_training: bool = None,
append: bool = True,
**kwargs,
):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
See also: nwb.write_nwb, nwb.append_nwb
"""
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
if as_training:
if append and Path(filename).exists():
nwb.append_nwb_training(labels, filename)
else:
nwb.write_nwb(labels, filename, None, None, True)
else:
nwb.write_nwb(labels, filename)
if append and Path(filename).exists():
nwb.append_nwb(labels, filename, **kwargs)
else:
nwb.write_nwb(labels, filename)
func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
action = nwb.append_nwb_training if as_training else nwb.append_nwb
write = nwb.write_nwb_training if as_training else nwb.write_nwb
data = func(labels, **kwargs)
if append and Path(filename).exists():
action(data, filename, **kwargs)
else:
write(data, filename, **kwargs)

Comment on lines 195 to 218
def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignore[return]
"""Converts a list of SLEAP Videos to NWB SourceVideos.
Args:
videos: A list of SLEAP Videos.
Returns:
An NWB SourceVideos object.
"""
source_videos = []
for i, video in enumerate(videos):
image_series = ImageSeries(
name=f"video_{i}",
description="Video file",
unit="NA",
format="external",
external_file=[video.filename],
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]],
starting_frame=[0],
rate=30.0, # TODO - change to `video.backend.fps` when available
)
source_videos.append(image_series)
return SourceVideos(image_series=source_videos)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling and type annotations to videos_to_source_videos.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def videos_to_source_videos(videos: List[Video]) -> SourceVideos:
+    if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos):
+        raise ValueError("Invalid input: Expected a list of Video objects.")

Committable suggestion was skipped due to low confidence.

Comment on lines 93 to 145
def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: # type: ignore[return]
"""Creates an NWB PoseTraining object from a Labels object.
Args:
labels: A Labels object.
filename: The filename of the source video.
Returns:
A PoseTraining object.
"""
training_frame_list = []
for i, labeled_frame in enumerate(labels.labeled_frames):
training_frame_name = name_generator("training_frame")
training_frame_annotator = f"{training_frame_name}{i}"
skeleton_instances_list = []
for instance in labeled_frame.instances:
if isinstance(instance, PredictedInstance):
continue
skeleton_instance = instance_to_skeleton_instance(instance)
skeleton_instances_list.append(skeleton_instance)

training_frame_skeleton_instances = SkeletonInstances(
skeleton_instances=skeleton_instances_list
)
training_frame_video = labeled_frame.video
training_frame_video_index = labeled_frame.frame_idx
training_frame = TrainingFrame(
name=training_frame_name,
annotator=training_frame_annotator,
skeleton_instances=training_frame_skeleton_instances,
source_video=ImageSeries(
name=training_frame_name,
description=training_frame_annotator,
unit="NA",
format="external",
external_file=[training_frame_video.filename],
dimension=[
training_frame_video.shape[1],
training_frame_video.shape[2],
],
starting_frame=[0],
rate=30.0,
),
source_video_frame_index=training_frame_video_index,
)
training_frame_list.append(training_frame)

training_frames = TrainingFrames(training_frames=training_frame_list)
pose_training = PoseTraining(
training_frames=training_frames,
source_videos=videos_to_source_videos(labels.videos),
)
return pose_training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Break down labels_to_pose_training and add error handling.

The function is long and performs multiple tasks. Consider breaking it down into smaller functions and adding error handling for invalid inputs.

def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining:
+    if labels is None:
+        raise ValueError("Labels object cannot be None.")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: # type: ignore[return]
"""Creates an NWB PoseTraining object from a Labels object.
Args:
labels: A Labels object.
filename: The filename of the source video.
Returns:
A PoseTraining object.
"""
training_frame_list = []
for i, labeled_frame in enumerate(labels.labeled_frames):
training_frame_name = name_generator("training_frame")
training_frame_annotator = f"{training_frame_name}{i}"
skeleton_instances_list = []
for instance in labeled_frame.instances:
if isinstance(instance, PredictedInstance):
continue
skeleton_instance = instance_to_skeleton_instance(instance)
skeleton_instances_list.append(skeleton_instance)
training_frame_skeleton_instances = SkeletonInstances(
skeleton_instances=skeleton_instances_list
)
training_frame_video = labeled_frame.video
training_frame_video_index = labeled_frame.frame_idx
training_frame = TrainingFrame(
name=training_frame_name,
annotator=training_frame_annotator,
skeleton_instances=training_frame_skeleton_instances,
source_video=ImageSeries(
name=training_frame_name,
description=training_frame_annotator,
unit="NA",
format="external",
external_file=[training_frame_video.filename],
dimension=[
training_frame_video.shape[1],
training_frame_video.shape[2],
],
starting_frame=[0],
rate=30.0,
),
source_video_frame_index=training_frame_video_index,
)
training_frame_list.append(training_frame)
training_frames = TrainingFrames(training_frames=training_frame_list)
pose_training = PoseTraining(
training_frames=training_frames,
source_videos=videos_to_source_videos(labels.videos),
)
return pose_training
def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: # type: ignore[return]
"""Creates an NWB PoseTraining object from a Labels object.
Args:
labels: A Labels object.
filename: The filename of the source video.
Returns:
A PoseTraining object.
"""
if labels is None:
raise ValueError("Labels object cannot be None.")
training_frame_list = []
for i, labeled_frame in enumerate(labels.labeled_frames):
training_frame_name = name_generator("training_frame")
training_frame_annotator = f"{training_frame_name}{i}"
skeleton_instances_list = []
for instance in labeled_frame.instances:
if isinstance(instance, PredictedInstance):
continue
skeleton_instance = instance_to_skeleton_instance(instance)
skeleton_instances_list.append(skeleton_instance)
training_frame_skeleton_instances = SkeletonInstances(
skeleton_instances=skeleton_instances_list
)
training_frame_video = labeled_frame.video
training_frame_video_index = labeled_frame.frame_idx
training_frame = TrainingFrame(
name=training_frame_name,
annotator=training_frame_annotator,
skeleton_instances=training_frame_skeleton_instances,
source_video=ImageSeries(
name=training_frame_name,
description=training_frame_annotator,
unit="NA",
format="external",
external_file=[training_frame_video.filename],
dimension=[
training_frame_video.shape[1],
training_frame_video.shape[2],
],
starting_frame=[0],
rate=30.0,
),
source_video_frame_index=training_frame_video_index,
)
training_frame_list.append(training_frame)
training_frames = TrainingFrames(training_frames=training_frame_list)
pose_training = PoseTraining(
training_frames=training_frames,
source_videos=videos_to_source_videos(labels.videos),
)
return pose_training

Comment on lines 220 to 236
def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs):
"""Write a SLEAP package to an NWB file.
Args:
filename: The path to the SLEAP package.
labels: The SLEAP Labels object.
"""
assert filename.endswith(".pkg.slp")

path = filename.split(".slp")[0]
save_path = Path(path + ".nwb_images")
img_paths = []
for i, labeled_frame in enumerate(labels.labeled_frames):
img_path = save_path / f"frame_{i}.png"
imwrite(img_path, labeled_frame.image)
img_paths.append(img_path)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling and type annotations to sleap_pkg_to_nwb.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs):
+    if not filename.endswith(".pkg.slp"):
+        raise ValueError("Invalid filename: Expected a .pkg.slp file.")

Committable suggestion was skipped due to low confidence.

Comment on lines 493 to 508
def append_nwb_training(
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_training = labels_to_pose_training(labels)
provenance = labels.provenance
raise NotImplementedError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implement or document the append_nwb_training function.

Currently, the function raises NotImplementedError. This needs to be addressed either by implementing the function or by documenting it as a known limitation.

- raise NotImplementedError
+ # TODO: Implement or document this function
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def append_nwb_training(
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_training = labels_to_pose_training(labels)
provenance = labels.provenance
raise NotImplementedError
def append_nwb_training(
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_training = labels_to_pose_training(labels)
provenance = labels.provenance
# TODO: Implement or document this function
Tools
Ruff

505-505: Local variable pose_training is assigned to but never used

Remove assignment to unused variable pose_training

(F841)


506-506: Local variable provenance is assigned to but never used

Remove assignment to unused variable provenance

(F841)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Outside diff range, codebase verification and nitpick comments (3)
sleap_io/io/nwb.py (3)

Line range hint 358-415:
Review the function write_nwb.

This function writes data from a PoseTraining object to an NWB file. The function setup and the handling of optional metadata are done correctly. The use of append_nwb_data to integrate the training data into the NWB file is appropriate.

However, the function lacks error handling for potential issues during the file writing process. Consider adding error handling mechanisms to improve robustness.

+  try:
       nwbfile = NWBFile(**nwb_file_kwargs)
       nwbfile = append_nwb_data(pose_training, nwbfile, pose_estimation_metadata)
       with NWBHDF5IO(str(nwbfile_path), "w") as io:
           io.write(nwbfile)
+  except Exception as e:
+      # Handle potential file writing errors
+      print(f"Error writing NWB file: {e}")

Line range hint 522-551:
Add error handling to append_nwb.

The function lacks error handling for potential issues during the file writing process. Consider adding error handling mechanisms to improve robustness.

+  try:
       if as_training:
           with NWBHDF5IO(filename, mode="a", load_namespaces=True) as io:
               nwb_file = io.read()
               nwb_file = append_nwb_training(
                   labels, nwb_file, pose_estimation_metadata=pose_estimation_metadata
               )
               io.write(nwb_file)
       else:
           with NWBHDF5IO(filename, mode="a", load_namespaces=True) as io:
               nwb_file = io.read()
               nwb_file = append_nwb_data(
                   labels, nwb_file, pose_estimation_metadata=pose_estimation_metadata
               )
               io.write(nwb_file)
+  except Exception as e:
+      # Handle potential file writing errors
+      print(f"Error appending to NWB file: {e}")

Line range hint 652-694:
Add error handling and type annotations to build_track_pose_estimation_list.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def build_track_pose_estimation_list(
    track_data_df: pd.DataFrame, timestamps: ArrayLike  # type: ignore[return]
) -> List[PoseEstimationSeries]:
+    if track_data_df is None or timestamps is None:
+        raise ValueError("Invalid input: track_data_df and timestamps cannot be None.")
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 1b8e08b and aab5a79.

Files selected for processing (1)
  • sleap_io/io/nwb.py (11 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py

511-511: Local variable pose_training is assigned to but never used

Remove assignment to unused variable pose_training

(F841)


518-518: Local variable labels_data_df is assigned to but never used

Remove assignment to unused variable labels_data_df

(F841)

Additional comments not posted (1)
sleap_io/io/nwb.py (1)

49-72: Incomplete functionality in pose_training_to_labels.

The function contains a TODO comment indicating incomplete functionality for determining frame_idx from training_frame.

-  frame_idx = training_frame # TODO
+  frame_idx = training_frame.frame_index  # Assuming `frame_index` is an attribute of `TrainingFrame`

Do you want me to assist in completing this functionality or open a GitHub issue to track this task?

Likely invalid or redundant comment.

Comment on lines 494 to 519
def append_nwb_training(
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
pose_training = labels_to_pose_training(labels)

for lf in labels.labeled_frames:
if lf.has_predicted_instances:
labels_data_df = convert_predictions_to_dataframe(labels)
break
else:
labels_data_df = pd.DataFrame()
raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implement or document the append_nwb_training function.

Currently, the function raises NotImplementedError. This needs to be addressed either by implementing the function or by documenting it as a known limitation.

- raise NotImplementedError
+ # TODO: Implement or document this function
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def append_nwb_training(
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
pose_training = labels_to_pose_training(labels)
for lf in labels.labeled_frames:
if lf.has_predicted_instances:
labels_data_df = convert_predictions_to_dataframe(labels)
break
else:
labels_data_df = pd.DataFrame()
raise NotImplementedError
def append_nwb_training(
labels: Labels, nwbfile_path: str, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
"""Append a PoseTraining object to an existing NWB data file.
Args:
pose_training: A PoseTraining object.
nwbfile_path: The path to the NWB file.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
pose_training = labels_to_pose_training(labels)
for lf in labels.labeled_frames:
if lf.has_predicted_instances:
labels_data_df = convert_predictions_to_dataframe(labels)
break
else:
labels_data_df = pd.DataFrame()
# TODO: Implement or document this function
Tools
Ruff

511-511: Local variable pose_training is assigned to but never used

Remove assignment to unused variable pose_training

(F841)


518-518: Local variable labels_data_df is assigned to but never used

Remove assignment to unused variable labels_data_df

(F841)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Outside diff range, codebase verification and nitpick comments (6)
sleap_io/io/nwb.py (6)

Line range hint 361-414: Add error handling and type annotations to write_nwb.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def write_nwb(
    labels: Labels,
    nwbfile_path: str,
    nwb_file_kwargs: Optional[dict] = None,
    pose_estimation_metadata: Optional[dict] = None,
    as_training: Optional[bool] = None,
):
+    if not isinstance(labels, Labels):
+        raise ValueError("Invalid input: Expected a Labels object.")

Line range hint 446-494: Add error handling and type annotations to append_nwb_data.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def append_nwb_data(
    labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] = None
) -> NWBFile:
+    if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile):
+        raise ValueError("Invalid input: Expected Labels and NWBFile objects.")

Line range hint 543-571: Add error handling and type annotations to append_nwb.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def append_nwb(
    labels: Labels,
    filename: str,
    pose_estimation_metadata: Optional[dict] = None,
    as_training: Optional[bool] = None,
):
+    if not isinstance(labels, Labels) or not isinstance(filename, str):
+        raise ValueError("Invalid input: Expected Labels object and filename string.")

Line range hint 573-591: Add type annotations to get_processing_module_for_video.

The function lacks type annotations for better readability and maintainability.

def get_processing_module_for_video(
    processing_module_name: str, nwbfile: NWBFile
) -> ProcessingModule:

Line range hint 593-671: Add error handling and type annotations to build_pose_estimation_container_for_track.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def build_pose_estimation_container_for_track(
    labels_data_df: pd.DataFrame,
    labels: Labels,
    track_name: str,
    video: Video,
    pose_estimation_metadata: dict,
) -> PoseEstimation:
+    if not isinstance(labels_data_df, pd.DataFrame) or not isinstance(labels, Labels) or not isinstance(track_name, str) or not isinstance(video, Video) or not isinstance(pose_estimation_metadata, dict):
+        raise ValueError("Invalid input: Expected appropriate data types for inputs.")

Line range hint 673-721: Add error handling and type annotations to build_track_pose_estimation_list.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def build_track_pose_estimation_list(
    track_data_df: pd.DataFrame, timestamps: ArrayLike
) -> List[PoseEstimationSeries]:
+    if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, np.ndarray):
+        raise ValueError("Invalid input: Expected a pandas DataFrame and a numpy array.")
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between aab5a79 and 66edcd6.

Files selected for processing (1)
  • sleap_io/io/nwb.py (11 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py

529-529: Local variable camera is assigned to but never used

Remove assignment to unused variable camera

(F841)


538-538: Local variable labels_data_df is assigned to but never used

Remove assignment to unused variable labels_data_df

(F841)

Additional comments not posted (1)
sleap_io/io/nwb.py (1)

51-74: Add error handling and type annotations to pose_training_to_labels.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def pose_training_to_labels(pose_training: PoseTraining) -> Labels:
+    if not isinstance(pose_training, PoseTraining):
+        raise ValueError("Invalid input: Expected a PoseTraining object.")

Likely invalid or redundant comment.

Comment on lines 496 to 539
def append_nwb_training(
labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
"""Append training data from a Labels object to an in-memory NWB file.

Args:
labels: A general labels object.
nwbfile: An in-memory NWB file.
pose_estimation_metadata: Metadata for pose estimation.

Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version

subject = Subject(subject_id="No specified id", species="No specified species")
nwbfile.subject = subject
pose_training = labels_to_pose_training(labels)

behavior_pm = nwbfile.create_processing_module(
name="behavior",
description="Behavioral data",
)
behavior_pm.add(pose_training)

skeletons_list = [slp_skeleton_to_nwb(skeleton) for skeleton in labels.skeletons]
skeletons = Skeletons(skeletons=skeletons_list)
behavior_pm.add(skeletons)

camera = nwbfile.create_device(name="camera",
description="Camera used to record the video",
manufacturer="No specified manufacturer")

for lf in labels.labeled_frames:
if lf.has_predicted_instances:
labels_data_df = convert_predictions_to_dataframe(labels)
break
else:
labels_data_df = pd.DataFrame()
return nwbfile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling and type annotations to append_nwb_training.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def append_nwb_training(
    labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
+    if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile):
+        raise ValueError("Invalid input: Expected Labels and NWBFile objects.")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def append_nwb_training(
labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
"""Append training data from a Labels object to an in-memory NWB file.
Args:
labels: A general labels object.
nwbfile: An in-memory NWB file.
pose_estimation_metadata: Metadata for pose estimation.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
subject = Subject(subject_id="No specified id", species="No specified species")
nwbfile.subject = subject
pose_training = labels_to_pose_training(labels)
behavior_pm = nwbfile.create_processing_module(
name="behavior",
description="Behavioral data",
)
behavior_pm.add(pose_training)
skeletons_list = [slp_skeleton_to_nwb(skeleton) for skeleton in labels.skeletons]
skeletons = Skeletons(skeletons=skeletons_list)
behavior_pm.add(skeletons)
camera = nwbfile.create_device(name="camera",
description="Camera used to record the video",
manufacturer="No specified manufacturer")
for lf in labels.labeled_frames:
if lf.has_predicted_instances:
labels_data_df = convert_predictions_to_dataframe(labels)
break
else:
labels_data_df = pd.DataFrame()
return nwbfile
def append_nwb_training(
labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile):
raise ValueError("Invalid input: Expected Labels and NWBFile objects.")
"""Append training data from a Labels object to an in-memory NWB file.
Args:
labels: A general labels object.
nwbfile: An in-memory NWB file.
pose_estimation_metadata: Metadata for pose estimation.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
subject = Subject(subject_id="No specified id", species="No specified species")
nwbfile.subject = subject
pose_training = labels_to_pose_training(labels)
behavior_pm = nwbfile.create_processing_module(
name="behavior",
description="Behavioral data",
)
behavior_pm.add(pose_training)
skeletons_list = [slp_skeleton_to_nwb(skeleton) for skeleton in labels.skeletons]
skeletons = Skeletons(skeletons=skeletons_list)
behavior_pm.add(skeletons)
camera = nwbfile.create_device(name="camera",
description="Camera used to record the video",
manufacturer="No specified manufacturer")
for lf in labels.labeled_frames:
if lf.has_predicted_instances:
labels_data_df = convert_predictions_to_dataframe(labels)
break
else:
labels_data_df = pd.DataFrame()
return nwbfile
Tools
Ruff

529-529: Local variable camera is assigned to but never used

Remove assignment to unused variable camera

(F841)


538-538: Local variable labels_data_df is assigned to but never used

Remove assignment to unused variable labels_data_df

(F841)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Outside diff range, codebase verification and nitpick comments (5)
sleap_io/io/nwb.py (5)

Line range hint 361-414:
Add error handling and type annotations to write_nwb.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def write_nwb(
    labels: Labels,
    nwbfile_path: str,
    nwb_file_kwargs: Optional[dict] = None,
    pose_estimation_metadata: Optional[dict] = None,
    as_training: Optional[bool] = None,
):
+    if not isinstance(labels, Labels) or not isinstance(nwbfile_path, str):
+        raise ValueError("Invalid input: Expected Labels object and a string path.")

Line range hint 417-493:
Add error handling and type annotations to append_nwb_data.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def append_nwb_data(
    labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] = None
) -> NWBFile:
+    if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile):
+        raise ValueError("Invalid input: Expected Labels and NWBFile objects.")

Line range hint 576-604:
Add error handling and type annotations to append_nwb.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def append_nwb(
    labels: Labels,
    filename: str,
    pose_estimation_metadata: Optional[dict] = None,
    as_training: Optional[bool] = None,
):
+    if not isinstance(labels, Labels) or not isinstance(filename, str):
+        raise ValueError("Invalid input: Expected Labels object and a string filename.")

Line range hint 606-690:
Add error handling and type annotations to build_pose_estimation_container_for_track.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def build_pose_estimation_container_for_track(
    labels_data_df: pd.DataFrame,
    labels: Labels,
    track_name: str,
    video: Video,
    pose_estimation_metadata: dict,
) -> PoseEstimation:
+    if not isinstance(labels_data_df, pd.DataFrame) or not isinstance(labels, Labels) or not isinstance(track_name, str) or not isinstance(video, Video):
+        raise ValueError("Invalid input: Expected DataFrame, Labels object, string track name, and Video object.")

Line range hint 706-749:
Add error handling and type annotations to build_track_pose_estimation_list.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def build_track_pose_estimation_list(
    track_data_df: pd.DataFrame, timestamps: ArrayLike
) -> List[PoseEstimationSeries]:
+    if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, np.ndarray):
+        raise ValueError("Invalid input: Expected DataFrame and ndarray.")
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 66edcd6 and 1b70d5c.

Files selected for processing (1)
  • sleap_io/io/nwb.py (11 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py

551-551: Local variable pose_estimation is assigned to but never used

Remove assignment to unused variable pose_estimation

(F841)


571-571: Local variable labels_data_df is assigned to but never used

Remove assignment to unused variable labels_data_df

(F841)

Additional comments not posted (7)
sleap_io/io/nwb.py (7)

77-93: Add error handling and type annotations to nwb_skeleton_to_sleap.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton:
+    if skeleton is None or not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
+        raise ValueError("Invalid skeleton: Missing required data.")

95-147: Break down labels_to_pose_training and add error handling.

The function is long and performs multiple tasks. Consider breaking it down into smaller functions and adding error handling for invalid inputs.

def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining:
+    if labels is None:
+        raise ValueError("Labels object cannot be None.")

150-172: Add error handling and type annotations to slp_skeleton_to_nwb.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def slp_skeleton_to_nwb(skeleton: SLEAPSkeleton) -> NWBSkeleton:
+    if skeleton is None or not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
+        raise ValueError("Invalid skeleton: Missing required data.")

175-194: Add error handling and type annotations to instance_to_skeleton_instance.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def instance_to_skeleton_instance(instance: Instance) -> SkeletonInstance:
+    if instance is None or not hasattr(instance, 'skeleton') or not hasattr(instance, 'points'):
+        raise ValueError("Invalid instance: Missing required data.")

197-219: Add error handling and type annotations to videos_to_source_videos.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def videos_to_source_videos(videos: List[Video]) -> SourceVideos:
+    if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos):
+        raise ValueError("Invalid input: Expected a list of Video objects.")

222-238: Add error handling and type annotations to sleap_pkg_to_nwb.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def sleap_pkg_to_nwb(filename: str, labels: Labels, **kwargs):
+    if not filename.endswith(".pkg.slp"):
+        raise ValueError("Invalid filename: Expected a .pkg.slp file.")

496-572: Add error handling and type annotations to append_nwb_training.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def append_nwb_training(
    labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict]
) -> NWBFile:
+    if not isinstance(labels, Labels) or not isinstance(nwbfile, NWBFile):
+        raise ValueError("Invalid input: Expected Labels and NWBFile objects.")
Tools
Ruff

551-551: Local variable pose_estimation is assigned to but never used

Remove assignment to unused variable pose_estimation

(F841)


571-571: Local variable labels_data_df is assigned to but never used

Remove assignment to unused variable labels_data_df

(F841)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between e932340 and 6f3aabe.

Files selected for processing (1)
  • sleap_io/io/nwb.py (15 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)

Additional comments not posted (13)
sleap_io/io/nwb.py (13)

61-69: Add error handling for invalid inputs.

The function lacks error handling for invalid pose_training inputs, which can improve robustness and maintainability.


75-78: Handle multiple external files.

The current implementation only supports single-file external videos. Consider implementing logic to handle multiple external files.


82-92: Optimize instance processing.

The loop processing the instances can be optimized by using list comprehensions.


99-115: Add error handling for invalid inputs.

The function lacks error handling for invalid skeleton inputs, which can improve robustness and maintainability.


117-166: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.


168-194: Add error handling for invalid inputs.

The function lacks error handling for invalid skeleton inputs, which can improve robustness and maintainability.


197-220: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.


215-216: Complete the TODO comment for instance tracking.

The function contains a TODO comment indicating incomplete functionality for tracking the number of instances.


223-246: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.


242-242: Complete the TODO comment for video.backend.fps.

The function contains a TODO comment indicating the need to change the rate to video.backend.fps when available.


314-314: Complete the TODO comment for video.backend.fps.

The function contains a TODO comment indicating the need to change the rate to video.backend.fps when available.


618-719: Add error handling for invalid inputs.

The function lacks error handling for invalid labels, nwbfile, pose_estimation_metadata, frame_inds, and frame_path inputs, which can improve robustness and maintainability.


857-857: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Outside diff range, codebase verification and nitpick comments (5)
tests/io/test_main.py (1)

40-42: Add assertions to verify the saved NWB file.

The test_nwb_training function currently lacks assertions to verify the correctness of the operation. Consider adding checks to ensure that the labels are saved and loaded correctly.

loaded_labels = load_nwb(tmp_path / "test_nwb.nwb")
assert isinstance(loaded_labels, Labels)
assert len(loaded_labels) == len(labels)
sleap_io/io/nwb.py (4)

17-20: Use contextlib.suppress for import suppression.

Replace the try-except-pass block with contextlib.suppress to improve readability.

-try:
-    import cv2
-except ImportError:
-    pass
+from contextlib import suppress
+with suppress(ImportError):
+    import cv2
Tools
Ruff

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)


137-139: Optimize instance processing.

The loop processing the instances can be optimized by using list comprehensions.

-        for instance, skeleton in zip(labeled_frame.instances, skeletons_list):
-            skeleton_instance = instance_to_skeleton_instance(instance, skeleton)
-            skeleton_instances_list.append(skeleton_instance)
+        skeleton_instances_list.extend(
+            instance_to_skeleton_instance(instance, skeleton)
+            for instance, skeleton in zip(labeled_frame.instances, skeletons_list)
+        )

183-187: Optimize edge processing logic.

The current implementation iterates over skeleton edges multiple times. Consider optimizing the logic to reduce complexity.

-    skeleton_edges = dict(enumerate(skeleton.nodes))
-    for i, source in skeleton_edges.items():
-        for destination in list(skeleton_edges.values())[i:]:
-            if Edge(source, destination) in skeleton.edges:
-                nwb_edges.append([i, list(skeleton_edges.values()).index(destination)])
+    node_indices = {node: idx for idx, node in enumerate(skeleton.nodes)}
+    nwb_edges = [
+        [node_indices[edge.source], node_indices[edge.destination]]
+        for edge in skeleton.edges
+    ]

711-712: Remove unused variables processing_module_name and nwb_processing_module.

These variables are assigned but never used, which can lead to confusion and clutter.

-        processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}"
-        nwb_processing_module = get_processing_module_for_video(
-            processing_module_name, nwbfile
-        )
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 6f3aabe and 7a83046.

Files selected for processing (2)
  • sleap_io/io/nwb.py (17 hunks)
  • tests/io/test_main.py (1 hunks)
Additional context used
Ruff
tests/io/test_main.py

29-29: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


30-30: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


36-36: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

sleap_io/io/nwb.py

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)

Additional comments not posted (7)
tests/io/test_main.py (1)

25-30: Use isinstance() for type checks.

Replace direct type comparisons with isinstance() for better readability and performance.

- assert type(loaded_labels) == Labels
- assert type(load_file(tmp_path / "test_nwb.nwb")) == Labels
+ assert isinstance(loaded_labels, Labels)
+ assert isinstance(load_file(tmp_path / "test_nwb.nwb"), Labels)
Tools
Ruff

29-29: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


30-30: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

sleap_io/io/nwb.py (6)

Line range hint 778-809: LGTM!

The function correctly appends a SLEAP Labels object to an existing NWB data file.


197-220: Add error handling for invalid inputs.

The function lacks error handling for invalid instance and skeleton inputs, which can improve robustness and maintainability.

+    if not isinstance(instance, Instance):
+        raise ValueError("Invalid input: Expected an Instance object.")
+    if not isinstance(skeleton, NWBSkeleton):
+        raise ValueError("Invalid input: Expected an NWBSkeleton object.")

Likely invalid or redundant comment.


915-915: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

+    if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, (np.ndarray, list)):
+        raise ValueError("Invalid input: Expected a pandas DataFrame and an array-like object.")

Likely invalid or redundant comment.


676-695: Add error handling for invalid inputs.

The function lacks error handling for invalid labels, nwbfile, pose_estimation_metadata, frame_inds, and frame_path inputs, which can improve robustness and maintainability.

+    if not isinstance(labels, Labels):
+        raise ValueError("Invalid input: Expected a Labels object.")
+    if not isinstance(nwbfile, NWBFile):
+        raise ValueError("Invalid input: Expected an NWBFile object.")
+    if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict):
+        raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.")
+    if frame_inds is not None and not isinstance(frame_inds, list):
+        raise ValueError("Invalid input: Expected a list for frame_inds.")
+    if frame_path is not None and not isinstance(frame_path, str):
+        raise ValueError("Invalid input: Expected a string for frame_path.")

Likely invalid or redundant comment.


75-78: Handle multiple external files.

The current implementation only supports single-file external videos. Consider implementing logic to handle multiple external files.

-        if source_video.format == "external" and len(source_video.external_file) == 1:
-            video = Video(source_video.external_file[0])
-        else:
-            raise NotImplementedError("Only single-file external videos are supported.")
+        if source_video.format == "external":
+            if len(source_video.external_file) == 1:
+                video = Video(source_video.external_file[0])
+            else:
+                # Implement logic for handling multiple external files
+                raise NotImplementedError("Multiple external files are not yet supported.")

Likely invalid or redundant comment.


117-132: Add error handling for invalid inputs.

The function lacks error handling for invalid labels, skeletons_list, and video_info inputs, which can improve robustness and maintainability.

+    if not isinstance(labels, Labels):
+        raise ValueError("Invalid input: Expected a Labels object.")
+    if not isinstance(skeletons_list, list) or not all(isinstance(s, NWBSkeleton) for s in skeletons_list):
+        raise ValueError("Invalid input: Expected a list of NWBSkeleton objects.")
+    if not isinstance(video_info, tuple) or len(video_info) != 3:
+        raise ValueError("Invalid input: Expected a tuple with three elements (dict, Video, ImageSeries).")

Likely invalid or redundant comment.

Comment on lines 215 to 216
# TODO add a counter in the loop to track the number of instances
# instead of using id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complete the TODO comment for instance tracking.

The function contains a TODO comment indicating incomplete functionality for tracking the number of instances.

Do you want me to assist in completing this functionality or open a GitHub issue to track this task?

sleap_io/io/nwb.py Show resolved Hide resolved
sleap_io/io/nwb.py Outdated Show resolved Hide resolved
sleap_io/io/nwb.py Outdated Show resolved Hide resolved
Comment on lines +168 to +179
def slp_skeleton_to_nwb(
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None
) -> NWBSkeleton: # type: ignore[return]
"""Converts SLEAP skeleton to NWB skeleton.
Args:
skeleton: A SLEAP skeleton.
subject: An NWB subject.
Returns:
An NWB skeleton.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling for invalid inputs.

The function lacks error handling for invalid skeleton inputs, which can improve robustness and maintainability.

+    if not isinstance(skeleton, SLEAPSkeleton):
+        raise ValueError("Invalid input: Expected a SLEAPSkeleton object.")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def slp_skeleton_to_nwb(
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None
) -> NWBSkeleton: # type: ignore[return]
"""Converts SLEAP skeleton to NWB skeleton.
Args:
skeleton: A SLEAP skeleton.
subject: An NWB subject.
Returns:
An NWB skeleton.
"""
def slp_skeleton_to_nwb(
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None
) -> NWBSkeleton: # type: ignore[return]
"""Converts SLEAP skeleton to NWB skeleton.
Args:
skeleton: A SLEAP skeleton.
subject: An NWB subject.
Returns:
An NWB skeleton.
"""
if not isinstance(skeleton, SLEAPSkeleton):
raise ValueError("Invalid input: Expected a SLEAPSkeleton object.")

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Outside diff range, codebase verification and nitpick comments (2)
tests/io/test_main.py (1)

34-36: Add assertions to verify test outcomes.

The test_nwb_training function lacks assertions to verify that the save_nwb operation has the expected effect. Consider adding checks similar to those in test_nwb to ensure the saved and loaded data match the expected structure and content.

sleap_io/io/nwb.py (1)

Line range hint 787-818: Add error handling for invalid inputs.

Consider adding error handling for invalid inputs to improve robustness and maintainability.

+    if not isinstance(labels, Labels):
+        raise ValueError("Invalid input: Expected a Labels object.")
+    if not isinstance(filename, str):
+        raise ValueError("Invalid input: Expected a string for filename.")
+    if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict):
+        raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.")
+    if frame_inds is not None and not isinstance(frame_inds, list):
+        raise ValueError("Invalid input: Expected a list for frame_inds.")
+    if frame_path is not None and not isinstance(frame_path, str):
+        raise ValueError("Invalid input: Expected a string for frame_path.")
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 7a83046 and 1a7e58b.

Files selected for processing (2)
  • sleap_io/io/nwb.py (17 hunks)
  • tests/io/test_main.py (1 hunks)
Additional context used
Ruff
tests/io/test_main.py

29-29: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


30-30: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

sleap_io/io/nwb.py

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)

Additional comments not posted (11)
tests/io/test_main.py (1)

25-32: Use isinstance() for type checks.

Replace direct type comparisons with isinstance() for better readability and performance.

- assert type(loaded_labels) == Labels
- assert type(load_file(tmp_path / "test_nwb.nwb")) == Labels
+ assert isinstance(loaded_labels, Labels)
+ assert isinstance(load_file(tmp_path / "test_nwb.nwb"), Labels)
Tools
Ruff

29-29: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


30-30: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

sleap_io/io/nwb.py (10)

61-96: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.


75-78: Handle multiple external files.

The current implementation only supports single-file external videos. Consider implementing logic to handle multiple external files.


99-115: Add error handling for invalid inputs.

The function lacks error handling for invalid skeleton inputs, which can improve robustness and maintainability.


117-166: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.


168-194: Add error handling for invalid inputs.

The function lacks error handling for invalid skeleton inputs, which can improve robustness and maintainability.


197-220: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.


215-216: Complete the TODO comment for instance tracking.

The function contains a TODO comment indicating incomplete functionality for tracking the number of instances.


223-246: Add error handling for invalid inputs.

The function lacks error handling for invalid videos inputs, which can improve robustness and maintainability.


685-786: Add error handling for invalid inputs.

The function lacks error handling for invalid labels, nwbfile, pose_estimation_metadata, frame_inds, and frame_path inputs, which can improve robustness and maintainability.


924-924: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

Copy link

codecov bot commented Aug 14, 2024

Codecov Report

Attention: Patch coverage is 78.00000% with 44 lines in your changes missing coverage. Please review.

Project coverage is 94.22%. Comparing base (e002927) to head (67a62c2).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
sleap_io/io/nwb.py 78.53% 41 Missing ⚠️
sleap_io/io/main.py 66.66% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #104      +/-   ##
==========================================
- Coverage   96.01%   94.22%   -1.80%     
==========================================
  Files          17       17              
  Lines        2035     2215     +180     
==========================================
+ Hits         1954     2087     +133     
- Misses         81      128      +47     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Outside diff range, codebase verification and nitpick comments (1)
sleap_io/io/nwb.py (1)

17-20: Use contextlib.suppress for import suppression.

Replace the try-except-pass block with contextlib.suppress to improve readability.

-try:
-    import cv2
-except ImportError:
-    pass
+from contextlib import suppress
+with suppress(ImportError):
+    import cv2
Tools
Ruff

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 1a7e58b and 90fd397.

Files selected for processing (1)
  • sleap_io/io/nwb.py (17 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)

Additional comments not posted (6)
sleap_io/io/nwb.py (6)

Line range hint 753-784: LGTM!

The function is well-implemented and handles the as_training parameter appropriately.


Line range hint 818-887: LGTM!

The function is well-implemented and handles metadata and timestamps appropriately.


197-220: Add error handling and complete the TODO comment for instance tracking.

The function lacks error handling for invalid inputs. The TODO comment indicates incomplete functionality for tracking the number of instances.

+    if not isinstance(instance, Instance):
+        raise ValueError("Invalid input: Expected an Instance object.")
+    if not isinstance(skeleton, NWBSkeleton):
+        raise ValueError("Invalid input: Expected an NWBSkeleton object.")

Do you want me to assist in completing the functionality for tracking the number of instances or open a GitHub issue to track this task?

Likely invalid or redundant comment.


890-890: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

+    if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, (np.ndarray, list)):
+        raise ValueError("Invalid input: Expected a pandas DataFrame and an array-like object.")

Likely invalid or redundant comment.


168-194: Optimize edge processing logic.

The current implementation iterates over skeleton edges multiple times. Consider optimizing the logic to reduce complexity.

-    skeleton_edges = dict(enumerate(skeleton.nodes))
-    for i, source in skeleton_edges.items():
-        for destination in list(skeleton_edges.values())[i:]:
-            if Edge(source, destination) in skeleton.edges:
-                nwb_edges.append([i, list(skeleton_edges.values()).index(destination)])
+    node_indices = {node: idx for idx, node in enumerate(skeleton.nodes)}
+    nwb_edges = [
+        [node_indices[edge.source], node_indices[edge.destination]]
+        for edge in skeleton.edges
+    ]

Likely invalid or redundant comment.


99-115: Add error handling for invalid skeleton inputs.

The function lacks error handling for invalid skeleton inputs, which can improve robustness and maintainability.

+    if not isinstance(skeleton, NWBSkeleton):
+        raise ValueError("Invalid input: Expected an NWBSkeleton object.")

Likely invalid or redundant comment.

sleap_io/io/nwb.py Outdated Show resolved Hide resolved
Comment on lines 117 to 166
def labels_to_pose_training(
labels: Labels,
skeletons_list: list[NWBSkeleton], # type: ignore[return]
video_info: tuple[dict[int, str], Video, ImageSeries],
) -> PoseTraining: # type: ignore[return]
"""Creates an NWB PoseTraining object from a Labels object.
Args:
labels: A Labels object.
skeletons_list: A list of NWB skeletons.
video_info: A tuple containing a dictionary mapping frame indices to file paths,
the video, and the `ImageSeries`.
Returns:
A PoseTraining object.
"""
training_frame_list = []
skeleton_instances_list = []
source_video_list = []
for i, labeled_frame in enumerate(labels.labeled_frames):
for instance, skeleton in zip(labeled_frame.instances, skeletons_list):
skeleton_instance = instance_to_skeleton_instance(instance, skeleton)
skeleton_instances_list.append(skeleton_instance)

training_frame_skeleton_instances = SkeletonInstances(
skeleton_instances=skeleton_instances_list
)
training_frame_video_index = labeled_frame.frame_idx

image_series = video_info[2]
source_video = image_series
if source_video not in source_video_list:
source_video_list.append(source_video)
training_frame = TrainingFrame(
name=f"training_frame_{i}",
annotator="N/A",
skeleton_instances=training_frame_skeleton_instances,
source_video=source_video,
source_video_frame_index=training_frame_video_index,
)
training_frame_list.append(training_frame)

training_frames = TrainingFrames(training_frames=training_frame_list)
source_videos = SourceVideos(image_series=source_video_list)
pose_training = PoseTraining(
training_frames=training_frames,
source_videos=source_videos,
)
return pose_training

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling and handle multiple external files.

The function lacks error handling for invalid inputs. Consider implementing logic to handle multiple external files.

+    if not isinstance(labels, Labels):
+        raise ValueError("Invalid input: Expected a Labels object.")
+    if not isinstance(skeletons_list, list) or not all(isinstance(s, NWBSkeleton) for s in skeletons_list):
+        raise ValueError("Invalid input: Expected a list of NWBSkeleton objects.")
+    if not isinstance(video_info, tuple) or len(video_info) != 3:
+        raise ValueError("Invalid input: Expected a tuple with three elements (dict, Video, ImageSeries).")
-        if source_video.format == "external" and len(source_video.external_file) == 1:
-            video = Video(source_video.external_file[0])
-        else:
-            raise NotImplementedError("Only single-file external videos are supported.")
+        if source_video.format == "external":
+            if len(source_video.external_file) == 1:
+                video = Video(source_video.external_file[0])
+            else:
+                # Implement logic for handling multiple external files
+                raise NotImplementedError("Multiple external files are not yet supported.")

sleap_io/io/nwb.py Show resolved Hide resolved
Comment on lines 651 to 752
nwbfile: NWBFile,
pose_estimation_metadata: Optional[dict] = None,
frame_inds: Optional[list[int]] = None,
frame_path: Optional[str] = None,
) -> NWBFile:
"""Append training data from a Labels object to an in-memory NWB file.
Args:
labels: A general labels object.
nwbfile: An in-memory NWB file.
pose_estimation_metadata: Metadata for pose estimation.
frame_inds: The indices of the frames to write. If None, all frames are written.
frame_path: The path to save the frames. If None, the path is the video
filename without the extension.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version

subject = Subject(subject_id="No specified id", species="No specified species")
nwbfile.subject = subject

for i, video in enumerate(labels.videos):
video_path = (
Path(video.filename)
if isinstance(video.filename, str)
else video.filename[i]
)
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}"
nwb_processing_module = get_processing_module_for_video(
processing_module_name, nwbfile
)
default_metadata["original_videos"] = [f"{video.filename}"]
default_metadata["labeled_videos"] = [f"{video.filename}"]
default_metadata.update(pose_estimation_metadata)

skeletons_list = [
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons
]
skeletons = Skeletons(skeletons=skeletons_list)
nwb_processing_module.add(skeletons)
video_info = write_video_to_path(
labels.videos[0], frame_inds, frame_path=frame_path
)
pose_training = labels_to_pose_training(labels, skeletons_list, video_info)
nwb_processing_module.add(pose_training)

confidence_definition = "Softmax output of the deep neural network"
reference_frame = (
"The coordinates are in (x, y) relative to the top-left of the image. "
"Coordinates refer to the midpoint of the pixel. "
"That is, t the midpoint of the top-left pixel is at (0, 0), whereas "
"the top-left corner of that same pixel is at (-0.5, -0.5)."
)
pose_estimation_series_list = []
for node in skeletons_list[0].nodes:
pose_estimation_series = PoseEstimationSeries(
name=node,
description=f"Marker placed on {node}",
data=np.random.rand(100, 2),
unit="pixels",
reference_frame=reference_frame,
timestamps=np.linspace(0, 10, num=100),
confidence=np.random.rand(100),
confidence_definition=confidence_definition,
)
pose_estimation_series_list.append(pose_estimation_series)

camera = nwbfile.create_device(
name=f"camera {i}",
description=f"Camera used to record video {i}",
manufacturer="No specified manufacturer",
)
try:
dimensions = np.array([[video.backend.shape[1], video.backend.shape[2]]])
except AttributeError:
dimensions = np.array([[400, 400]])

pose_estimation = PoseEstimation(
name="pose_estimation",
pose_estimation_series=pose_estimation_series_list,
description="Estimated positions of the nodes in the video",
original_videos=[video.filename for video in labels.videos],
labeled_videos=[video.filename for video in labels.videos],
dimensions=dimensions,
devices=[camera],
scorer="No specified scorer",
source_software="SLEAP",
source_software_version=sleap_version,
skeleton=skeletons_list[0],
)
nwb_processing_module.add(pose_estimation)

return nwbfile

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling for invalid inputs.

The function lacks error handling for invalid labels, nwbfile, pose_estimation_metadata, frame_inds, and frame_path inputs, which can improve robustness and maintainability.

+    if not isinstance(labels, Labels):
+        raise ValueError("Invalid input: Expected a Labels object.")
+    if not isinstance(nwbfile, NWBFile):
+        raise ValueError("Invalid input: Expected an NWBFile object.")
+    if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict):
+        raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.")
+    if frame_inds is not None and not isinstance(frame_inds, list):
+        raise ValueError("Invalid input: Expected a list for frame_inds.")
+    if frame_path is not None and not isinstance(frame_path, str):
+        raise ValueError("Invalid input: Expected a string for frame_path.")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def append_nwb_training(
labels: Labels,
nwbfile: NWBFile,
pose_estimation_metadata: Optional[dict] = None,
frame_inds: Optional[list[int]] = None,
frame_path: Optional[str] = None,
) -> NWBFile:
"""Append training data from a Labels object to an in-memory NWB file.
Args:
labels: A general labels object.
nwbfile: An in-memory NWB file.
pose_estimation_metadata: Metadata for pose estimation.
frame_inds: The indices of the frames to write. If None, all frames are written.
frame_path: The path to save the frames. If None, the path is the video
filename without the extension.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
subject = Subject(subject_id="No specified id", species="No specified species")
nwbfile.subject = subject
for i, video in enumerate(labels.videos):
video_path = (
Path(video.filename)
if isinstance(video.filename, str)
else video.filename[i]
)
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}"
nwb_processing_module = get_processing_module_for_video(
processing_module_name, nwbfile
)
default_metadata["original_videos"] = [f"{video.filename}"]
default_metadata["labeled_videos"] = [f"{video.filename}"]
default_metadata.update(pose_estimation_metadata)
skeletons_list = [
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons
]
skeletons = Skeletons(skeletons=skeletons_list)
nwb_processing_module.add(skeletons)
video_info = write_video_to_path(
labels.videos[0], frame_inds, frame_path=frame_path
)
pose_training = labels_to_pose_training(labels, skeletons_list, video_info)
nwb_processing_module.add(pose_training)
confidence_definition = "Softmax output of the deep neural network"
reference_frame = (
"The coordinates are in (x, y) relative to the top-left of the image. "
"Coordinates refer to the midpoint of the pixel. "
"That is, t the midpoint of the top-left pixel is at (0, 0), whereas "
"the top-left corner of that same pixel is at (-0.5, -0.5)."
)
pose_estimation_series_list = []
for node in skeletons_list[0].nodes:
pose_estimation_series = PoseEstimationSeries(
name=node,
description=f"Marker placed on {node}",
data=np.random.rand(100, 2),
unit="pixels",
reference_frame=reference_frame,
timestamps=np.linspace(0, 10, num=100),
confidence=np.random.rand(100),
confidence_definition=confidence_definition,
)
pose_estimation_series_list.append(pose_estimation_series)
camera = nwbfile.create_device(
name=f"camera {i}",
description=f"Camera used to record video {i}",
manufacturer="No specified manufacturer",
)
try:
dimensions = np.array([[video.backend.shape[1], video.backend.shape[2]]])
except AttributeError:
dimensions = np.array([[400, 400]])
pose_estimation = PoseEstimation(
name="pose_estimation",
pose_estimation_series=pose_estimation_series_list,
description="Estimated positions of the nodes in the video",
original_videos=[video.filename for video in labels.videos],
labeled_videos=[video.filename for video in labels.videos],
dimensions=dimensions,
devices=[camera],
scorer="No specified scorer",
source_software="SLEAP",
source_software_version=sleap_version,
skeleton=skeletons_list[0],
)
nwb_processing_module.add(pose_estimation)
return nwbfile
def append_nwb_training(
labels: Labels,
nwbfile: NWBFile,
pose_estimation_metadata: Optional[dict] = None,
frame_inds: Optional[list[int]] = None,
frame_path: Optional[str] = None,
) -> NWBFile:
"""Append training data from a Labels object to an in-memory NWB file.
Args:
labels: A general labels object.
nwbfile: An in-memory NWB file.
pose_estimation_metadata: Metadata for pose estimation.
frame_inds: The indices of the frames to write. If None, all frames are written.
frame_path: The path to save the frames. If None, the path is the video
filename without the extension.
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
+ if not isinstance(labels, Labels):
+ raise ValueError("Invalid input: Expected a Labels object.")
+ if not isinstance(nwbfile, NWBFile):
+ raise ValueError("Invalid input: Expected an NWBFile object.")
+ if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict):
+ raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.")
+ if frame_inds is not None and not isinstance(frame_inds, list):
+ raise ValueError("Invalid input: Expected a list for frame_inds.")
+ if frame_path is not None and not isinstance(frame_path, str):
+ raise ValueError("Invalid input: Expected a string for frame_path.")
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
subject = Subject(subject_id="No specified id", species="No specified species")
nwbfile.subject = subject
for i, video in enumerate(labels.videos):
video_path = (
Path(video.filename)
if isinstance(video.filename, str)
else video.filename[i]
)
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}"
nwb_processing_module = get_processing_module_for_video(
processing_module_name, nwbfile
)
default_metadata["original_videos"] = [f"{video.filename}"]
default_metadata["labeled_videos"] = [f"{video.filename}"]
default_metadata.update(pose_estimation_metadata)
skeletons_list = [
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons
]
skeletons = Skeletons(skeletons=skeletons_list)
nwb_processing_module.add(skeletons)
video_info = write_video_to_path(
labels.videos[0], frame_inds, frame_path=frame_path
)
pose_training = labels_to_pose_training(labels, skeletons_list, video_info)
nwb_processing_module.add(pose_training)
confidence_definition = "Softmax output of the deep neural network"
reference_frame = (
"The coordinates are in (x, y) relative to the top-left of the image. "
"Coordinates refer to the midpoint of the pixel. "
"That is, t the midpoint of the top-left pixel is at (0, 0), whereas "
"the top-left corner of that same pixel is at (-0.5, -0.5)."
)
pose_estimation_series_list = []
for node in skeletons_list[0].nodes:
pose_estimation_series = PoseEstimationSeries(
name=node,
description=f"Marker placed on {node}",
data=np.random.rand(100, 2),
unit="pixels",
reference_frame=reference_frame,
timestamps=np.linspace(0, 10, num=100),
confidence=np.random.rand(100),
confidence_definition=confidence_definition,
)
pose_estimation_series_list.append(pose_estimation_series)
camera = nwbfile.create_device(
name=f"camera {i}",
description=f"Camera used to record video {i}",
manufacturer="No specified manufacturer",
)
try:
dimensions = np.array([[video.backend.shape[1], video.backend.shape[2]]])
except AttributeError:
dimensions = np.array([[400, 400]])
pose_estimation = PoseEstimation(
name="pose_estimation",
pose_estimation_series=pose_estimation_series_list,
description="Estimated positions of the nodes in the video",
original_videos=[video.filename for video in labels.videos],
labeled_videos=[video.filename for video in labels.videos],
dimensions=dimensions,
devices=[camera],
scorer="No specified scorer",
source_software="SLEAP",
source_software_version=sleap_version,
skeleton=skeletons_list[0],
)
nwb_processing_module.add(pose_estimation)
return nwbfile

@talmo talmo mentioned this pull request Aug 16, 2024
4 tasks
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Outside diff range, codebase verification and nitpick comments (10)
sleap_io/io/nwb.py (10)

61-96: Enhance error handling for video formats.

The function currently raises a NotImplementedError for unsupported video formats. Consider adding more descriptive error handling or support for multiple external files.

- raise NotImplementedError("Only single-file external videos are supported.")
+ raise ValueError("Unsupported video format: Only single-file external videos are supported.")

99-115: Add error handling for invalid skeleton inputs.

The function lacks error handling for invalid skeleton inputs, which can improve robustness and maintainability.

+ if not isinstance(skeleton, NWBSkeleton):
+     raise ValueError("Invalid input: Expected an NWBSkeleton object.")

117-166: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

+ if not isinstance(labels, Labels):
+     raise ValueError("Invalid input: Expected a Labels object.")
+ if not isinstance(skeletons_list, list) or not all(isinstance(s, NWBSkeleton) for s in skeletons_list):
+     raise ValueError("Invalid input: Expected a list of NWBSkeleton objects.")
+ if not isinstance(video_info, tuple) or len(video_info) != 3:
+     raise ValueError("Invalid input: Expected a tuple with three elements (dict, Video, ImageSeries).")

Optimize source_video handling.

The current implementation assumes source_video is unique for each frame. Consider optimizing the logic to handle cases where source_video might repeat.

- if source_video not in source_video_list:
-     source_video_list.append(source_video)
+ if source_video not in source_video_list:
+     source_video_list.append(source_video)
+ else:
+     # Handle repeated source_video cases

168-194: Add error handling for invalid skeleton inputs.

The function lacks error handling for invalid skeleton inputs, which can improve robustness and maintainability.

+ if not isinstance(skeleton, SLEAPSkeleton):
+     raise ValueError("Invalid input: Expected a SLEAPSkeleton object.")

Optimize edge processing logic.

The current implementation iterates over skeleton edges multiple times. Consider optimizing the logic to reduce complexity.

- skeleton_edges = dict(enumerate(skeleton.nodes))
- for i, source in skeleton_edges.items():
-     for destination in list(skeleton_edges.values())[i:]:
-         if Edge(source, destination) in skeleton.edges:
-             nwb_edges.append([i, list(skeleton_edges.values()).index(destination)])
+ node_indices = {node: idx for idx, node in enumerate(skeleton.nodes)}
+ nwb_edges = [
+     [node_indices[edge.source], node_indices[edge.destination]]
+     for edge in skeleton.edges
+ ]

197-220: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

+ if not isinstance(instance, Instance):
+     raise ValueError("Invalid input: Expected an Instance object.")
+ if not isinstance(skeleton, NWBSkeleton):
+     raise ValueError("Invalid input: Expected an NWBSkeleton object.")

Improve instance naming logic.

Using id(instance) for naming can lead to non-intuitive names. Consider using a counter or a more descriptive naming scheme.

- name=f"skeleton_instance_{id(instance)}",
+ name=f"skeleton_instance_{instance.index}",  # Assuming `index` is a unique identifier

223-246: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

+ if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos):
+     raise ValueError("Invalid input: Expected a list of Video objects.")

Use dynamic rate for video.

The function uses a fixed rate of 30.0. Consider using video.backend.fps when available.

- rate=30.0,  # TODO - change to `video.backend.fps` when available
+ rate=video.backend.fps if hasattr(video.backend, 'fps') else 30.0,

248-317: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

+ if not isinstance(video, Video):
+     raise ValueError("Invalid input: Expected a Video object.")
+ if frame_inds is not None and not isinstance(frame_inds, list):
+     raise ValueError("Invalid input: Expected a list of frame indices.")
+ if not isinstance(image_format, str):
+     raise ValueError("Invalid input: Expected a string for image_format.")

Use dynamic save path.

The function uses a fixed save path based on the video filename. Consider allowing more flexible path configurations.

- save_path = video.filename.split(".")[0]
+ save_path = frame_path or video.filename.split(".")[0]

637-738: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

+ if not isinstance(labels, Labels):
+     raise ValueError("Invalid input: Expected a Labels object.")
+ if not isinstance(nwbfile, NWBFile):
+     raise ValueError("Invalid input: Expected an NWBFile object.")
+ if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict):
+     raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.")
+ if frame_inds is not None and not isinstance(frame_inds, list):
+     raise ValueError("Invalid input: Expected a list for frame_inds.")
+ if frame_path is not None and not isinstance(frame_path, str):
+     raise ValueError("Invalid input: Expected a string for frame_path.")

Use dynamic metadata.

The function uses fixed metadata for the subject. Consider allowing more flexible metadata configurations.

- subject = Subject(subject_id="No specified id", species="No specified species")
+ subject = Subject(subject_id=labels.provenance.get("subject_id", "No specified id"),
+                   species=labels.provenance.get("species", "No specified species"))

Line range hint 739-770: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

+ if not isinstance(labels, Labels):
+     raise ValueError("Invalid input: Expected a Labels object.")
+ if not isinstance(filename, str):
+     raise ValueError("Invalid input: Expected a string for filename.")
+ if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict):
+     raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.")
+ if frame_inds is not None and not isinstance(frame_inds, list):
+     raise ValueError("Invalid input: Expected a list for frame_inds.")
+ if frame_path is not None and not isinstance(frame_path, str):
+     raise ValueError("Invalid input: Expected a string for frame_path.")
+ if as_training is not None and not isinstance(as_training, bool):
+     raise ValueError("Invalid input: Expected a boolean for as_training.")

Line range hint 876-914: Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

+ if not isinstance(track_data_df, pd.DataFrame):
+     raise ValueError("Invalid input: Expected a pandas DataFrame.")
+ if not isinstance(timestamps, (np.ndarray, list)):
+     raise ValueError("Invalid input: Expected an array-like object for timestamps.")

Handle non-uniform sampling.

The function assumes uniform sampling for rate calculation. Consider handling non-uniform sampling more explicitly.

- uniform_samples = np.unique(sample_periods.round(5)).size == 1
+ uniform_samples = np.allclose(sample_periods, sample_periods[0], atol=0.0001)
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 90fd397 and d8294b7.

Files selected for processing (1)
  • sleap_io/io/nwb.py (17 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)


27-27: hdmf.utils.LabelledDict imported but unused

Remove unused import: hdmf.utils.LabelledDict

(F401)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between d8294b7 and c3b286e.

Files selected for processing (2)
  • pyproject.toml (1 hunks)
  • tests/io/test_nwb.py (2 hunks)
Files skipped from review as they are similar to previous changes (2)
  • pyproject.toml
  • tests/io/test_nwb.py

keyaloding added a commit that referenced this pull request Aug 27, 2024
keyaloding added a commit that referenced this pull request Aug 27, 2024
keyaloding added a commit that referenced this pull request Aug 27, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between c3b286e and dedc06d.

Files selected for processing (2)
  • environment.yml (1 hunks)
  • sleap_io/io/nwb.py (17 hunks)
Files skipped from review due to trivial changes (1)
  • environment.yml
Additional context used
Ruff
sleap_io/io/nwb.py

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)

Comment on lines 95 to 111
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return]
"""Converts an NWB skeleton to a SLEAP skeleton.
Args:
skeleton: An NWB skeleton.
Returns:
A SLEAP skeleton.
"""
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
nodes=nodes,
edges=edges,
name=skeleton.name,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add type annotations and error handling in nwb_skeleton_to_sleap.

The function nwb_skeleton_to_sleap converts an NWB skeleton to a SLEAP skeleton. The function is straightforward but lacks type annotations and error handling.

Enhance the function with type annotations and error handling:

def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton:
    if not isinstance(skeleton, NWBSkeleton):
        raise ValueError("Expected an NWBSkeleton object")
    nodes = [Node(name=node) for node in skeleton.nodes]
    edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
    return SLEAPSkeleton(nodes=nodes, edges=edges, name=skeleton.name)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between dedc06d and 141d1dd.

Files selected for processing (1)
  • sleap_io/io/nwb.py (17 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)

Additional comments not posted (1)
sleap_io/io/nwb.py (1)

17-20: Refactor import handling using contextlib.suppress.

To improve the readability and conciseness of the code, consider using contextlib.suppress for handling the optional import of cv2. This approach is cleaner and reduces the amount of code:

-from contextlib import suppress
+try:
+    import cv2
+except ImportError:
+    pass
-with suppress(ImportError):
+    import cv2

Likely invalid or redundant comment.

Tools
Ruff

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)

Comment on lines 193 to 214
def instance_to_skeleton_instance(
instance: Instance, skeleton: NWBSkeleton # type: ignore[return]
) -> SkeletonInstance: # type: ignore[return]
"""Converts a SLEAP Instance to an NWB SkeletonInstance.
Args:
instance: A SLEAP Instance.
skeleton: An NWB Skeleton.
Returns:
An NWB SkeletonInstance.
"""
points_list = list(instance.points.values())
node_locs = [[point.x, point.y] for point in points_list]
np_node_locations = np.array(node_locs)
return SkeletonInstance(
name=f"skeleton_instance_{id(instance)}",
id=np.uint64(id(instance)),
node_locations=np_node_locations,
node_visibility=[point.visible for point in instance.points.values()],
skeleton=skeleton,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling for input validation in instance_to_skeleton_instance.

This function should include checks to ensure that the inputs instance and skeleton are of the correct types and have the necessary attributes:

def instance_to_skeleton_instance(
    instance: Instance, skeleton: NWBSkeleton  # type: ignore[return]
) -> SkeletonInstance:
+    if not isinstance(instance, Instance):
+        raise TypeError("Expected an Instance object")
+    if not isinstance(skeleton, NWBSkeleton):
+        raise TypeError("Expected an NWBSkeleton object")
    points_list = list(instance.points.values())
    ...
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def instance_to_skeleton_instance(
instance: Instance, skeleton: NWBSkeleton # type: ignore[return]
) -> SkeletonInstance: # type: ignore[return]
"""Converts a SLEAP Instance to an NWB SkeletonInstance.
Args:
instance: A SLEAP Instance.
skeleton: An NWB Skeleton.
Returns:
An NWB SkeletonInstance.
"""
points_list = list(instance.points.values())
node_locs = [[point.x, point.y] for point in points_list]
np_node_locations = np.array(node_locs)
return SkeletonInstance(
name=f"skeleton_instance_{id(instance)}",
id=np.uint64(id(instance)),
node_locations=np_node_locations,
node_visibility=[point.visible for point in instance.points.values()],
skeleton=skeleton,
)
def instance_to_skeleton_instance(
instance: Instance, skeleton: NWBSkeleton # type: ignore[return]
) -> SkeletonInstance: # type: ignore[return]
"""Converts a SLEAP Instance to an NWB SkeletonInstance.
Args:
instance: A SLEAP Instance.
skeleton: An NWB Skeleton.
Returns:
An NWB SkeletonInstance.
"""
if not isinstance(instance, Instance):
raise TypeError("Expected an Instance object")
if not isinstance(skeleton, NWBSkeleton):
raise TypeError("Expected an NWBSkeleton object")
points_list = list(instance.points.values())
node_locs = [[point.x, point.y] for point in points_list]
np_node_locations = np.array(node_locs)
return SkeletonInstance(
name=f"skeleton_instance_{id(instance)}",
id=np.uint64(id(instance)),
node_locations=np_node_locations,
node_visibility=[point.visible for point in instance.points.values()],
skeleton=skeleton,
)

Comment on lines 95 to 110
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return]
"""Converts an NWB skeleton to a SLEAP skeleton.
Args:
skeleton: An NWB skeleton.
Returns:
A SLEAP skeleton.
"""
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
nodes=nodes,
edges=edges,
name=skeleton.name,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling for input validation in nwb_skeleton_to_sleap.

This function should include checks to ensure that the input skeleton is of the correct type and has the necessary attributes:

def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton:
+    if not isinstance(skeleton, NWBSkeleton):
+        raise TypeError("Expected an NWBSkeleton object")
+    if not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
+        raise ValueError("Invalid NWBSkeleton object: Missing 'nodes' or 'edges'")
    nodes = [Node(name=node) for node in skeleton.nodes]
    edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
    ...
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return]
"""Converts an NWB skeleton to a SLEAP skeleton.
Args:
skeleton: An NWB skeleton.
Returns:
A SLEAP skeleton.
"""
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
nodes=nodes,
edges=edges,
name=skeleton.name,
)
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return]
"""Converts an NWB skeleton to a SLEAP skeleton.
Args:
skeleton: An NWB skeleton.
Returns:
A SLEAP skeleton.
"""
if not isinstance(skeleton, NWBSkeleton):
raise TypeError("Expected an NWBSkeleton object")
if not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
raise ValueError("Invalid NWBSkeleton object: Missing 'nodes' or 'edges'")
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
nodes=nodes,
edges=edges,
name=skeleton.name,
)

Comment on lines 242 to 311
def write_video_to_path(
video: Video,
frame_inds: Optional[list[int]] = None,
image_format: str = "png",
frame_path: Optional[str] = None,
) -> tuple[dict[int, str], Video, ImageSeries]:
"""Write individual frames of a video to a path.
Args:
video: The video to write.
frame_inds: The indices of the frames to write. If None, all frames are written.
image_format: The format of the image to write. Default is .png
frame_path: The directory to save the frames to. If None, the path is the video
filename without the extension.
Returns:
A tuple containing a dictionary mapping frame indices to file paths,
the video, and the `ImageSeries`.
"""
index_data = {}
if frame_inds is None:
frame_inds = list(range(video.backend.num_frames))

if isinstance(video.filename, list):
save_path = video.filename[0].split(".")[0]
else:
save_path = video.filename.split(".")[0]

if frame_path is not None:
save_path = frame_path

try:
os.makedirs(save_path, exist_ok=True)
except PermissionError:
filename_with_extension = video.filename.split("/")[-1]
filename = filename_with_extension.split(".")[0]
save_path = input("Permission denied. Enter a new path:") + "/" + filename
os.makedirs(save_path, exist_ok=True)

if "cv2" in sys.modules:
for frame_idx in frame_inds:
try:
frame = video[frame_idx]
except FileNotFoundError:
video_filename = input("Video not found. Enter the video filename:")
video = Video.from_filename(video_filename)
frame = video[frame_idx]
frame_path = f"{save_path}/frame_{frame_idx}.{image_format}"
index_data[frame_idx] = frame_path
cv2.imwrite(frame_path, frame)
else:
for frame_idx in frame_inds:
try:
frame = video[frame_idx]
except FileNotFoundError:
video_filename = input("Video not found. Enter the filename:")
video = Video.from_filename(video_filename)
frame = video[frame_idx]
frame_path = f"{save_path}/frame_{frame_idx}.{image_format}"
index_data[frame_idx] = frame_path
iio.imwrite(frame_path, frame)

image_series = ImageSeries(
name="video",
external_file=os.listdir(save_path),
starting_frame=[0 for _ in range(len(os.listdir(save_path)))],
rate=30.0, # TODO - change to `video.backend.fps` when available
)
return index_data, video, image_series

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor and enhance error handling in write_video_to_path.

This function handles file operations and user input, which are error-prone areas. Consider refactoring to separate concerns and improve error recovery:

def write_video_to_path(
    video: Video,
    frame_inds: Optional[list[int]] = None,
    image_format: str = "png",
    frame_path: Optional[str] = None,
) -> tuple[dict[int, str], Video, ImageSeries]:
    index_data = {}
    if frame_inds is None:
        frame_inds = list(range(video.backend.num_frames))
    ...
    try:
        os.makedirs(save_path, exist_ok=True)
    except PermissionError:
+        # Consider logging the error or handling it more gracefully
        filename_with_extension = video.filename.split("/")[-1]
        filename = filename_with_extension.split(".")[0]
        save_path = input("Permission denied. Enter a new path:") + "/" + filename
        os.makedirs(save_path, exist_ok=True)
    ...

Consider adding more robust error handling and possibly separating the user input handling into a different function or method to make the code cleaner and more maintainable.

Committable suggestion was skipped due to low confidence.

Comment on lines 60 to 92
def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return]
"""Creates a Labels object from an NWB PoseTraining object.
Args:
pose_training: An NWB PoseTraining object.
Returns:
A Labels object.
"""
labeled_frames = []
skeletons = {}
training_frames = pose_training.training_frames.training_frames.values()
for training_frame in training_frames:
source_video = training_frame.source_video
video = Video(source_video.external_file)

frame_idx = training_frame.source_video_frame_index
instances = []
for instance in training_frame.skeleton_instances.skeleton_instances.values():
if instance.skeleton.name not in skeletons:
skeletons[instance.skeleton.name] = nwb_skeleton_to_sleap(
instance.skeleton
)
skeleton = skeletons[instance.skeleton.name]
instances.append(
Instance.from_numpy(
points=instance.node_locations[:], skeleton=skeleton
)
)
labeled_frames.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)
)
return Labels(labeled_frames=labeled_frames)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling for input validation in pose_training_to_labels.

The function lacks checks for the validity of the pose_training input. Adding error handling for the input type and ensuring it contains the necessary attributes would enhance the robustness of the function:

def pose_training_to_labels(pose_training: PoseTraining) -> Labels:
+    if not isinstance(pose_training, PoseTraining):
+        raise TypeError("Expected a PoseTraining object")
+    if not hasattr(pose_training, 'training_frames') or not isinstance(pose_training.training_frames, TrainingFrames):
+        raise ValueError("Invalid PoseTraining object: Missing 'training_frames'")
    labeled_frames = []
    skeletons = {}
    training_frames = pose_training.training_frames.training_frames.values()
    ...
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return]
"""Creates a Labels object from an NWB PoseTraining object.
Args:
pose_training: An NWB PoseTraining object.
Returns:
A Labels object.
"""
labeled_frames = []
skeletons = {}
training_frames = pose_training.training_frames.training_frames.values()
for training_frame in training_frames:
source_video = training_frame.source_video
video = Video(source_video.external_file)
frame_idx = training_frame.source_video_frame_index
instances = []
for instance in training_frame.skeleton_instances.skeleton_instances.values():
if instance.skeleton.name not in skeletons:
skeletons[instance.skeleton.name] = nwb_skeleton_to_sleap(
instance.skeleton
)
skeleton = skeletons[instance.skeleton.name]
instances.append(
Instance.from_numpy(
points=instance.node_locations[:], skeleton=skeleton
)
)
labeled_frames.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)
)
return Labels(labeled_frames=labeled_frames)
def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return]
"""Creates a Labels object from an NWB PoseTraining object.
Args:
pose_training: An NWB PoseTraining object.
Returns:
A Labels object.
"""
if not isinstance(pose_training, PoseTraining):
raise TypeError("Expected a PoseTraining object")
if not hasattr(pose_training, 'training_frames') or not isinstance(pose_training.training_frames, TrainingFrames):
raise ValueError("Invalid PoseTraining object: Missing 'training_frames'")
labeled_frames = []
skeletons = {}
training_frames = pose_training.training_frames.training_frames.values()
for training_frame in training_frames:
source_video = training_frame.source_video
video = Video(source_video.external_file)
frame_idx = training_frame.source_video_frame_index
instances = []
for instance in training_frame.skeleton_instances.skeleton_instances.values():
if instance.skeleton.name not in skeletons:
skeletons[instance.skeleton.name] = nwb_skeleton_to_sleap(
instance.skeleton
)
skeleton = skeletons[instance.skeleton.name]
instances.append(
Instance.from_numpy(
points=instance.node_locations[:], skeleton=skeleton
)
)
labeled_frames.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)
)
return Labels(labeled_frames=labeled_frames)

Comment on lines 113 to 161
def labels_to_pose_training(
labels: Labels,
skeletons_list: list[NWBSkeleton], # type: ignore[return]
video_info: tuple[dict[int, str], Video, ImageSeries],
) -> PoseTraining: # type: ignore[return]
"""Creates an NWB PoseTraining object from a Labels object.
Args:
labels: A Labels object.
skeletons_list: A list of NWB skeletons.
video_info: A tuple containing a dictionary mapping frame indices to file paths,
the video, and the `ImageSeries`.
Returns:
A PoseTraining object.
"""
training_frame_list = []
skeleton_instances_list = []
source_video_list = []
for i, labeled_frame in enumerate(labels.labeled_frames):
for instance, skeleton in zip(labeled_frame.instances, skeletons_list):
skeleton_instance = instance_to_skeleton_instance(instance, skeleton)
skeleton_instances_list.append(skeleton_instance)

training_frame_skeleton_instances = SkeletonInstances(
skeleton_instances=skeleton_instances_list
)
training_frame_video_index = labeled_frame.frame_idx

image_series = video_info[2]
source_video = image_series
if source_video not in source_video_list:
source_video_list.append(source_video)
training_frame = TrainingFrame(
name=f"training_frame_{i}",
annotator="N/A",
skeleton_instances=training_frame_skeleton_instances,
source_video=source_video,
source_video_frame_index=training_frame_video_index,
)
training_frame_list.append(training_frame)

training_frames = TrainingFrames(training_frames=training_frame_list)
source_videos = SourceVideos(image_series=source_video_list)
pose_training = PoseTraining(
training_frames=training_frames,
source_videos=source_videos,
)
return pose_training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve error handling and refactor labels_to_pose_training for clarity.

This function handles multiple complex operations and would benefit from clearer error handling and possibly some refactoring to improve readability:

def labels_to_pose_training(
    labels: Labels,
    skeletons_list: list[NWBSkeleton],  # type: ignore[return]
    video_info: tuple[dict[int, str], Video, ImageSeries],
) -> PoseTraining:
+    if not isinstance(labels, Labels):
+        raise TypeError("Expected a Labels object")
+    if not isinstance(skeletons_list, list) or not all(isinstance(s, NWBSkeleton) for s in skeletons_list):
+        raise ValueError("Expected a list of NWBSkeleton objects")
+    if not isinstance(video_info, tuple) or len(video_info) != 3:
+        raise ValueError("Expected a tuple containing (dict, Video, ImageSeries)")
    ...

Consider refactoring to separate some of the nested logic into helper functions to improve the function's readability and maintainability.

Committable suggestion was skipped due to low confidence.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Outside diff range, codebase verification and nitpick comments (2)
sleap_io/io/nwb.py (2)

Line range hint 698-729: Consider simplifying the function.

The function append_nwb handles multiple parameters and data structures, which adds complexity. Consider refactoring to simplify the function or breaking it down into smaller, more manageable parts.


Line range hint 764-825: Consider simplifying the function.

The function build_pose_estimation_container_for_track handles multiple parameters and data structures, which adds complexity. Consider refactoring to simplify the function or breaking it down into smaller, more manageable parts.

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 141d1dd and 67a62c2.

Files selected for processing (1)
  • sleap_io/io/nwb.py (17 hunks)
Additional context used
Ruff
sleap_io/io/nwb.py

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)


157-157: Use enumerate() for index variable instance_counter in for loop

(SIM113)

Additional comments not posted (7)
sleap_io/io/nwb.py (7)

186-218: LGTM!

The function pose_training_to_labels is well-implemented and handles the conversion from PoseTraining to Labels effectively.


221-236: LGTM!

The function nwb_skeleton_to_sleap correctly converts NWB skeleton data to SLEAP format without any apparent issues.


268-291: LGTM!

The function instance_to_skeleton_instance is well-implemented and handles the conversion from SLEAP instance data to NWB format effectively.


294-316: LGTM!

The function videos_to_source_videos correctly converts video data to NWB format without any apparent issues.


131-183: Confirm error handling and consider simplifying.

The function labels_to_pose_training is complex, handling multiple layers of data aggregation. Ensure that the error handling suggested in previous reviews has been incorporated to improve robustness.

Additionally, consider refactoring to simplify the function or breaking it down into smaller, more manageable parts.

Tools
Ruff

157-157: Use enumerate() for index variable instance_counter in for loop

(SIM113)


239-265: Confirm error handling and consider simplifying.

The function slp_skeleton_to_nwb converts SLEAP data to NWB format. Ensure that the error handling suggested in previous reviews has been incorporated to improve robustness.

Additionally, consider refactoring to simplify the function or breaking it down into smaller, more manageable parts.


637-696: Confirm error handling and consider simplifying.

The function append_nwb_training is complex, handling multiple layers of data aggregation. Ensure that the error handling suggested in previous reviews has been incorporated to improve robustness.

Additionally, consider refactoring to simplify the function or breaking it down into smaller, more manageable parts.

sleap_io/io/nwb.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

🧹 Outside diff range and nitpick comments (5)
README.md (1)

44-48: Excellent addition to the documentation!

The new example for saving labels to an NWB file with training data conversion is a valuable addition that aligns well with the PR objectives. It clearly demonstrates how to use the new as_training and frame_inds parameters.

A few suggestions to further improve this section:

  1. Consider adding a link to more detailed documentation about these new parameters, especially explaining the significance of as_training=True and how users should choose frame_inds.

  2. It might be helpful to clarify the purpose of saving individual frames. For example, you could expand on why this enhances portability and in what scenarios this would be particularly useful.

Here's a suggested expansion of the comment:

 # Save to an NWB file and convert SLEAP training data to NWB training data.
-# Note: This will save the first 3 frames of the video as individual images in a
-# subfolder next to the NWB file for portability.
+# Note: This will save the first 3 frames of the video as individual images in a
+# subfolder next to the NWB file. This enhances portability by allowing users to
+# visualize key frames without needing access to the original video file, which
+# can be particularly useful for sharing datasets or when working with large video files.
docs/index.md (1)

48-52: LGTM! Consider adding a brief explanation of NWB format.

The new code block and comments effectively demonstrate how to save SLEAP training data to an NWB file with additional options. The example is clear and informative.

To improve clarity for users who might not be familiar with the NWB format, consider adding a brief explanation or link to more information about NWB (Neurodata Without Borders) format. For example:

 # Save to NWB file.
 labels.save("predictions.nwb")
+
+# NWB (Neurodata Without Borders) is a standardized neurophysiology data format.
+# For more information, visit: https://www.nwb.org/
 
 # Save to an NWB file and convert SLEAP training data to NWB training data.
sleap_io/io/main.py (2)

84-91: Clarify the usage of frame_inds, frame_path, and image_format parameters.

The new parameters frame_inds, frame_path, and image_format are only applicable when as_training is True. However, this is not immediately clear from the parameter names or the docstring. Consider renaming these parameters and updating the docstring to make their usage more explicit.

Here's a suggestion:

def save_nwb(
    labels: Labels,
    filename: str,
    as_training: bool = False,
    append: bool = True,
    training_frame_indices: Optional[list[int]] = None,
    training_frame_path: Optional[str] = None,
    training_image_format: str = "png",
):
    """Save a SLEAP dataset to NWB format.

    Args:
        ...
        training_frame_indices: Optional list of labeled frame indices within the Labels to save when saving in training data format. No effect if `as_training` is `False`.
        training_frame_path: The path to a folder to save the extracted frame images to when saving in training data format. No effect if `as_training` is `False`.
        training_image_format: The image format to use when saving extracted frame images. No effect if `as_training` is `False`.
        ...
    """
    ...

229-230: Handle unknown formats consistently.

The load_file function raises a ValueError with a specific error message when an unknown format is encountered. However, the error message in the save_file function is different. Consider using a consistent error message for both functions.

Here's a suggestion:

else:
    raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.")
sleap_io/io/nwb.py (1)

339-339: Complete the TODO comment for video.backend.fps.

The function contains a TODO comment indicating the need to change the rate to video.backend.fps when available.

Do you want me to assist in completing this functionality or open a GitHub issue to track this task?

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 67a62c2 and 62dc540.

📒 Files selected for processing (5)
  • README.md (1 hunks)
  • docs/index.md (1 hunks)
  • pyproject.toml (1 hunks)
  • sleap_io/io/main.py (3 hunks)
  • sleap_io/io/nwb.py (17 hunks)
🧰 Additional context used
Ruff
sleap_io/io/nwb.py

11-11: os imported but unused

Remove unused import: os

(F401)


17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)


185-185: Use enumerate() for index variable instance_counter in for loop

(SIM113)

🔇 Additional comments not posted (5)
sleap_io/io/nwb.py (5)

60-104: LGTM!

The save_frame_images function is well-structured and handles the saving of frame images correctly. The error handling for missing video files is marked with a TODO comment, which is appropriate for future implementation.


107-153: LGTM!

The make_image_series function correctly creates NWB ImageSeries objects from the saved frame images. The grouping of labeled frames by video and the mapping of labeled frame indices to series indices are implemented properly.


156-209: LGTM!

The labels_to_pose_training function correctly converts SLEAP Labels to NWB PoseTraining. The creation of TrainingFrame and SkeletonInstance objects is handled properly, and the PoseTraining object is constructed with the necessary components.

🧰 Tools
Ruff

185-185: Use enumerate() for index variable instance_counter in for loop

(SIM113)


212-244: LGTM!

The pose_training_to_labels function correctly converts NWB PoseTraining back to SLEAP Labels. The reconstruction of LabeledFrame and Instance objects is implemented properly.


872-872: ⚠️ Potential issue

Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

def build_track_pose_estimation_list(
    track_data_df: pd.DataFrame, timestamps: ArrayLike
) -> List[PoseEstimationSeries]:
+    if not isinstance(track_data_df, pd.DataFrame) or not isinstance(timestamps, (np.ndarray, list)):
+        raise ValueError("Invalid input: Expected a pandas DataFrame and an array-like object.")

Likely invalid or redundant comment.

Comment on lines +27 to +29
# TODO: Replace with ndx-pose>=0.2.0 once released.
# https://github.com/rly/ndx-pose/issues/29
"ndx-pose @ git+https://github.com/rly/ndx-pose@a847ad4be75e60ef9e413b8cbfc99c616fc9fd05",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Dependency points to an outdated commit

The ndx-pose dependency is currently set to commit a847ad4be75e60ef9e413b8cbfc99c616fc9fd05, while the latest commit on the main branch is 8017291ae9aee84d9f0b6f1d9426b56372cd715b. To ensure you have the most recent updates and fixes, consider updating the dependency to the latest commit until version 0.2.0 is released on PyPI.

  • Update the commit hash to 8017291ae9aee84d9f0b6f1d9426b56372cd715b.
  • Monitor the PyPI repository for the release of ndx-pose>=0.2.0 and update the dependency accordingly.
🔗 Analysis chain

Temporary dependency solution needs follow-up

The addition of the ndx-pose dependency from a specific GitHub commit is a temporary solution. While this allows for immediate use of the required features, it comes with some considerations:

  1. Reproducibility: Using a specific commit ensures consistency but may cause issues if the repository is altered or becomes unavailable.
  2. Maintenance: This approach requires manual updates when new versions are released.
  3. Security: It bypasses the usual package vetting process on PyPI.

To ensure this temporary solution doesn't persist longer than necessary, let's check the current state of the ndx-pose package:

Consider implementing a CI/CD process to automatically check for the availability of ndx-pose>=0.2.0 on PyPI and create a pull request to update the dependency when it becomes available.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if ndx-pose>=0.2.0 is available on PyPI
pip install ndx-pose>=0.2.0 2>&1 || echo "ndx-pose>=0.2.0 is not yet available on PyPI"

# Check the latest commit on the main branch of ndx-pose repository
latest_commit=$(gh api repos/rly/ndx-pose/commits/main --jq '.sha')
echo "Latest commit on main branch: $latest_commit"
echo "Current commit in use: a847ad4be75e60ef9e413b8cbfc99c616fc9fd05"

Length of output: 518

Comment on lines +62 to +70
def save_nwb(
labels: Labels,
filename: str,
as_training: bool = False,
append: bool = True,
frame_inds: Optional[list[int]] = None,
frame_path: Optional[str] = None,
image_format: str = "png",
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor save_nwb to reduce complexity and improve readability.

The function signature has been updated to include several new parameters. However, the function body can be refactored to reduce complexity and improve readability. Consider the following suggestions:

  • Use a dictionary to map the as_training flag to the corresponding nwb functions. This will eliminate the need for the if-else block.
  • Move the docstring for the as_training parameter to a single line to improve readability.
  • Consider using more descriptive variable names for frame_inds and frame_path.

Here's a refactored version of the function:

def save_nwb(
    labels: Labels,
    filename: str,
    as_training: bool = False,
    append: bool = True,
    training_frame_indices: Optional[list[int]] = None,
    training_frame_path: Optional[str] = None,
    training_image_format: str = "png",
):
    """Save a SLEAP dataset to NWB format.

    Args:
        labels: A SLEAP `Labels` object (see `load_slp`).
        filename: Path to NWB file to save to. Must end in `.nwb`.
        as_training: If `True`, save the dataset as a training dataset.
        append: If `True` (the default), append to existing NWB file. File will be created if it does not exist.
        training_frame_indices: Optional list of labeled frame indices within the Labels to save when saving in training data format.
        training_frame_path: The path to a folder to save the extracted frame images to when saving in training data format.
        training_image_format: The image format to use when saving extracted frame images.

    See also: nwb.write_nwb, nwb.append_nwb, nwb.append_nwb_training
    """
    nwb_funcs = {
        True: {
            "append": nwb.append_nwb_training,
            "write": nwb.write_nwb_training,
        },
        False: {
            "append": nwb.append_nwb,
            "write": nwb.write_nwb,
        },
    }

    action = "append" if append and Path(filename).exists() else "write"
    nwb_funcs[as_training][action](
        labels,
        filename,
        frame_inds=training_frame_indices,
        frame_path=training_frame_path,
        image_format=training_image_format,
    )

Comment on lines +93 to +103
See also: nwb.write_nwb, nwb.append_nwb, nwb.append_nwb_training
"""
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
nwb.append_nwb(
labels,
filename,
as_training=as_training,
frame_inds=frame_inds,
frame_path=frame_path,
image_format=image_format,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor the if-else block to reduce duplication.

The if-else block in the save_nwb function contains duplicated code. The only difference between the two branches is the function being called (nwb.append_nwb or nwb.write_nwb). This duplication can be eliminated by using a dictionary to map the append flag to the corresponding function.

Here's a refactored version:

nwb_funcs = {
    True: {
        "append": nwb.append_nwb_training,
        "write": nwb.write_nwb_training,
    },
    False: {
        "append": nwb.append_nwb,
        "write": nwb.write_nwb,
    },
}

action = "append" if append and Path(filename).exists() else "write"
nwb_funcs[as_training][action](
    labels,
    filename,
    frame_inds=training_frame_indices,
    frame_path=training_frame_path,
    image_format=training_image_format,
)

Comment on lines +260 to +263
elif format in ("nwb", "nwb_predictions"):
save_nwb(labels, filename, False)
elif format == "nwb_training":
save_nwb(labels, filename, True, frame_inds=kwargs.get("frame_inds", None))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use a dictionary to map formats to save functions.

The save_file function uses a series of if-elif statements to determine which save function to call based on the format parameter. This can be simplified by using a dictionary to map formats to their corresponding save functions.

Here's a refactored version:

format_map = {
    "slp": save_slp,
    "nwb": lambda l, f, **kw: save_nwb(l, f, False, **kw),
    "nwb_training": lambda l, f, **kw: save_nwb(l, f, True, frame_inds=kw.pop("frame_inds", None), **kw),
    "labelstudio": save_labelstudio,
    "jabs": lambda l, f, **kw: save_jabs(l, kw.pop("pose_version", 5), f, **kw),
}

if format in format_map:
    format_map[format](labels, filename, **kwargs)
else:
    raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.")

Comment on lines +17 to +20
try:
import cv2
except ImportError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use contextlib.suppress for import suppression.

Replace the try-except-pass block with contextlib.suppress to improve readability.

-try:
-    import cv2
-except ImportError:
-    pass
+from contextlib import suppress
+with suppress(ImportError):
+    import cv2
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
import cv2
except ImportError:
pass
from contextlib import suppress
with suppress(ImportError):
import cv2
🧰 Tools
Ruff

17-20: Use contextlib.suppress(ImportError) instead of try-except-pass

Replace with contextlib.suppress(ImportError)

(SIM105)

Comment on lines +320 to +342
def videos_to_source_videos(videos: list[Video]) -> SourceVideos: # type: ignore[return]
"""Converts a list of SLEAP Videos to NWB SourceVideos.
Args:
videos: A list of SLEAP Videos.
Returns:
An NWB SourceVideos object.
"""
source_videos = []
for i, video in enumerate(videos):
image_series = ImageSeries(
name=f"video_{i}",
description="N/A",
unit="NA",
format="external",
external_file=[video.filename],
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]],
starting_frame=[0],
rate=30.0, # TODO - change to `video.backend.fps` when available
)
source_videos.append(image_series)
return SourceVideos(image_series=source_videos)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for input validation in videos_to_source_videos.

This function should include checks to ensure that the input videos is a list of Video objects:

def videos_to_source_videos(videos: list[Video]) -> SourceVideos:  # type: ignore[return]
+    if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos):
+        raise TypeError("Expected a list of Video objects")
    source_videos = []
    ...
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def videos_to_source_videos(videos: list[Video]) -> SourceVideos: # type: ignore[return]
"""Converts a list of SLEAP Videos to NWB SourceVideos.
Args:
videos: A list of SLEAP Videos.
Returns:
An NWB SourceVideos object.
"""
source_videos = []
for i, video in enumerate(videos):
image_series = ImageSeries(
name=f"video_{i}",
description="N/A",
unit="NA",
format="external",
external_file=[video.filename],
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]],
starting_frame=[0],
rate=30.0, # TODO - change to `video.backend.fps` when available
)
source_videos.append(image_series)
return SourceVideos(image_series=source_videos)
def videos_to_source_videos(videos: list[Video]) -> SourceVideos: # type: ignore[return]
"""Converts a list of SLEAP Videos to NWB SourceVideos.
Args:
videos: A list of SLEAP Videos.
Returns:
An NWB SourceVideos object.
"""
if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos):
raise TypeError("Expected a list of Video objects")
source_videos = []
for i, video in enumerate(videos):
image_series = ImageSeries(
name=f"video_{i}",
description="N/A",
unit="NA",
format="external",
external_file=[video.filename],
dimension=[video.backend.img_shape[0], video.backend.img_shape[1]],
starting_frame=[0],
rate=30.0, # TODO - change to `video.backend.fps` when available
)
source_videos.append(image_series)
return SourceVideos(image_series=source_videos)

Comment on lines +247 to +262
def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return]
"""Converts an NWB skeleton to a SLEAP skeleton.
Args:
skeleton: An NWB skeleton.
Returns:
A SLEAP skeleton.
"""
nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
nodes=nodes,
edges=edges,
name=skeleton.name,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling and type annotations to nwb_skeleton_to_sleap.

The function lacks error handling for invalid inputs and type annotations for better readability and maintainability.

def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton:
+    if skeleton is None or not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
+        raise ValueError("Invalid skeleton: Missing required data.")
    nodes = [Node(name=node) for node in skeleton.nodes]
    edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
    return SLEAPSkeleton(
        nodes=nodes,
        edges=edges,
        name=skeleton.name,
    )

Committable suggestion was skipped due to low confidence.

Comment on lines +265 to +291
def slp_skeleton_to_nwb(
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None
) -> NWBSkeleton: # type: ignore[return]
"""Converts SLEAP skeleton to NWB skeleton.
Args:
skeleton: A SLEAP skeleton.
subject: An NWB subject.
Returns:
An NWB skeleton.
"""
if subject is None:
subject = Subject(species="No specified species", subject_id="No specified id")
nwb_edges = []
skeleton_edges = dict(enumerate(skeleton.nodes))
for i, source in skeleton_edges.items():
for destination in list(skeleton_edges.values())[i:]:
if Edge(source, destination) in skeleton.edges:
nwb_edges.append([i, list(skeleton_edges.values()).index(destination)])

return NWBSkeleton(
name=skeleton.name,
nodes=skeleton.node_names,
edges=np.array(nwb_edges, dtype=np.uint8),
subject=subject,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for input validation in slp_skeleton_to_nwb.

This function should include checks to ensure that the input skeleton is of the correct type and has the necessary attributes:

def slp_skeleton_to_nwb(
    skeleton: SLEAPSkeleton, subject: Optional[Subject] = None
) -> NWBSkeleton:
+    if not isinstance(skeleton, SLEAPSkeleton):
+        raise TypeError("Expected a SLEAPSkeleton object")
+    if not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
+        raise ValueError("Invalid SLEAPSkeleton object: Missing 'nodes' or 'edges'")
    if subject is None:
        subject = Subject(species="No specified species", subject_id="No specified id")
    ...
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def slp_skeleton_to_nwb(
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None
) -> NWBSkeleton: # type: ignore[return]
"""Converts SLEAP skeleton to NWB skeleton.
Args:
skeleton: A SLEAP skeleton.
subject: An NWB subject.
Returns:
An NWB skeleton.
"""
if subject is None:
subject = Subject(species="No specified species", subject_id="No specified id")
nwb_edges = []
skeleton_edges = dict(enumerate(skeleton.nodes))
for i, source in skeleton_edges.items():
for destination in list(skeleton_edges.values())[i:]:
if Edge(source, destination) in skeleton.edges:
nwb_edges.append([i, list(skeleton_edges.values()).index(destination)])
return NWBSkeleton(
name=skeleton.name,
nodes=skeleton.node_names,
edges=np.array(nwb_edges, dtype=np.uint8),
subject=subject,
)
def slp_skeleton_to_nwb(
skeleton: SLEAPSkeleton, subject: Optional[Subject] = None
) -> NWBSkeleton: # type: ignore[return]
"""Converts SLEAP skeleton to NWB skeleton.
Args:
skeleton: A SLEAP skeleton.
subject: An NWB subject.
Returns:
An NWB skeleton.
"""
if not isinstance(skeleton, SLEAPSkeleton):
raise TypeError("Expected a SLEAPSkeleton object")
if not hasattr(skeleton, 'nodes') or not hasattr(skeleton, 'edges'):
raise ValueError("Invalid SLEAPSkeleton object: Missing 'nodes' or 'edges'")
if subject is None:
subject = Subject(species="No specified species", subject_id="No specified id")
nwb_edges = []
skeleton_edges = dict(enumerate(skeleton.nodes))
for i, source in skeleton_edges.items():
for destination in list(skeleton_edges.values())[i:]:
if Edge(source, destination) in skeleton.edges:
nwb_edges.append([i, list(skeleton_edges.values()).index(destination)])
return NWBSkeleton(
name=skeleton.name,
nodes=skeleton.node_names,
edges=np.array(nwb_edges, dtype=np.uint8),
subject=subject,
)

Comment on lines +294 to +317
def instance_to_skeleton_instance(
instance: Instance,
skeleton: NWBSkeleton, # type: ignore[return]
counter: int,
) -> SkeletonInstance: # type: ignore[return]
"""Converts a SLEAP Instance to an NWB SkeletonInstance.
Args:
instance: A SLEAP Instance.
skeleton: An NWB Skeleton.
counter: An integer counter.
Returns:
An NWB SkeletonInstance.
"""
points_list = list(instance.points.values())
node_locations = np.array([[point.x, point.y] for point in points_list])
return SkeletonInstance(
name=f"skeleton_instance_{counter}",
id=np.uint64(counter),
node_locations=node_locations,
node_visibility=[point.visible for point in instance.points.values()],
skeleton=skeleton,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for input validation in instance_to_skeleton_instance.

This function should include checks to ensure that the inputs instance and skeleton are of the correct types and have the necessary attributes:

def instance_to_skeleton_instance(
    instance: Instance, skeleton: NWBSkeleton  # type: ignore[return]
) -> SkeletonInstance:
+    if not isinstance(instance, Instance):
+        raise TypeError("Expected an Instance object")
+    if not isinstance(skeleton, NWBSkeleton):
+        raise TypeError("Expected an NWBSkeleton object")
    points_list = list(instance.points.values())
    ...

Committable suggestion was skipped due to low confidence.

Comment on lines +663 to +732
def append_nwb_training(
labels: Labels,
nwbfile: NWBFile,
pose_estimation_metadata: Optional[dict] = None,
frame_inds: Optional[list[int]] = None,
frame_path: Optional[str] = None,
image_format: str = "png",
) -> NWBFile:
"""Append training data from a Labels object to an in-memory NWB file.
Args:
labels: A general labels object.
nwbfile: An in-memory NWB file.
pose_estimation_metadata: Metadata for pose estimation.
frame_inds: The indices of the frames to write. If None, all frames are written.
frame_path: The path to save the frames. If None, the path is the video
filename without the extension.
image_format: The format of the image to write. Default is "png".
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version

subject = Subject(subject_id="No specified id", species="No specified species")
nwbfile.subject = subject

for i, video in enumerate(labels.videos):
video_path = (
Path(video.filename)
if isinstance(video.filename, str)
else video.filename[i]
)
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}"
nwb_processing_module = get_processing_module_for_video(
processing_module_name, nwbfile
)
default_metadata["original_videos"] = [f"{video.filename}"]
default_metadata["labeled_videos"] = [f"{video.filename}"]
default_metadata.update(pose_estimation_metadata)

skeletons_list = [
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons
]
skeletons = Skeletons(skeletons=skeletons_list)
nwb_processing_module.add(skeletons)
lf_ind_to_img_path = save_frame_images(
labels,
save_path=frame_path,
image_format=image_format,
frame_inds=frame_inds,
)
image_series_list, lf_ind_to_series_ind = make_image_series(
labels, lf_ind_to_img_path
)
pose_training = labels_to_pose_training(
labels, skeletons_list, image_series_list, lf_ind_to_series_ind
)
nwb_processing_module.add(pose_training)

_ = nwbfile.create_device(
name=f"camera {i}",
description=f"Camera used to record video {i}",
manufacturer="No specified manufacturer",
)
return nwbfile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for invalid inputs.

The function lacks error handling for invalid inputs, which can improve robustness and maintainability.

def append_nwb_training(
    labels: Labels,
    nwbfile: NWBFile,
    pose_estimation_metadata: Optional[dict] = None,
    frame_inds: Optional[list[int]] = None,
    frame_path: Optional[str] = None,
    image_format: str = "png",
) -> NWBFile:
+    if not isinstance(labels, Labels):
+        raise ValueError("Invalid input: Expected a Labels object.")
+    if not isinstance(nwbfile, NWBFile):
+        raise ValueError("Invalid input: Expected an NWBFile object.")
+    if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict):
+        raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.")
+    if frame_inds is not None and not isinstance(frame_inds, list):
+        raise ValueError("Invalid input: Expected a list for frame_inds.")
+    if frame_path is not None and not isinstance(frame_path, str):
+        raise ValueError("Invalid input: Expected a string for frame_path.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def append_nwb_training(
labels: Labels,
nwbfile: NWBFile,
pose_estimation_metadata: Optional[dict] = None,
frame_inds: Optional[list[int]] = None,
frame_path: Optional[str] = None,
image_format: str = "png",
) -> NWBFile:
"""Append training data from a Labels object to an in-memory NWB file.
Args:
labels: A general labels object.
nwbfile: An in-memory NWB file.
pose_estimation_metadata: Metadata for pose estimation.
frame_inds: The indices of the frames to write. If None, all frames are written.
frame_path: The path to save the frames. If None, the path is the video
filename without the extension.
image_format: The format of the image to write. Default is "png".
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
subject = Subject(subject_id="No specified id", species="No specified species")
nwbfile.subject = subject
for i, video in enumerate(labels.videos):
video_path = (
Path(video.filename)
if isinstance(video.filename, str)
else video.filename[i]
)
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}"
nwb_processing_module = get_processing_module_for_video(
processing_module_name, nwbfile
)
default_metadata["original_videos"] = [f"{video.filename}"]
default_metadata["labeled_videos"] = [f"{video.filename}"]
default_metadata.update(pose_estimation_metadata)
skeletons_list = [
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons
]
skeletons = Skeletons(skeletons=skeletons_list)
nwb_processing_module.add(skeletons)
lf_ind_to_img_path = save_frame_images(
labels,
save_path=frame_path,
image_format=image_format,
frame_inds=frame_inds,
)
image_series_list, lf_ind_to_series_ind = make_image_series(
labels, lf_ind_to_img_path
)
pose_training = labels_to_pose_training(
labels, skeletons_list, image_series_list, lf_ind_to_series_ind
)
nwb_processing_module.add(pose_training)
_ = nwbfile.create_device(
name=f"camera {i}",
description=f"Camera used to record video {i}",
manufacturer="No specified manufacturer",
)
return nwbfile
def append_nwb_training(
labels: Labels,
nwbfile: NWBFile,
pose_estimation_metadata: Optional[dict] = None,
frame_inds: Optional[list[int]] = None,
frame_path: Optional[str] = None,
image_format: str = "png",
) -> NWBFile:
if not isinstance(labels, Labels):
raise ValueError("Invalid input: Expected a Labels object.")
if not isinstance(nwbfile, NWBFile):
raise ValueError("Invalid input: Expected an NWBFile object.")
if pose_estimation_metadata is not None and not isinstance(pose_estimation_metadata, dict):
raise ValueError("Invalid input: Expected a dictionary for pose_estimation_metadata.")
if frame_inds is not None and not isinstance(frame_inds, list):
raise ValueError("Invalid input: Expected a list for frame_inds.")
if frame_path is not None and not isinstance(frame_path, str):
raise ValueError("Invalid input: Expected a string for frame_path.")
"""Append training data from a Labels object to an in-memory NWB file.
Args:
labels: A general labels object.
nwbfile: An in-memory NWB file.
pose_estimation_metadata: Metadata for pose estimation.
frame_inds: The indices of the frames to write. If None, all frames are written.
frame_path: The path to save the frames. If None, the path is the video
filename without the extension.
image_format: The format of the image to write. Default is "png".
Returns:
An in-memory NWB file with the PoseTraining data appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
subject = Subject(subject_id="No specified id", species="No specified species")
nwbfile.subject = subject
for i, video in enumerate(labels.videos):
video_path = (
Path(video.filename)
if isinstance(video.filename, str)
else video.filename[i]
)
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}"
nwb_processing_module = get_processing_module_for_video(
processing_module_name, nwbfile
)
default_metadata["original_videos"] = [f"{video.filename}"]
default_metadata["labeled_videos"] = [f"{video.filename}"]
default_metadata.update(pose_estimation_metadata)
skeletons_list = [
slp_skeleton_to_nwb(skeleton, subject) for skeleton in labels.skeletons
]
skeletons = Skeletons(skeletons=skeletons_list)
nwb_processing_module.add(skeletons)
lf_ind_to_img_path = save_frame_images(
labels,
save_path=frame_path,
image_format=image_format,
frame_inds=frame_inds,
)
image_series_list, lf_ind_to_series_ind = make_image_series(
labels, lf_ind_to_img_path
)
pose_training = labels_to_pose_training(
labels, skeletons_list, image_series_list, lf_ind_to_series_ind
)
nwb_processing_module.add(pose_training)
_ = nwbfile.create_device(
name=f"camera {i}",
description=f"Camera used to record video {i}",
manufacturer="No specified manufacturer",
)
return nwbfile

@talmo talmo mentioned this pull request Sep 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants