Skip to content

Commit

Permalink
final edits
Browse files Browse the repository at this point in the history
  • Loading branch information
emdavis02 committed Jul 25, 2024
1 parent a69a65c commit 31eb3fb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
2 changes: 1 addition & 1 deletion sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5629,7 +5629,7 @@ def main(args: Optional[list] = None):
# Save results.
try:
labels_pr.save(output_path)
except Exception as e:
except Exception:
print("WARNING: Provided output path invalid.")
fallback_path = data_path_obj.with_suffix(".predictions.slp")
labels_pr.save(fallback_path)
Expand Down
32 changes: 20 additions & 12 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import tensorflow as tf
import tensorflow_hub as hub
from numpy.testing import assert_array_equal, assert_allclose
from sleap.io.video import available_video_exts

import sleap
from sleap.gui.learning import runners
Expand Down Expand Up @@ -1556,12 +1557,10 @@ def test_sleap_track_mult_input_slp(
sleap_track(args=args)

# Assert predictions file exists
expected_extensions = {
".mp4",
} # Add other video formats if necessary
expected_extensions = available_video_exts()

for file_path in slp_path_list:
if file_path in expected_extensions:
if file_path.suffix in expected_extensions:
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
assert Path(expected_output_file).exists()

Expand Down Expand Up @@ -1602,9 +1601,10 @@ def test_sleap_track_mult_input_slp_mp4(
# Run inference
sleap_track(args=args)

# Assert predictions file exists
expected_extensions = available_video_exts()

for file_path in slp_path_list:
if file_path.suffix == ".mp4":
if file_path.suffix in expected_extensions:
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
assert Path(expected_output_file).exists()

Expand Down Expand Up @@ -1643,8 +1643,10 @@ def test_sleap_track_mult_input_mp4(
sleap_track(args=args)

# Assert predictions file exists
expected_extensions = available_video_exts()

for file_path in slp_path_list:
if file_path.suffix == ".mp4":
if file_path.suffix in expected_extensions:
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
assert Path(expected_output_file).exists()

Expand Down Expand Up @@ -1686,8 +1688,10 @@ def test_sleap_track_output_mult(
slp_path = Path(slp_path)

# Check if there are any files in the directory
expected_extensions = available_video_exts()

for file_path in slp_path_list:
if file_path.suffix == ".mp4":
if file_path.suffix in expected_extensions:
expected_output_file = output_path_obj / (
file_path.stem + ".predictions.slp"
)
Expand Down Expand Up @@ -1808,8 +1812,10 @@ def test_sleap_track_csv_input(
sleap_track(args=args)

# Assert predictions file exists
expected_extensions = available_video_exts()

for file_path in slp_path_list:
if file_path.suffix == ".mp4":
if file_path.suffix in expected_extensions:
expected_output_file = file_path.with_suffix(".TESTpredictions.slp")
assert Path(expected_output_file).exists()

Expand Down Expand Up @@ -1839,7 +1845,7 @@ def test_sleap_track_invalid_csv(
).split()

# Run inference and expect ValueError for missing 'data_path' column
with pytest.raises(ValueError):
with pytest.raises(ValueError, match=f"Column containing valid data_paths does not exist in the CSV file: {csv_missing_column_path}"):
sleap_track(args=args_missing_column)

# Create sleap-track command for empty CSV file
Expand All @@ -1850,7 +1856,7 @@ def test_sleap_track_invalid_csv(
).split()

# Run inference and expect ValueError for empty CSV file
with pytest.raises(ValueError):
with pytest.raises(ValueError, match = f"CSV file is empty: {csv_empty_path}"):
sleap_track(args=args_empty)


Expand Down Expand Up @@ -1894,8 +1900,10 @@ def test_sleap_track_text_file_input(
sleap_track(args=args)

# Assert predictions file exists
expected_extensions = available_video_exts()

for file_path in slp_path_list:
if file_path.suffix == ".mp4":
if file_path.suffix in expected_extensions:
expected_output_file = Path(file_path).with_suffix(".predictions.slp")
assert Path(expected_output_file).exists()

Expand Down

0 comments on commit 31eb3fb

Please sign in to comment.