From ebd2e1ec8f062efaf2588884bb46e92a8030ddcd Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 9 Dec 2022 12:35:58 -0800 Subject: [PATCH] Don't create instances during inference if no points were found (#1073) * Don't create instances during inference if no points were found * Add points check for all predictors * Fix single instance predictor logic and test * Add tests for all predictors Co-authored-by: roomrys <38435167+roomrys@users.noreply.github.com> --- sleap/nn/inference.py | 27 +++++++++----- tests/nn/test_inference.py | 75 +++++++++++++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 11 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index b5baca687..dc7962c63 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -762,7 +762,7 @@ def call( tf.int64 ) # (batch_size, n_centroids, 1, 1, 2) dists = a - b # (batch_size, n_centroids, n_insts, n_nodes, 2) - dists = tf.sqrt(tf.reduce_sum(dists ** 2, axis=-1)) # reduce over xy + dists = tf.sqrt(tf.reduce_sum(tf.math.square(dists), axis=-1)) # reduce over xy dists = tf.reduce_min(dists, axis=-1) # reduce over nodes dists = dists.to_tensor( tf.cast(np.NaN, tf.float32) @@ -1453,14 +1453,17 @@ def _make_labeled_frames_from_generator( ex["instance_peak_vals"], ): # Loop over instances. - predicted_instances = [ - sleap.instance.PredictedInstance.from_arrays( - points=points[0], - point_confidences=confidences[0], - instance_score=np.nansum(confidences[0]), - skeleton=skeleton, - ) - ] + if np.isnan(points[0]).all(): + predicted_instances = [] + else: + predicted_instances = [ + sleap.instance.PredictedInstance.from_arrays( + points=points[0], + point_confidences=confidences[0], + instance_score=np.nansum(confidences[0]), + skeleton=skeleton, + ) + ] predicted_frames.append( sleap.LabeledFrame( @@ -2434,6 +2437,9 @@ def _make_labeled_frames_from_generator( # Loop over instances. predicted_instances = [] for pts, confs, score in zip(points, confidences, scores): + if np.isnan(pts).all(): + continue + predicted_instances.append( sleap.instance.PredictedInstance.from_arrays( points=pts, @@ -2999,6 +3005,9 @@ def _make_labeled_frames_from_generator( # Loop over instances. predicted_instances = [] for pts, confs, score in zip(points, confidences, scores): + if np.isnan(pts).all(): + continue + predicted_instances.append( sleap.instance.PredictedInstance.from_arrays( points=pts, diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 54050ceb8..a1fbb7353 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -573,14 +573,25 @@ def test_single_instance_predictor( def test_single_instance_predictor_high_peak_thresh( min_labels_robot, min_single_instance_robot_model_path ): + predictor = SingleInstancePredictor.from_trained_models( + min_single_instance_robot_model_path, peak_threshold=0 + ) + predictor.verbosity = "none" + labels_pr = predictor.predict(min_labels_robot) + assert len(labels_pr) == 2 + assert len(labels_pr[0]) == 1 + assert labels_pr[0][0].n_visible_points == 2 + assert len(labels_pr[1]) == 1 + assert labels_pr[1][0].n_visible_points == 2 + predictor = SingleInstancePredictor.from_trained_models( min_single_instance_robot_model_path, peak_threshold=1.5 ) predictor.verbosity = "none" labels_pr = predictor.predict(min_labels_robot) assert len(labels_pr) == 2 - assert labels_pr[0][0].n_visible_points == 0 - assert labels_pr[1][0].n_visible_points == 0 + assert len(labels_pr[0]) == 0 + assert len(labels_pr[1]) == 0 def test_topdown_predictor_centroid(min_labels, min_centroid_model_path): @@ -612,6 +623,16 @@ def test_topdown_predictor_centroid(min_labels, min_centroid_model_path): assert len(labels_pr[0].instances) == 2 +def test_topdown_predictor_centroid_high_threshold(min_labels, min_centroid_model_path): + predictor = TopDownPredictor.from_trained_models( + centroid_model_path=min_centroid_model_path, peak_threshold=1.5 + ) + predictor.verbosity = "none" + labels_pr = predictor.predict(min_labels) + assert len(labels_pr) == 1 + assert len(labels_pr[0].instances) == 0 + + def test_topdown_predictor_centered_instance( min_labels, min_centered_instance_model_path ): @@ -636,6 +657,18 @@ def test_topdown_predictor_centered_instance( assert_allclose(points_gt[inds1.numpy()], points_pr[inds2.numpy()], atol=1.5) +def test_topdown_predictor_centered_instance_high_threshold( + min_labels, min_centered_instance_model_path +): + predictor = TopDownPredictor.from_trained_models( + confmap_model_path=min_centered_instance_model_path, peak_threshold=1.5 + ) + predictor.verbosity = "none" + labels_pr = predictor.predict(min_labels) + assert len(labels_pr) == 1 + assert len(labels_pr[0].instances) == 0 + + def test_bottomup_predictor(min_labels, min_bottomup_model_path): predictor = BottomUpPredictor.from_trained_models( model_path=min_bottomup_model_path @@ -666,6 +699,16 @@ def test_bottomup_predictor(min_labels, min_bottomup_model_path): assert len(labels_pr[0]) == 0 +def test_bottomup_predictor_high_peak_thresh(min_labels, min_bottomup_model_path): + predictor = BottomUpPredictor.from_trained_models( + model_path=min_bottomup_model_path, peak_threshold=1.5 + ) + predictor.verbosity = "none" + labels_pr = predictor.predict(min_labels) + assert len(labels_pr) == 1 + assert len(labels_pr[0].instances) == 0 + + def test_bottomup_multiclass_predictor( min_tracks_2node_labels, min_bottomup_multiclass_model_path ): @@ -698,6 +741,20 @@ def test_bottomup_multiclass_predictor( labels_pr[0][1].track.name == "male" +def test_bottomup_multiclass_predictor_high_threshold( + min_tracks_2node_labels, min_bottomup_multiclass_model_path +): + labels_gt = sleap.Labels(min_tracks_2node_labels[[0]]) + predictor = BottomUpMultiClassPredictor.from_trained_models( + model_path=min_bottomup_multiclass_model_path, + peak_threshold=1.5, + integral_refinement=False, + ) + labels_pr = predictor.predict(labels_gt) + assert len(labels_pr) == 1 + assert len(labels_pr[0].instances) == 0 + + def test_topdown_multiclass_predictor( min_tracks_2node_labels, min_topdown_multiclass_model_path ): @@ -724,6 +781,20 @@ def test_topdown_multiclass_predictor( ) +def test_topdown_multiclass_predictor_high_threshold( + min_tracks_2node_labels, min_topdown_multiclass_model_path +): + labels_gt = sleap.Labels(min_tracks_2node_labels[[0]]) + predictor = TopDownMultiClassPredictor.from_trained_models( + confmap_model_path=min_topdown_multiclass_model_path, + peak_threshold=1.5, + integral_refinement=False, + ) + labels_pr = predictor.predict(labels_gt) + assert len(labels_pr) == 1 + assert len(labels_pr[0].instances) == 0 + + def test_load_model( min_single_instance_robot_model_path, min_centroid_model_path,