Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for more .nwb training data to SLEAP #104

Open
wants to merge 92 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
70a34ed
nwb to sleap conversion function
keyaloding Jun 28, 2024
f2fd6f5
push
keyaloding Jul 12, 2024
6c062ec
fixed attribute errors
keyaloding Jul 16, 2024
f7d8a0c
test
keyaloding Jul 16, 2024
23a5a83
b
keyaloding Jul 16, 2024
cf4bcf7
n
keyaloding Jul 16, 2024
8948e92
a
keyaloding Jul 17, 2024
be6ccbc
message
keyaloding Jul 17, 2024
0aabd86
update
keyaloding Jul 17, 2024
85d47ab
change
keyaloding Jul 17, 2024
1b8e08b
black
keyaloding Jul 17, 2024
aab5a79
cab
keyaloding Jul 18, 2024
66edcd6
update
keyaloding Jul 18, 2024
1b70d5c
change
keyaloding Jul 18, 2024
c59f996
m
keyaloding Jul 18, 2024
50da4a1
black
keyaloding Jul 18, 2024
c8c173a
black
keyaloding Jul 18, 2024
0742bdf
updated skeleton naming
keyaloding Jul 18, 2024
055e3e8
updated package handling
keyaloding Jul 19, 2024
dcc3da8
added error handling
keyaloding Jul 19, 2024
0ba071e
`append_nwb_training` update
keyaloding Jul 19, 2024
0c7f048
added test
keyaloding Jul 19, 2024
6a6c38c
Merge branch 'talmolab:main' into kloding/nwb_data_structures
keyaloding Jul 22, 2024
b15471e
Code quality
keyaloding Jul 22, 2024
dd522f9
resolved some comments
keyaloding Jul 22, 2024
22ed734
resolved comments + formatting
keyaloding Jul 22, 2024
7d1737d
updated test and fixed comments
keyaloding Jul 22, 2024
6f1a7d1
added img_to_path function (not implemented)
keyaloding Jul 23, 2024
51d5623
updated img_to_path
keyaloding Jul 24, 2024
e511efa
img_to_path
keyaloding Jul 24, 2024
42da933
added test and finished img_to_path
keyaloding Jul 25, 2024
e246f8b
removed slp_pkg function and code quality
keyaloding Jul 25, 2024
13d3a47
attempted to resolve `Skeletons` issue and nwbfile version
keyaloding Jul 26, 2024
e579ebf
updated load_nwb and video writing
keyaloding Jul 26, 2024
90f0ec4
updated processing module handling
keyaloding Jul 26, 2024
24634a1
update
keyaloding Jul 26, 2024
4819536
fixed save issues
keyaloding Jul 26, 2024
432c343
updated saving
keyaloding Jul 27, 2024
7cb04c4
updated code quality and image write function
keyaloding Jul 29, 2024
a66cf8f
implemented metadata indexing
keyaloding Jul 29, 2024
8f2271b
updated loading nwb
keyaloding Jul 29, 2024
f510765
black
keyaloding Jul 29, 2024
183cf55
updated video handling
keyaloding Jul 29, 2024
7e2d138
added `__future__` import
keyaloding Jul 31, 2024
846856b
updated future import
keyaloding Jul 31, 2024
f88ca17
updated test and documentation
keyaloding Aug 1, 2024
0a8fc45
updated saving nwb
keyaloding Aug 1, 2024
73b41b0
Revert "updated video handling"
keyaloding Aug 1, 2024
a9be0eb
Delete save_test.ipynb
keyaloding Aug 2, 2024
1031247
added test fixtures
keyaloding Aug 2, 2024
68c7cfb
fixed recursion error
keyaloding Aug 2, 2024
3bf7c8a
updated error handling
keyaloding Aug 3, 2024
6d12f39
updated error handling
keyaloding Aug 3, 2024
e94c14d
code quality
keyaloding Aug 5, 2024
dac1eb5
code quality
keyaloding Aug 5, 2024
593c139
code quality
keyaloding Aug 5, 2024
68133e1
code quality
keyaloding Aug 5, 2024
d66b858
updated video saving
keyaloding Aug 5, 2024
0abaca4
deleted large test file
keyaloding Aug 7, 2024
4ba67ea
updated video loading and pose_estimation creation
keyaloding Aug 7, 2024
3bac601
reverted PoseEstimation handling
keyaloding Aug 8, 2024
9699be2
updated fixtures
keyaloding Aug 8, 2024
f01e03a
updated fixtures
keyaloding Aug 8, 2024
9a4422d
Delete load_test.ipynb
keyaloding Aug 8, 2024
55ca394
Delete presentation.ipynb
keyaloding Aug 8, 2024
c91e74c
code quality
keyaloding Aug 8, 2024
63faf8b
code quality
keyaloding Aug 8, 2024
56911d6
fixed pydocstyle errors
keyaloding Aug 8, 2024
ac6bb23
updated tests
keyaloding Aug 8, 2024
0be5813
a
keyaloding Aug 8, 2024
46f63c6
fixed test not passing
keyaloding Aug 9, 2024
1cd6180
fixed test not passing
keyaloding Aug 9, 2024
a2f96ab
update
keyaloding Aug 9, 2024
30fbe6f
added cameras
keyaloding Aug 9, 2024
6a0cf27
fixed test fixture
keyaloding Aug 11, 2024
52d45b8
added fixture
keyaloding Aug 11, 2024
bab0917
updated tests
keyaloding Aug 11, 2024
a268e1f
update
keyaloding Aug 12, 2024
c57e572
updated toml
keyaloding Aug 12, 2024
e7b9fd3
z
keyaloding Aug 12, 2024
e932340
camera update
keyaloding Aug 13, 2024
6f3aabe
fixed OrphanBuildContainerError
keyaloding Aug 13, 2024
7a83046
one test left
keyaloding Aug 14, 2024
1a7e58b
passed all tests
keyaloding Aug 14, 2024
90fd397
removed unused function
keyaloding Aug 14, 2024
d8294b7
removed unused function
keyaloding Aug 20, 2024
c3b286e
Merge branch 'talmolab:main' into kloding/nwb_data_structures
keyaloding Aug 24, 2024
dedc06d
code quality update, removed `PoseEstimation` handling
keyaloding Sep 5, 2024
141d1dd
fixed black format issue
keyaloding Sep 5, 2024
67a62c2
added `SkeletonInstance` counter
keyaloding Sep 8, 2024
bf7fadf
Docs and nits
talmo Sep 25, 2024
62dc540
Add multi-video support
talmo Sep 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"attrs",
"h5py>=3.8.0",
"pynwb",
"ndx-pose",
"ndx-pose @ git+https://github.com/rly/ndx-pose@a847ad4be75e60ef9e413b8cbfc99c616fc9fd05",
"pandas",
"simplejson",
"imageio",
Expand Down
82 changes: 82 additions & 0 deletions save_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"minimal_instance.pkg.nwb_images/img_0.png\n"
]
},
{
"ename": "ValueError",
"evalue": "Can't write images with one color channel.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 16\u001b[0m\n\u001b[1;32m 14\u001b[0m img_path \u001b[38;5;241m=\u001b[39m save_path \u001b[38;5;241m/\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimg_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.png\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28mprint\u001b[39m(img_path)\n\u001b[0;32m---> 16\u001b[0m \u001b[43miio\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimage\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 17\u001b[0m img_paths\u001b[38;5;241m.\u001b[39mappend(img_path)\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28mprint\u001b[39m(img_paths)\n",
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/v3.py:147\u001b[0m, in \u001b[0;36mimwrite\u001b[0;34m(uri, image, plugin, extension, format_hint, **kwargs)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Write an ndimage to the given URI.\u001b[39;00m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03mThe exact behavior depends on the file type and plugin used. To learn about\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 136\u001b[0m \n\u001b[1;32m 137\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m imopen(\n\u001b[1;32m 140\u001b[0m uri,\n\u001b[1;32m 141\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 145\u001b[0m extension\u001b[38;5;241m=\u001b[39mextension,\n\u001b[1;32m 146\u001b[0m ) \u001b[38;5;28;01mas\u001b[39;00m img_file:\n\u001b[0;32m--> 147\u001b[0m encoded \u001b[38;5;241m=\u001b[39m \u001b[43mimg_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m encoded\n",
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/plugins/pillow.py:433\u001b[0m, in \u001b[0;36mPillowPlugin.write\u001b[0;34m(self, ndimage, mode, format, is_batch, **kwargs)\u001b[0m\n\u001b[1;32m 431\u001b[0m is_batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 432\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m--> 433\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt write images with one color channel.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 434\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m4\u001b[39m]:\n\u001b[1;32m 435\u001b[0m \u001b[38;5;66;03m# Note: this makes a channel-last assumption\u001b[39;00m\n\u001b[1;32m 436\u001b[0m is_batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"\u001b[0;31mValueError\u001b[0m: Can't write images with one color channel."
]
}
],
"source": [
"import imageio.v3 as iio\n",
"from pathlib import Path\n",
"import sleap_io as sio\n",
"\n",
"save_path = Path(\"minimal_instance.pkg.nwb_images\")\n",
"try:\n",
" save_path.mkdir(parents=True, exist_ok=True)\n",
"except Exception as e:\n",
" print(f\"An error {e} occurred. The directory could not be created.\")\n",
"img_paths = []\n",
"\n",
"labels_original = sio.load_slp(\"tests/data/slp/minimal_instance.pkg.slp\")\n",
"for i, lf in enumerate(labels_original):\n",
" img_path = save_path / f\"img_{i}.png\"\n",
" print(img_path)\n",
" iio.imwrite(img_path, lf.image)\n",
" img_paths.append(img_path)\n",
"print(img_paths)"
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve error handling for image writing.

This cell attempts to write images, but fails due to an unsupported single color channel. This issue should be caught and handled, or the notebook should ensure that only supported image formats are processed.

try:
    for i, lf in enumerate(labels_original):
        img_path = save_path / f"img_{i}.png"
        print(img_path)
        iio.imwrite(img_path, lf.image)
        img_paths.append(img_path)
except ValueError as e:
    print(f"Failed to write image: {e}")

}
],
"metadata": {
"kernelspec": {
"display_name": "io_dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
33 changes: 26 additions & 7 deletions sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,33 +47,48 @@ def save_slp(
return slp.write_labels(filename, labels, embed=embed)


def load_nwb(filename: str) -> Labels:
def load_nwb(filename: str, as_training: Optional[bool] = None) -> Labels:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to specify this flag when loading. We can just figure it out by scanning the NWB file for PoseTraining versus PoseEstimationSeries.

"""Load an NWB dataset as a SLEAP `Labels` object.

Args:
filename: Path to a NWB file (`.nwb`).
as_training: If `True`, load the dataset as a training dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify conditional logic in load_nwb.

The function includes an as_training parameter to load datasets as training data. However, the logic can be simplified to avoid redundancy.

-  if as_training is None:
-    return nwb.read_nwb(filename)
-  
  if as_training:
    return nwb.read_nwb_training(filename)
  else:
    return nwb.read_nwb(filename)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def load_nwb(filename: str, as_training: Optional[bool] = None) -> Labels:
"""Load an NWB dataset as a SLEAP `Labels` object.
Args:
filename: Path to a NWB file (`.nwb`).
as_training: If `True`, load the dataset as a training dataset.
def load_nwb(filename: str, as_training: Optional[bool] = None) -> Labels:
"""Load an NWB dataset as a SLEAP `Labels` object.
Args:
filename: Path to a NWB file (`.nwb`).
as_training: If `True`, load the dataset as a training dataset.
if as_training:
return nwb.read_nwb_training(filename)
else:
return nwb.read_nwb(filename)


Returns:
The dataset as a `Labels` object.
"""
return nwb.read_nwb(filename)


def save_nwb(labels: Labels, filename: str, append: bool = True):
def save_nwb(
labels: Labels,
filename: str,
as_training: bool = None,
append: bool = True,
**kwargs,
):
"""Save a SLEAP dataset to NWB format.

Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.

See also: nwb.write_nwb, nwb.append_nwb
"""
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
if as_training:
if append and Path(filename).exists():
nwb.append_nwb_training(labels, filename)
else:
nwb.write_nwb(labels, filename, None, None, True)

else:
nwb.write_nwb(labels, filename)
if append and Path(filename).exists():
nwb.append_nwb(labels, filename, **kwargs)
else:
nwb.write_nwb(labels, filename)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor save_nwb to reduce complexity and address previous comments.

The function has been updated to handle the as_training parameter. However, the current implementation is complex and can be simplified. Additionally, address the previous comment about reducing the depth of conditionals.

-    if as_training:
-        pose_training = nwb.labels_to_pose_training(labels, **kwargs)
-        if append and Path(filename).exists():
-            nwb.append_nwb_training(pose_training, filename, **kwargs)
-        else:
-            nwb.write_nwb_training(pose_training, filename, **kwargs)
-
-    else:
-        if append and Path(filename).exists():
-            nwb.append_nwb(labels, filename, **kwargs)
-        else:
-            nwb.write_nwb(labels, filename)
+    func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
+    action = nwb.append_nwb_training if as_training else nwb.append_nwb
+    write = nwb.write_nwb_training if as_training else nwb.write_nwb
+
+    data = func(labels, **kwargs)
+    if append and Path(filename).exists():
+        action(data, filename, **kwargs)
+    else:
+        write(data, filename, **kwargs)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def save_nwb(
labels: Labels,
filename: str,
as_training: bool = None,
append: bool = True,
**kwargs,
):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
See also: nwb.write_nwb, nwb.append_nwb
"""
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
if as_training:
if append and Path(filename).exists():
nwb.append_nwb_training(labels, filename)
else:
nwb.write_nwb(labels, filename, None, None, True)
else:
nwb.write_nwb(labels, filename)
if append and Path(filename).exists():
nwb.append_nwb(labels, filename, **kwargs)
else:
nwb.write_nwb(labels, filename)
func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
action = nwb.append_nwb_training if as_training else nwb.append_nwb
write = nwb.write_nwb_training if as_training else nwb.write_nwb
data = func(labels, **kwargs)
if append and Path(filename).exists():
action(data, filename, **kwargs)
else:
write(data, filename, **kwargs)



def load_labelstudio(
Expand Down Expand Up @@ -190,6 +205,8 @@ def load_file(
return load_jabs(filename, **kwargs)
elif format == "video":
return load_video(filename, **kwargs)
else:
raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.")


def save_file(
Expand Down Expand Up @@ -219,8 +236,10 @@ def save_file(

if format == "slp":
save_slp(labels, filename, **kwargs)
elif format == "nwb":
save_nwb(labels, filename, **kwargs)
elif format in ("nwb", "nwb_predictions"):
save_nwb(labels, filename, as_training=False, **kwargs)
elif format == "nwb_training":
save_nwb(labels, filename, as_training=True, **kwargs)
elif format == "labelstudio":
save_labelstudio(labels, filename, **kwargs)
elif format == "jabs":
Expand Down
Loading
Loading