Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
emdavis02 committed Jul 10, 2024
1 parent beb5e1e commit 55bfe4b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
4 changes: 2 additions & 2 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5291,13 +5291,13 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:

# Figure out which input path to use.
data_path = args.data_path

if data_path == None or data_path == "":
raise ValueError(
"You must specify a path to a video or a labels dataset. "
"Run 'sleap-track -h' to see full command documentation."
)

data_path_obj = Path(data_path)

# Check for multiple video inputs
Expand Down
31 changes: 29 additions & 2 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_sleap_track_mult_inputs_folder_mp4():
def test_sleap_track_invalid_input_path():
return "tests/data/videos/invalid_input_test"


@pytest.fixture
def test_sleap_track_output_file():
return "tests/data/videos/output_test_file.slp"
Expand Down Expand Up @@ -1711,8 +1712,8 @@ def test_sleap_track_output_mult(
files_to_remove = set(new_output_path_list) - set(output_path_list)
for file in files_to_remove:
file.unlink()


def test_sleap_track_invalid_output(
test_sleap_track_output_file: str,
min_centroid_model_path: str,
Expand Down Expand Up @@ -1754,6 +1755,32 @@ def test_sleap_track_invalid_input(
sleap_track(args=args)


def test_sleap_track_user_labeled_frames(
centered_pair_predictions: Labels,
min_centroid_model_path: str,
min_centered_instance_model_path: str,
tmpdir,
):
slp_path = str(Path(tmpdir, "old_slp.slp"))
Labels.save(centered_pair_predictions, slp_path)

# Create sleap-track command
args = (
f"{slp_path} --model {min_centroid_model_path} "
"--only-labeled-frames "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

# Run inference
sleap_track(args=args)

# Assert predictions file exists
output_path = f"{slp_path}.predictions.slp"
assert Path(output_path).exists()

# Create invalid sleap-track command


def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
"""Test flow tracker instances are pruned."""
labels: Labels = centered_pair_predictions
Expand Down

0 comments on commit 55bfe4b

Please sign in to comment.