Skip to content

Commit

Permalink
extending test cases for mp4 folders
Browse files Browse the repository at this point in the history
  • Loading branch information
emdavis02 committed Jun 13, 2024
1 parent b0ac880 commit dcc7a63
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 19 deletions.
6 changes: 2 additions & 4 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5495,7 +5495,7 @@ def main(args: Optional[list] = None):
output_path_obj = Path(output_path)

# Output path given is a file, but multiple inputs were given
if output_path is not None and (Path.is_file(output_path_obj) and data_path_list.len() > 1):
if output_path is not None and (Path.is_file(output_path_obj) and len(data_path_list) > 1):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)
Expand All @@ -5517,9 +5517,7 @@ def main(args: Optional[list] = None):
labels_pr = predictor.predict(provider)

if output_path is None:
#if data_path.as_posix().endswith(".slp"):
# output_path = data_path
#else:

output_path = data_path.parent / (data_path.stem + ".predictions.slp")


Expand Down
129 changes: 114 additions & 15 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import zipfile
from pathlib import Path
from typing import cast
import os

import numpy as np
import pytest
Expand Down Expand Up @@ -65,12 +64,16 @@
# sleap.nn.system.use_cpu_only()

@pytest.fixture
def test_sleap_track_mult_inputs_folder_slp_files():
return "tests/data/videos/slp_multiple_inputs"
def test_sleap_track_mult_inputs_folder_slp():
return "tests/data/videos/multiple_inputs_slp"

@pytest.fixture
def test_sleap_track_mult_inputs_folder():
return "tests/data/videos/multiple_inputs"
def test_sleap_track_mult_inputs_folder_slp_mp4():
return "tests/data/videos/multiple_inputs_slp_mp4"

@pytest.fixture
def test_sleap_track_mult_inputs_folder_mp4():
return "tests/data/videos/multiple_inputs_mp4"


@pytest.fixture
Expand Down Expand Up @@ -1483,25 +1486,21 @@ def test_sleap_track_single_input(
with pytest.raises(ValueError):
sleap_track(args=args)

#@pytest.mark.parametrize("tracking", ["simple", "flow", "simplemaxtracks", "flowmaxtracks", "None"])
def test_sleap_track_mult_input(
@pytest.mark.parametrize("tracking", ["simple", "flow", "None"])
def test_sleap_track_mult_input_slp(
centered_pair_predictions: Labels,
min_centroid_model_path: str,
min_centered_instance_model_path: str,
#test_sleap_track_mult_inputs_folder: str,
test_sleap_track_mult_inputs_folder_slp_files: str,
#tracking
test_sleap_track_mult_inputs_folder_slp: str,
tracking
):
slp_path = test_sleap_track_mult_inputs_folder_slp_files
#slp_path = test_sleap_track_mult_inputs_folder
slp_path = test_sleap_track_mult_inputs_folder_slp
slp_path_obj = Path(slp_path)
Labels.save(centered_pair_predictions, slp_path)

# Create sleap-track command
args = (
f"{slp_path} --model {min_centroid_model_path} "
#f"--tracking.tracker {tracking} "
f"--tracking.tracker simple "
f"--tracking.tracker {tracking} "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

Expand Down Expand Up @@ -1537,6 +1536,106 @@ def test_sleap_track_mult_input(
file.unlink()


@pytest.mark.parametrize("tracking", ["simple", "flow", "None"])
def test_sleap_track_mult_input_slp_mp4(
centered_pair_predictions: Labels,
min_centroid_model_path: str,
min_centered_instance_model_path: str,
test_sleap_track_mult_inputs_folder_slp_mp4: str,
tracking
):
slp_path = test_sleap_track_mult_inputs_folder_slp_mp4
slp_path_obj = Path(slp_path)

# Create sleap-track command
args = (
f"{slp_path} --model {min_centroid_model_path} "
f"--tracking.tracker {tracking} "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

if Path.is_dir(slp_path_obj):
slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
slp_path_list.append(file_path)
elif Path.is_file(slp_path_obj):
slp_path_list = [args.data_path]

for output_path in slp_path_list:
assert Path(output_path).exists()

# Run inference
sleap_track(args=args)
slp_path = Path(slp_path)

# Assert predictions file exists
if Path.is_dir(slp_path_obj):
new_slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
new_slp_path_list.append(file_path)
elif Path.is_file(slp_path):
new_slp_path_list = [args.data_path]

for output_path in slp_path_list:
assert Path(output_path).exists()

files_to_remove = set(new_slp_path_list) - set(slp_path_list)
for file in files_to_remove:
file.unlink()

#@pytest.mark.parametrize("tracking", ["simple", "flow", "None"])
def test_sleap_track_mult_input_mp4(
centered_pair_predictions: Labels,
min_centroid_model_path: str,
min_centered_instance_model_path: str,
test_sleap_track_mult_inputs_folder_mp4: str,
#tracking
):
slp_path = test_sleap_track_mult_inputs_folder_mp4
slp_path_obj = Path(slp_path)

# Create sleap-track command
args = (
f"{slp_path} --model {min_centroid_model_path} "
#f"--tracking.tracker {tracking} "
f"--tracking.tracker simple "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

if Path.is_dir(slp_path_obj):
slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
slp_path_list.append(file_path)
elif Path.is_file(slp_path_obj):
slp_path_list = [args.data_path]

for output_path in slp_path_list:
assert Path(output_path).exists()

# Run inference
sleap_track(args=args)
slp_path = Path(slp_path)

# Assert predictions file exists
if Path.is_dir(slp_path_obj):
new_slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
new_slp_path_list.append(file_path)
elif Path.is_file(slp_path):
new_slp_path_list = [args.data_path]

for output_path in slp_path_list:
assert Path(output_path).exists()

#files_to_remove = set(new_slp_path_list) - set(slp_path_list)
#for file in files_to_remove:
# file.unlink()


def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
"""Test flow tracker instances are pruned."""
labels: Labels = centered_pair_predictions
Expand Down

0 comments on commit dcc7a63

Please sign in to comment.