Skip to content

Commit

Permalink
add tests for every combination related to kalman args
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed Dec 18, 2024
1 parent c3c5372 commit 2342a9d
Showing 1 changed file with 166 additions and 16 deletions.
182 changes: 166 additions & 16 deletions tests/nn/test_tracking_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,176 @@
from sleap.io.dataset import Labels, LabeledFrame


similarity_args = [
"instance",
"normalized_instance",
"object_keypoint",
"centroid",
"iou",
]
match_args = ["hungarian", "greedy"]


@pytest.mark.parametrize(
"tracker_name", ["simple", "simplemaxtracks", "flow", "flowmaxtracks"]
)
def test_kalman_tracker(tmpdir, centered_pair_predictions_slp_path, tracker_name):
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
"--frames 200-300 "
"--tracking.similarity instance "
"--tracking.match hungarian "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
@pytest.mark.parametrize("similarity", similarity_args)
@pytest.mark.parametrize("match", match_args)
def test_kalman_tracker(
tmpdir, centered_pair_predictions_slp_path, tracker_name, similarity, match
):

if tracker_name == "flow" or tracker_name == "flowmaxtracks":
# Expecting ValueError for "flow" or "flowmaxtracks" due to Kalman filter requiring a simple tracker
with pytest.raises(
ValueError,
match="Kalman filter requires simple tracker for initial tracking.",
):
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
else:
# For simple or simplemaxtracks, continue with other tests
# Check for ValueError when similarity is "normalized_instance"
if similarity == "normalized_instance":
with pytest.raises(
ValueError,
match="Kalman filter does not support normalized_instance_similarity.",
):
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
return

# Check for ValueError when kf_node_indices is None which is the default
with pytest.raises(
ValueError,
match="Kalman filter requires node indices for instance tracking.",
):
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

# Test for missing max_tracks and target_instance_count with kf_init_frame_count
with pytest.raises(
ValueError,
match="Kalman filter requires max tracks or target instance count.",
):
cli = (
f"--tracking.tracker {tracker_name} "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

# Test with target_instance_count and without max_tracks
cli = (
f"--tracking.tracker {tracker_name} "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
f"-o {tmpdir}/{tracker_name}_target_instance_count.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

labels = sleap.load_file(f"{tmpdir}/{tracker_name}_target_instance_count.slp")
assert len(labels.tracks) == 2

# Test with target_instance_count and with max_tracks
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

labels = sleap.load_file(f"{tmpdir}/{tracker_name}.slp")
assert len(labels.tracks) == 2
labels = sleap.load_file(
f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp"
)
assert len(labels.tracks) == 2

# Test with "--tracking.pre_cull_iou_threshold", "0.8"
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
"--tracking.pre_cull_iou_threshold 0.8 "
f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

labels = sleap.load_file(
f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp"
)
assert len(labels.tracks) == 2

# Test with "--tracking.pre_cull_to_target", "1"
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
"--tracking.pre_cull_to_target 1 "
f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
labels = sleap.load_file(
f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp"
)
assert len(labels.tracks) == 2


def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path):
Expand Down

0 comments on commit 2342a9d

Please sign in to comment.