Skip to content

Commit

Permalink
push
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Jul 12, 2024
1 parent 70a34ed commit f2fd6f5
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 60 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"attrs",
"h5py>=3.8.0",
"pynwb",
"ndx-pose",
"ndx-pose @ git+https://github.com/rly/ndx-pose@a847ad4be75e60ef9e413b8cbfc99c616fc9fd05",
"pandas",
"simplejson",
"imageio",
Expand Down
34 changes: 26 additions & 8 deletions sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,50 @@ def save_slp(
return slp.write_labels(filename, labels, embed=embed)


def load_nwb(filename: str) -> Labels:
def load_nwb(filename: str, as_training: Optional[bool]=None) -> Labels:
"""Load an NWB dataset as a SLEAP `Labels` object.
Args:
filename: Path to a NWB file (`.nwb`).
as_training: If `True`, load the dataset as a training dataset.
Returns:
The dataset as a `Labels` object.
"""
return nwb.read_nwb(filename)
if as_training is None:
return

if as_training:
return nwb.read_nwb_training(filename)
else:
return nwb.read_nwb(filename)


def save_nwb(labels: Labels, filename: str, append: bool = True):
def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs):
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to NWB file to save to. Must end in `.nwb`.
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
See also: nwb.write_nwb, nwb.append_nwb
"""
if append and Path(filename).exists():
nwb.append_nwb(labels, filename)
else:
nwb.write_nwb(labels, filename)
if as_training:
pose_training = nwb.labels_to_pose_training(labels, **kwargs)
if append and Path(filename).exists():
nwb.append_nwb_training(pose_training, filename, **kwargs)
else:
nwb.write_nwb_training(pose_training, filename, **kwargs)

else:
if append and Path(filename).exists():
nwb.append_nwb(labels, filename, **kwargs)
else:
nwb.write_nwb(labels, filename)


def load_labelstudio(
filename: str, skeleton: Optional[Union[Skeleton, list[str]]] = None
Expand Down Expand Up @@ -190,6 +206,8 @@ def load_file(
return load_jabs(filename, **kwargs)
elif format == "video":
return load_video(filename, **kwargs)
else:
raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.")


def save_file(
Expand Down Expand Up @@ -219,7 +237,7 @@ def save_file(

if format == "slp":
save_slp(labels, filename, **kwargs)
elif format == "nwb":
elif format in ("nwb", "nwb_training", "nwb_predictions"):
save_nwb(labels, filename, **kwargs)
elif format == "labelstudio":
save_labelstudio(labels, filename, **kwargs)
Expand Down
Loading

0 comments on commit f2fd6f5

Please sign in to comment.