Skip to content

Commit

Permalink
Add explicit check for file accessibility
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Sep 29, 2024
1 parent b0af70f commit 8e179ca
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 82 deletions.
70 changes: 69 additions & 1 deletion sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,75 @@
Instance,
PredictedInstance,
)
from sleap_io.io.utils import convert_predictions_to_dataframe


def convert_predictions_to_dataframe(labels: Labels) -> pd.DataFrame:
"""Convert predictions data to a Pandas dataframe.
Args:
labels: A general label object.
Returns:
pd.DataFrame: A pandas data frame with the structured data with
hierarchical columns. The column hierarchy is:
"video_path",
"skeleton_name",
"track_name",
"node_name",
And it is indexed by the frames.
Raises:
ValueError: If no frames in the label objects contain predicted instances.
"""
# Form pairs of labeled_frames and predicted instances
labeled_frames = labels.labeled_frames
all_frame_instance_tuples = (
(label_frame, instance) # type: ignore
for label_frame in labeled_frames
for instance in label_frame.predicted_instances
)

# Extract the data
data_list = list()
for labeled_frame, instance in all_frame_instance_tuples:
# Traverse the nodes of the instances's skeleton
skeleton = instance.skeleton
for node in skeleton.nodes:
row_dict = dict(
frame_idx=labeled_frame.frame_idx,
x=instance.points[node].x,
y=instance.points[node].y,
score=instance.points[node].score, # type: ignore[attr-defined]
node_name=node.name,
skeleton_name=skeleton.name,
track_name=instance.track.name if instance.track else "untracked",
video_path=labeled_frame.video.filename,
)
data_list.append(row_dict)

if not data_list:
raise ValueError("No predicted instances found in labels object")

labels_df = pd.DataFrame(data_list)

# Reformat the data with columns for dict-like hierarchical data access.
index = [
"skeleton_name",
"track_name",
"node_name",
"video_path",
"frame_idx",
]

labels_tidy_df = (
labels_df.set_index(index)
.unstack(level=[0, 1, 2, 3])
.swaplevel(0, -1, axis=1) # video_path on top while x, y score on bottom
.sort_index(axis=1) # Better format for columns
.sort_index(axis=0) # Sorts by frames
)

return labels_tidy_df


def get_timestamps(series: PoseEstimationSeries) -> np.ndarray:
Expand Down
13 changes: 5 additions & 8 deletions sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
Labels,
)
from sleap_io.io.video import VideoBackend, ImageVideo, MediaVideo, HDF5Video
from sleap_io.io.utils import (
read_hdf5_attrs,
read_hdf5_dataset,
)
from sleap_io.io.utils import read_hdf5_attrs, read_hdf5_dataset, is_file_accessible
from enum import IntEnum
from pathlib import Path
import imageio.v3 as iio
Expand Down Expand Up @@ -107,11 +104,11 @@ def make_video(
backend = None
if open_backend:
try:
if not video_path.exists():
if not is_file_accessible(video_path):
# Check for the same filename in the same directory as the labels file.
video_path_ = Path(labels_path).parent / video_path.name
if video_path_.exists() and video_path.stat():
video_path = video_path_
candidate_video_path = Path(labels_path).parent / video_path.name
if is_file_accessible(candidate_video_path):
video_path = candidate_video_path
else:
# TODO (TP): Expand capabilities of path resolution to support more
# complex path finding strategies.
Expand Down
84 changes: 17 additions & 67 deletions sleap_io/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from __future__ import annotations
import h5py # type: ignore[import]
import numpy as np
import pandas as pd # type: ignore[import]
from typing import Any, Union, Optional, Generator
from sleap_io import Labels, LabeledFrame, PredictedInstance
from typing import Any, Union, Optional
from pathlib import Path


def read_hdf5_dataset(filename: str, dataset: str) -> np.ndarray:
Expand Down Expand Up @@ -175,72 +174,23 @@ def _overwrite_hdf5_attr(
_overwrite_hdf5_attr(ds, attr_name, attr_value)


def convert_predictions_to_dataframe(labels: Labels) -> pd.DataFrame:
"""Convert predictions data to a Pandas dataframe.
def is_file_accessible(filename: str | Path) -> bool:
"""Check if a file is accessible.
Args:
labels: A general label object.
filename: Path to a file.
Returns:
pd.DataFrame: A pandas data frame with the structured data with
hierarchical columns. The column hierarchy is:
"video_path",
"skeleton_name",
"track_name",
"node_name",
And it is indexed by the frames.
Raises:
ValueError: If no frames in the label objects contain predicted instances.
`True` if the file is accessible, `False` otherwise.
Notes:
This checks if the file readable by the current user by reading one byte from
the file.
"""
# Form pairs of labeled_frames and predicted instances
labeled_frames = labels.labeled_frames
all_frame_instance_tuples: Generator[
tuple[LabeledFrame, PredictedInstance], None, None
] = (
(label_frame, instance) # type: ignore
for label_frame in labeled_frames
for instance in label_frame.predicted_instances
)

# Extract the data
data_list = list()
for labeled_frame, instance in all_frame_instance_tuples:
# Traverse the nodes of the instances's skeleton
skeleton = instance.skeleton
for node in skeleton.nodes:
row_dict = dict(
frame_idx=labeled_frame.frame_idx,
x=instance.points[node].x,
y=instance.points[node].y,
score=instance.points[node].score, # type: ignore[attr-defined]
node_name=node.name,
skeleton_name=skeleton.name,
track_name=instance.track.name if instance.track else "untracked",
video_path=labeled_frame.video.filename,
)
data_list.append(row_dict)

if not data_list:
raise ValueError("No predicted instances found in labels object")

labels_df = pd.DataFrame(data_list)

# Reformat the data with columns for dict-like hierarchical data access.
index = [
"skeleton_name",
"track_name",
"node_name",
"video_path",
"frame_idx",
]

labels_tidy_df = (
labels_df.set_index(index)
.unstack(level=[0, 1, 2, 3])
.swaplevel(0, -1, axis=1) # video_path on top while x, y score on bottom
.sort_index(axis=1) # Better format for columns
.sort_index(axis=0) # Sorts by frames
)

return labels_tidy_df
filename = Path(filename)
try:
with open(filename, "rb") as f:
f.read(1)
return True
except (FileNotFoundError, PermissionError, OSError, ValueError):
return False
12 changes: 8 additions & 4 deletions sleap_io/model/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Tuple, Optional, Optional
import numpy as np
from sleap_io.io.video import VideoBackend, MediaVideo, HDF5Video, ImageVideo
from sleap_io.io.utils import is_file_accessible
from pathlib import Path


Expand Down Expand Up @@ -197,21 +198,24 @@ def __getitem__(self, inds: int | list[int] | slice) -> np.ndarray:
return self.backend[inds]

def exists(self, check_all: bool = False) -> bool:
"""Check if the video file exists.
"""Check if the video file exists and is accessible.
Args:
check_all: If `True`, check that all filenames in a list exist. If `False`
(the default), check that the first filename exists.
Returns:
`True` if the file exists and is accessible, `False` otherwise.
"""
if isinstance(self.filename, list):
if check_all:
for f in self.filename:
if not Path(f).exists():
if not is_file_accessible(f):
return False
return True
else:
return Path(self.filename[0]).exists()
return Path(self.filename).exists()
return is_file_accessible(self.filename[0])
return is_file_accessible(self.filename)

@property
def is_open(self) -> bool:
Expand Down
36 changes: 35 additions & 1 deletion tests/io/test_slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import simplejson as json
import pytest
from pathlib import Path

import shutil
from sleap_io.io.video import ImageVideo, HDF5Video, MediaVideo


Expand Down Expand Up @@ -363,3 +363,37 @@ def test_lazy_video_read(slp_real_data):

labels = read_labels(slp_real_data, open_videos=False)
assert labels.video.backend is None


def test_video_path_resolution(slp_real_data, tmp_path):
labels = read_labels(slp_real_data)
assert (
Path(labels.video.filename).as_posix()
== "tests/data/videos/centered_pair_low_quality.mp4"
)
shutil.copyfile(labels.video.filename, tmp_path / "centered_pair_low_quality.mp4")
labels.video.replace_filename(
"fake/path/to/centered_pair_low_quality.mp4", open=False
)
labels.save(tmp_path / "labels.slp")

# Resolve when the same video filename is found in the labels directory.
labels = read_labels(tmp_path / "labels.slp")
assert (
Path(labels.video.filename).as_posix()
== (tmp_path / "centered_pair_low_quality.mp4").as_posix()
)
assert labels.video.exists()

# Make the video file inaccessible.
labels.video.replace_filename("new_fake/path/to/inaccessible.mp4", open=False)
labels.save(tmp_path / "labels2.slp")
shutil.copyfile(
tmp_path / "centered_pair_low_quality.mp4", tmp_path / "inaccessible.mp4"
)
Path(tmp_path / "inaccessible.mp4").chmod(0o000)

# Fail to resolve when the video file is inaccessible.
labels = read_labels(tmp_path / "labels2.slp")
assert not labels.video.exists()
assert Path(labels.video.filename).as_posix() == "new_fake/path/to/inaccessible.mp4"
23 changes: 22 additions & 1 deletion tests/model/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_video_exists(centered_pair_low_quality_video, centered_pair_frame_paths
assert video.exists(check_all=True) is False


def test_video_open_close(centered_pair_low_quality_path):
def test_video_open_close(centered_pair_low_quality_path, centered_pair_frame_paths):
video = Video(centered_pair_low_quality_path)
assert video.is_open
assert type(video.backend) == MediaVideo
Expand Down Expand Up @@ -91,6 +91,10 @@ def test_video_open_close(centered_pair_low_quality_path):
video.open(grayscale=True)
assert video.shape == (1100, 384, 384, 1)

video.open(centered_pair_frame_paths)
assert video.shape == (3, 384, 384, 1)
assert type(video.backend) == ImageVideo


def test_video_replace_filename(
centered_pair_low_quality_path, centered_pair_frame_paths
Expand Down Expand Up @@ -142,3 +146,20 @@ def test_grayscale(centered_pair_low_quality_path):
video.open()
assert video.grayscale == True
assert video.shape[-1] == 1


def test_open_backend_preference(centered_pair_low_quality_path):
video = Video(centered_pair_low_quality_path)
assert video.is_open
assert type(video.backend) == MediaVideo

video = Video(centered_pair_low_quality_path, open_backend=False)
assert video.is_open is False
assert video.backend is None
with pytest.raises(ValueError):
video[0]

video.open_backend = True
img = video[0]
assert video.is_open
assert type(video.backend) == MediaVideo

0 comments on commit 8e179ca

Please sign in to comment.