Skip to content

Commit

Permalink
one test left
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Aug 14, 2024
1 parent 6f3aabe commit 7a83046
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 19 deletions.
80 changes: 69 additions & 11 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@ def read_nwb(path: str) -> Labels:
track_keys: List[str] = list(test_processing_module.fields["data_interfaces"])

# Get track
test_pose_estimation: PoseEstimation = test_processing_module[track_keys[0]]
for key in track_keys:
if isinstance(test_processing_module[key], PoseEstimation):
test_pose_estimation = test_processing_module[key]
break
node_names = test_pose_estimation.nodes[:]
edge_inds = test_pose_estimation.edges[:]

Expand All @@ -363,6 +366,10 @@ def read_nwb(path: str) -> Labels:
pose_estimation_series = processing_module[track_key][node_name]
except TypeError:
continue
except KeyError:
pose_estimation_series = processing_module[
"track=untracked"
].pose_estimation_series[node_name]
timestamps = np.union1d(
timestamps, get_timestamps(pose_estimation_series)
)
Expand All @@ -375,10 +382,20 @@ def read_nwb(path: str) -> Labels:
tracks_numpy = np.full((n_frames, n_tracks, n_nodes, 2), np.nan, np.float32)
confidence = np.full((n_frames, n_tracks, n_nodes), np.nan, np.float32)
for track_idx, track_key in enumerate(_track_keys):
pose_estimation = processing_module[track_key]
try:
pose_estimation = processing_module[track_key]
if not isinstance(pose_estimation, PoseEstimation):
raise KeyError
except KeyError:
pose_estimation = processing_module["track=untracked"]

for node_idx, node_name in enumerate(node_names):
pose_estimation_series = pose_estimation[node_name]
try:
pose_estimation_series = pose_estimation[node_name]
except KeyError:
pose_estimation_series = pose_estimation.pose_estimation_series[
node_name
]
frame_inds = np.searchsorted(
timestamps, get_timestamps(pose_estimation_series)
)
Expand Down Expand Up @@ -416,15 +433,24 @@ def read_nwb(path: str) -> Labels:
):
if np.isnan(inst_pts).all():
continue
insts.append(
Instance.from_numpy(
points=inst_pts, # (n_nodes, 2)
point_scores=inst_confs, # (n_nodes,)
instance_score=inst_confs.mean(), # ()
skeleton=skeleton,
track=track if is_tracked else None,
try:
insts.append(
Instance.from_numpy(
points=inst_pts, # (n_nodes, 2)
point_scores=inst_confs, # (n_nodes,)
instance_score=inst_confs.mean(), # ()
skeleton=skeleton,
track=track if is_tracked else None,
)
)
except TypeError:
insts.append(
Instance.from_numpy(
points=inst_pts,
skeleton=skeleton,
track=track if is_tracked else None,
)
)
)
if len(insts) > 0:
lfs.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=insts)
Expand Down Expand Up @@ -535,6 +561,38 @@ def write_nwb(
io.write(nwbfile)


def handle_orphan_container_error(labels: Labels, nwbfile: NWBFile) -> NWBFile:
"""Handle orphan container error by adding a skeleton to the processing module.
Args:
labels: A general labels object.
nwbfile: An in-memory nwbfile where the data is to be appended.
Returns:
An in-memory nwbfile with the data from the labels object appended.
"""
processing_module = nwbfile.processing[
f"SLEAP_VIDEO_000_{Path(labels.videos[0].filename).stem}"
]
if "track=untracked" in processing_module.containers:
pose_estimation = processing_module.containers["track=untracked"]
skeleton = pose_estimation.skeleton
skeletons = Skeletons(skeletons=[skeleton])
else:
skeletons = []
for i in range(len(labels.tracks)):
pose_estimation = processing_module.containers[f"track=track_{i}"]
skeleton = pose_estimation.skeleton
skeletons.append(skeleton)
try:
processing_module.add(skeletons)
except ValueError:
skeleton = pose_estimation.skeleton
skeletons = Skeletons(skeletons=[skeleton])
processing_module.add(skeletons)
return nwbfile


def append_nwb_data(
labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] = None
) -> NWBFile:
Expand Down
14 changes: 6 additions & 8 deletions tests/io/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,19 @@ def test_load_slp(slp_typical):
assert type(load_file(slp_typical)) == Labels


def test_nwb(tmp_path, slp_typical):
def test_nwb(tmp_path, slp_typical, slp_predictions_with_provenance):
labels = load_slp(slp_typical)
save_nwb(labels, tmp_path / "test_nwb.nwb", False)
loaded_labels = load_nwb(tmp_path / "test_nwb.nwb")
assert type(loaded_labels) == Labels
assert type(load_file(tmp_path / "test_nwb.nwb")) == Labels
assert len(loaded_labels) == len(labels)

labels2 = load_slp(slp_typical)
labels2.videos[0].filename = "test"
save_nwb(labels2, tmp_path / "test_nwb.nwb", append=True)
loaded_labels = load_nwb(tmp_path / "test_nwb.nwb")
assert type(loaded_labels) == Labels
assert len(loaded_labels) == (len(labels) + len(labels2))
assert len(loaded_labels.videos) == 2
labels2 = load_slp(slp_predictions_with_provenance)
save_nwb(labels2, tmp_path / "test_nwb2.nwb", False)
loaded_labels2 = load_nwb(tmp_path / "test_nwb2.nwb")
assert type(loaded_labels2) == Labels
assert len(loaded_labels2) == len(labels2)


def test_nwb_training(tmp_path, slp_typical):
Expand Down

0 comments on commit 7a83046

Please sign in to comment.