-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for more .nwb training data to SLEAP #104
base: main
Are you sure you want to change the base?
Changes from 11 commits
70a34ed
f2fd6f5
6c062ec
f7d8a0c
23a5a83
cf4bcf7
8948e92
be6ccbc
0aabd86
85d47ab
1b8e08b
aab5a79
66edcd6
1b70d5c
c59f996
50da4a1
c8c173a
0742bdf
055e3e8
dcc3da8
0ba071e
0c7f048
6a6c38c
b15471e
dd522f9
22ed734
7d1737d
6f1a7d1
51d5623
e511efa
42da933
e246f8b
13d3a47
e579ebf
90f0ec4
24634a1
4819536
432c343
7cb04c4
a66cf8f
8f2271b
f510765
183cf55
7e2d138
846856b
f88ca17
0a8fc45
73b41b0
a9be0eb
1031247
68c7cfb
3bf7c8a
6d12f39
e94c14d
dac1eb5
593c139
68133e1
d66b858
0abaca4
4ba67ea
3bac601
9699be2
f01e03a
9a4422d
55ca394
c91e74c
63faf8b
56911d6
ac6bb23
0be5813
46f63c6
1cd6180
a2f96ab
30fbe6f
6a0cf27
52d45b8
bab0917
a268e1f
c57e572
e7b9fd3
e932340
6f3aabe
7a83046
1a7e58b
90fd397
d8294b7
c3b286e
dedc06d
141d1dd
67a62c2
bf7fadf
62dc540
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"minimal_instance.pkg.nwb_images/img_0.png\n" | ||
] | ||
}, | ||
{ | ||
"ename": "ValueError", | ||
"evalue": "Can't write images with one color channel.", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[2], line 16\u001b[0m\n\u001b[1;32m 14\u001b[0m img_path \u001b[38;5;241m=\u001b[39m save_path \u001b[38;5;241m/\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimg_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.png\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28mprint\u001b[39m(img_path)\n\u001b[0;32m---> 16\u001b[0m \u001b[43miio\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimage\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 17\u001b[0m img_paths\u001b[38;5;241m.\u001b[39mappend(img_path)\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28mprint\u001b[39m(img_paths)\n", | ||
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/v3.py:147\u001b[0m, in \u001b[0;36mimwrite\u001b[0;34m(uri, image, plugin, extension, format_hint, **kwargs)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Write an ndimage to the given URI.\u001b[39;00m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03mThe exact behavior depends on the file type and plugin used. To learn about\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 136\u001b[0m \n\u001b[1;32m 137\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m imopen(\n\u001b[1;32m 140\u001b[0m uri,\n\u001b[1;32m 141\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 145\u001b[0m extension\u001b[38;5;241m=\u001b[39mextension,\n\u001b[1;32m 146\u001b[0m ) \u001b[38;5;28;01mas\u001b[39;00m img_file:\n\u001b[0;32m--> 147\u001b[0m encoded \u001b[38;5;241m=\u001b[39m \u001b[43mimg_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m encoded\n", | ||
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/plugins/pillow.py:433\u001b[0m, in \u001b[0;36mPillowPlugin.write\u001b[0;34m(self, ndimage, mode, format, is_batch, **kwargs)\u001b[0m\n\u001b[1;32m 431\u001b[0m is_batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 432\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m--> 433\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt write images with one color channel.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 434\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m ndimage\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m4\u001b[39m]:\n\u001b[1;32m 435\u001b[0m \u001b[38;5;66;03m# Note: this makes a channel-last assumption\u001b[39;00m\n\u001b[1;32m 436\u001b[0m is_batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", | ||
"\u001b[0;31mValueError\u001b[0m: Can't write images with one color channel." | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import imageio.v3 as iio\n", | ||
"from pathlib import Path\n", | ||
"import sleap_io as sio\n", | ||
"\n", | ||
"save_path = Path(\"minimal_instance.pkg.nwb_images\")\n", | ||
"try:\n", | ||
" save_path.mkdir(parents=True, exist_ok=True)\n", | ||
"except Exception as e:\n", | ||
" print(f\"An error {e} occurred. The directory could not be created.\")\n", | ||
"img_paths = []\n", | ||
"\n", | ||
"labels_original = sio.load_slp(\"tests/data/slp/minimal_instance.pkg.slp\")\n", | ||
"for i, lf in enumerate(labels_original):\n", | ||
" img_path = save_path / f\"img_{i}.png\"\n", | ||
" print(img_path)\n", | ||
" iio.imwrite(img_path, lf.image)\n", | ||
" img_paths.append(img_path)\n", | ||
"print(img_paths)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "io_dev", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -47,33 +47,48 @@ def save_slp( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return slp.write_labels(filename, labels, embed=embed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def load_nwb(filename: str) -> Labels: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def load_nwb(filename: str, as_training: Optional[bool] = None) -> Labels: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to specify this flag when loading. We can just figure it out by scanning the NWB file for |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Load an NWB dataset as a SLEAP `Labels` object. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
filename: Path to a NWB file (`.nwb`). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
as_training: If `True`, load the dataset as a training dataset. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify conditional logic in The function includes an - if as_training is None:
- return nwb.read_nwb(filename)
-
if as_training:
return nwb.read_nwb_training(filename)
else:
return nwb.read_nwb(filename) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
The dataset as a `Labels` object. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return nwb.read_nwb(filename) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def save_nwb(labels: Labels, filename: str, append: bool = True): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def save_nwb( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
labels: Labels, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
filename: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
as_training: bool = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
append: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Save a SLEAP dataset to NWB format. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
labels: A SLEAP `Labels` object (see `load_slp`). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
filename: Path to NWB file to save to. Must end in `.nwb`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
as_training: If `True`, save the dataset as a training dataset. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
append: If `True` (the default), append to existing NWB file. File will be | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
created if it does not exist. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
See also: nwb.write_nwb, nwb.append_nwb | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if append and Path(filename).exists(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nwb.append_nwb(labels, filename) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if as_training: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if append and Path(filename).exists(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nwb.append_nwb_training(labels, filename) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nwb.write_nwb(labels, filename, None, None, True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nwb.write_nwb(labels, filename) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if append and Path(filename).exists(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nwb.append_nwb(labels, filename, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nwb.write_nwb(labels, filename) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor The function has been updated to handle the - if as_training:
- pose_training = nwb.labels_to_pose_training(labels, **kwargs)
- if append and Path(filename).exists():
- nwb.append_nwb_training(pose_training, filename, **kwargs)
- else:
- nwb.write_nwb_training(pose_training, filename, **kwargs)
-
- else:
- if append and Path(filename).exists():
- nwb.append_nwb(labels, filename, **kwargs)
- else:
- nwb.write_nwb(labels, filename)
+ func = nwb.labels_to_pose_training if as_training else lambda l, **kw: l
+ action = nwb.append_nwb_training if as_training else nwb.append_nwb
+ write = nwb.write_nwb_training if as_training else nwb.write_nwb
+
+ data = func(labels, **kwargs)
+ if append and Path(filename).exists():
+ action(data, filename, **kwargs)
+ else:
+ write(data, filename, **kwargs) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def load_labelstudio( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -190,6 +205,8 @@ def load_file( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return load_jabs(filename, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif format == "video": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return load_video(filename, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def save_file( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -219,8 +236,10 @@ def save_file( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if format == "slp": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
save_slp(labels, filename, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif format == "nwb": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
save_nwb(labels, filename, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif format in ("nwb", "nwb_predictions"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
save_nwb(labels, filename, as_training=False, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif format == "nwb_training": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
save_nwb(labels, filename, as_training=True, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif format == "labelstudio": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
save_labelstudio(labels, filename, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif format == "jabs": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve error handling for image writing.
This cell attempts to write images, but fails due to an unsupported single color channel. This issue should be caught and handled, or the notebook should ensure that only supported image formats are processed.