Skip to content

Commit

Permalink
updated files from PR #104
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Aug 27, 2024
1 parent 5fa52d7 commit 901482d
Show file tree
Hide file tree
Showing 12 changed files with 625 additions and 44 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ labels = sio.load_file("predictions.slp")
sio.save_file(labels, "predictions.nwb")
# Or:
# labels.save("predictions.nwb")

# Save to an NWB file and convert SLEAP training data to NWB training data:
frame_inds = [i for i in range(20)]
sio.save_file(labels, "predictions.nwb", as_training=True, frame_inds=frame_inds)
# This will save the first 20 frames of the video as individual images
```

### Convert labels to raw arrays
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"attrs",
"h5py>=3.8.0",
"pynwb",
"ndx-pose",
"ndx-pose @ git+https://github.com/rly/ndx-pose@a847ad4be75e60ef9e413b8cbfc99c616fc9fd05",
"pandas",
"simplejson",
"imageio",
Expand Down
42 changes: 36 additions & 6 deletions sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Optional, Union
from pathlib import Path

from pynwb import NWBHDF5IO


def load_slp(filename: str) -> Labels:
"""Load a SLEAP dataset.
Expand Down Expand Up @@ -59,21 +61,45 @@ def load_nwb(filename: str) -> Labels:
return nwb.read_nwb(filename)


def save_nwb(labels: Labels, filename: str, append: bool = True):
def save_nwb(
labels: Labels,
filename: str,
as_training: bool = False,
append: bool = True,
frame_inds: Optional[list[int]] = None,
frame_path: Optional[str] = None,
):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
frame_inds: Optional list of frame indices to save. If None, all frames
will be saved.
frame_path: The path to save the frames. If None, the path is the video
filename without the extension.
See also: nwb.write_nwb, nwb.append_nwb
See also: nwb.write_nwb, nwb.append_nwb, nwb.append_nwb_training
"""
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
nwb.append_nwb(
labels,
filename,
as_training=as_training,
frame_inds=frame_inds,
frame_path=frame_path,
)
else:
nwb.write_nwb(labels, filename)
nwb.write_nwb(
labels,
filename,
as_training=as_training,
frame_inds=frame_inds,
frame_path=frame_path,
)


def load_labelstudio(
Expand Down Expand Up @@ -190,6 +216,8 @@ def load_file(
return load_jabs(filename, **kwargs)
elif format == "video":
return load_video(filename, **kwargs)
else:
raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.")


def save_file(
Expand Down Expand Up @@ -219,8 +247,10 @@ def save_file(

if format == "slp":
save_slp(labels, filename, **kwargs)
elif format == "nwb":
save_nwb(labels, filename, **kwargs)
elif format in ("nwb", "nwb_predictions"):
save_nwb(labels, filename, False)
elif format == "nwb_training":
save_nwb(labels, filename, True, frame_inds=kwargs.pop("frame_inds", None))
elif format == "labelstudio":
save_labelstudio(labels, filename, **kwargs)
elif format == "jabs":
Expand Down
Loading

0 comments on commit 901482d

Please sign in to comment.