From 271126921e611a3a2552422705a611a7732e68c0 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Thu, 4 May 2023 15:53:34 -0700 Subject: [PATCH 1/6] Add tracking_score attribute to `Instance`s --- sleap/instance.py | 6 +++++- tests/io/test_dataset.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index c14038552..9a2af39a5 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -369,6 +369,10 @@ class Instance: # The underlying Point array type that this instances point array should be. _point_array_type = PointArray + tracking_score: Optional[float] = attr.ib( + default=None, converter=lambda x: None if x is None else float(x) + ) + @from_predicted.validator def _validate_from_predicted_( self, attribute, from_predicted: Optional["PredictedInstance"] @@ -1309,7 +1313,7 @@ def __len__(self) -> int: """Return number of instances associated with frame.""" return len(self.instances) - def __getitem__(self, index) -> Instance: + def __getitem__(self, index) -> Union[Instance, PredictedInstance]: """Return instance (retrieved by index).""" return self.instances.__getitem__(index) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index a6fa7fdd7..fac0040b1 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -8,7 +8,6 @@ from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track from sleap.io.video import Video, MediaVideo from sleap.io.dataset import Labels, load_file -from sleap.io.legacy import load_labels_json_old from sleap.io.format.ndx_pose import NDXPoseAdaptor from sleap.io.format import filehandle from sleap.gui.suggestions import VideoFrameSuggestions, SuggestionFrame @@ -746,6 +745,36 @@ def test_dont_unify_skeletons(): labels.to_dict() +def test_instance_cattr(centered_pair_predictions: Labels, tmpdir: str): + labels = centered_pair_predictions + lf = labels.labeled_frames[0] + pred_inst: PredictedInstance = lf[0] + skeleton = pred_inst.skeleton + track = pred_inst.track + + # Initialize Instance + instance = Instance.from_pointsarray( + points=pred_inst.numpy(), skeleton=skeleton, track=track + ) + instance.from_predicted = pred_inst + assert instance.tracking_score is None + labels.add_instance(lf, instance) + + pred_tracking_score = pred_inst.tracking_score + inst_tracking_score = instance.tracking_score + + filename = str(PurePath(tmpdir, "labels.slp")) + labels.save(filename) + + labels_loaded = sleap.load_file(filename) + lf_loaded = labels_loaded.labeled_frames[0] + pred_inst_loaded = lf_loaded.predicted_instances[0] + instance_loaded = lf_loaded.user_instances[0] + + assert pred_inst_loaded.tracking_score == pred_tracking_score + assert instance_loaded.tracking_score == inst_tracking_score + + def test_instance_access(): labels = Labels() From 91466b64183fcff442b083216f74c733b2facc21 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Thu, 4 May 2023 16:04:46 -0700 Subject: [PATCH 2/6] Update adaptors to write tracking scores for `Instance`s --- sleap/info/write_tracking_h5.py | 2 +- sleap/io/format/hdf5.py | 4 ++-- sleap/io/format/nix.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 8bd583230..613b360c0 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -176,10 +176,10 @@ def get_occupancy_and_points_matrices( occupancy_matrix[track_i, frame_i] = 1 locations_matrix[frame_i, ..., track_i] = inst.numpy() + tracking_scores[frame_i, ..., track_i] = inst.tracking_score or np.nan if type(inst) == PredictedInstance: point_scores[frame_i, ..., track_i] = inst.scores instance_scores[frame_i, ..., track_i] = inst.score - tracking_scores[frame_i, ..., track_i] = inst.tracking_score return ( occupancy_matrix, diff --git a/sleap/io/format/hdf5.py b/sleap/io/format/hdf5.py index 353f88e3a..e43a29a91 100644 --- a/sleap/io/format/hdf5.py +++ b/sleap/io/format/hdf5.py @@ -438,11 +438,9 @@ def append_unique(old, new): if instance_type is PredictedInstance: score = instance.score pid = pred_point_id + pred_point_id_offset - tracking_score = instance.tracking_score else: score = np.nan pid = point_id + point_id_offset - tracking_score = np.nan # Keep track of any from_predicted instance links, we will # insert the correct instance_id in the dataset after we are @@ -451,6 +449,8 @@ def append_unique(old, new): instances_with_from_predicted.append(instance_id) instances_from_predicted.append(instance.from_predicted) + tracking_score = instance.tracking_score or np.nan + # Copy all the data instances[instance_id] = ( instance_id + instance_id_offset, diff --git a/sleap/io/format/nix.py b/sleap/io/format/nix.py index 4c39ec8b6..2f61cf999 100644 --- a/sleap/io/format/nix.py +++ b/sleap/io/format/nix.py @@ -239,13 +239,12 @@ def chunked_write( positions[index, :, node_map[m]] = np.array([np.nan, np.nan]) centroids[index, :] = inst.centroid + trackscore[index] = inst.tracking_score or 0.0 if hasattr(inst, "score"): instscore[index] = inst.score - trackscore[index] = inst.tracking_score pointscore[index, :] = inst.scores else: instscore[index] = 0.0 - trackscore[index] = 0.0 pointscore[index, :] = dflt_pointscore frameid_array[start:end] = indices[: end - start] From e4fc1b05cba6a3ea16bdf0d998cc99981a7945e4 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Thu, 4 May 2023 17:05:17 -0700 Subject: [PATCH 3/6] Set default `Instance.tracking_score` = 0.0 --- sleap/info/write_tracking_h5.py | 2 +- sleap/instance.py | 10 ++++------ sleap/io/format/hdf5.py | 2 +- sleap/io/format/nix.py | 2 +- tests/io/test_dataset.py | 14 ++++++++------ 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 613b360c0..d488b351e 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -176,7 +176,7 @@ def get_occupancy_and_points_matrices( occupancy_matrix[track_i, frame_i] = 1 locations_matrix[frame_i, ..., track_i] = inst.numpy() - tracking_scores[frame_i, ..., track_i] = inst.tracking_score or np.nan + tracking_scores[frame_i, ..., track_i] = inst.tracking_score if type(inst) == PredictedInstance: point_scores[frame_i, ..., track_i] = inst.scores instance_scores[frame_i, ..., track_i] = inst.score diff --git a/sleap/instance.py b/sleap/instance.py index 9a2af39a5..ea52c076e 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -357,6 +357,7 @@ class Instance: frame: A back reference to the :class:`LabeledFrame` that this :class:`Instance` belongs to. This field is set when instances are added to :class:`LabeledFrame` objects. + tracking_score: The instance-level track matching score. """ skeleton: Skeleton = attr.ib() @@ -369,9 +370,7 @@ class Instance: # The underlying Point array type that this instances point array should be. _point_array_type = PointArray - tracking_score: Optional[float] = attr.ib( - default=None, converter=lambda x: None if x is None else float(x) - ) + tracking_score: float = attr.ib(default=0.0, converter=float) @from_predicted.validator def _validate_from_predicted_( @@ -666,7 +665,8 @@ def __repr__(self) -> str: f"video={self.video}, " f"frame_idx={self.frame_idx}, " f"points=[{pts}], " - f"track={self.track}" + f"track={self.track}, " + f"tracking_score={self.tracking_score:.2f}" ")" ) @@ -1002,11 +1002,9 @@ class PredictedInstance(Instance): Args: score: The instance-level grouping prediction score. - tracking_score: The instance-level track matching score. """ score: float = attr.ib(default=0.0, converter=float) - tracking_score: float = attr.ib(default=0.0, converter=float) # The underlying Point array type that this instances point array should be. _point_array_type = PredictedPointArray diff --git a/sleap/io/format/hdf5.py b/sleap/io/format/hdf5.py index e43a29a91..4b9462083 100644 --- a/sleap/io/format/hdf5.py +++ b/sleap/io/format/hdf5.py @@ -449,7 +449,7 @@ def append_unique(old, new): instances_with_from_predicted.append(instance_id) instances_from_predicted.append(instance.from_predicted) - tracking_score = instance.tracking_score or np.nan + tracking_score = instance.tracking_score # Copy all the data instances[instance_id] = ( diff --git a/sleap/io/format/nix.py b/sleap/io/format/nix.py index 2f61cf999..fa2830464 100644 --- a/sleap/io/format/nix.py +++ b/sleap/io/format/nix.py @@ -239,7 +239,7 @@ def chunked_write( positions[index, :, node_map[m]] = np.array([np.nan, np.nan]) centroids[index, :] = inst.centroid - trackscore[index] = inst.tracking_score or 0.0 + trackscore[index] = inst.tracking_score if hasattr(inst, "score"): instscore[index] = inst.score pointscore[index, :] = inst.scores diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index fac0040b1..8a3f623ed 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -757,12 +757,10 @@ def test_instance_cattr(centered_pair_predictions: Labels, tmpdir: str): points=pred_inst.numpy(), skeleton=skeleton, track=track ) instance.from_predicted = pred_inst - assert instance.tracking_score is None + assert instance.tracking_score == 0.0 + instance.tracking_score = 0.5 labels.add_instance(lf, instance) - pred_tracking_score = pred_inst.tracking_score - inst_tracking_score = instance.tracking_score - filename = str(PurePath(tmpdir, "labels.slp")) labels.save(filename) @@ -771,8 +769,8 @@ def test_instance_cattr(centered_pair_predictions: Labels, tmpdir: str): pred_inst_loaded = lf_loaded.predicted_instances[0] instance_loaded = lf_loaded.user_instances[0] - assert pred_inst_loaded.tracking_score == pred_tracking_score - assert instance_loaded.tracking_score == inst_tracking_score + assert pred_inst_loaded.tracking_score == pred_inst.tracking_score + assert instance_loaded.tracking_score == instance.tracking_score def test_instance_access(): @@ -1557,3 +1555,7 @@ def test_export_nwb(centered_pair_predictions: Labels, tmpdir): # Read from NWB file read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename)) assert_read_labels_match(centered_pair_predictions, read_labels) + + +if __name__ == "__main__": + pytest.main([f"{__file__}::test_instance_cattr"]) From eb8fda48d9c51a22c18814cdec27399756a13b75 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Thu, 4 May 2023 17:36:54 -0700 Subject: [PATCH 4/6] Bump `FORMAT_ID` and read/write `Instance.tracking_score` --- sleap/io/format/hdf5.py | 8 ++++++-- tests/io/test_dataset.py | 12 +++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/sleap/io/format/hdf5.py b/sleap/io/format/hdf5.py index 4b9462083..75c9d3347 100644 --- a/sleap/io/format/hdf5.py +++ b/sleap/io/format/hdf5.py @@ -28,11 +28,12 @@ class LabelsV1Adaptor(format.adaptor.Adaptor): - FORMAT_ID = 1.2 + FORMAT_ID = 1.3 # 1.0 points with gridline coordinates, top left corner at (0, 0) # 1.1 points with midpixel coordinates, top left corner at (-0.5, -0.5) - # 1.2 adds track score to read and write functions + # 1.2 adds tracking score for PredictedInstance to read and write functions + # 1.3 adds tracking score for Instance to read and write functions @property def handles(self): @@ -180,6 +181,9 @@ def read( skeleton=skeleton, track=track, points=points[i["point_id_start"] : i["point_id_end"]], + tracking_score=i["tracking_score"] + if (format_id is not None and format_id >= 1.3) + else 0.0, ) else: # PredictedInstance instance = PredictedInstance( diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 8a3f623ed..2f7fdef3a 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -758,9 +758,11 @@ def test_instance_cattr(centered_pair_predictions: Labels, tmpdir: str): ) instance.from_predicted = pred_inst assert instance.tracking_score == 0.0 - instance.tracking_score = 0.5 labels.add_instance(lf, instance) + instance.tracking_score = 0.5 + pred_inst.tracking_score = 0.7 + filename = str(PurePath(tmpdir, "labels.slp")) labels.save(filename) @@ -769,8 +771,8 @@ def test_instance_cattr(centered_pair_predictions: Labels, tmpdir: str): pred_inst_loaded = lf_loaded.predicted_instances[0] instance_loaded = lf_loaded.user_instances[0] - assert pred_inst_loaded.tracking_score == pred_inst.tracking_score - assert instance_loaded.tracking_score == instance.tracking_score + assert round(pred_inst_loaded.tracking_score, 1) == pred_inst.tracking_score + assert round(instance_loaded.tracking_score, 1) == instance.tracking_score def test_instance_access(): @@ -1555,7 +1557,3 @@ def test_export_nwb(centered_pair_predictions: Labels, tmpdir): # Read from NWB file read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename)) assert_read_labels_match(centered_pair_predictions, read_labels) - - -if __name__ == "__main__": - pytest.main([f"{__file__}::test_instance_cattr"]) From d88264075f7a5b6317afc66f260845d0232d7c5c Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Wed, 24 May 2023 15:46:04 -0700 Subject: [PATCH 5/6] Prefer user instances when tracking --- sleap/nn/tracking.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index b861c359f..fedc3597d 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1189,7 +1189,12 @@ def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[Labele for inst in lf.instances: inst.track = None - track_args = dict(untracked_instances=lf.instances) + # Prefer user instances over predicted instances + untracked_instances = ( + lf.user_instances if lf.has_user_instances else lf.predicted_instances + ) + track_args = {"untracked_instances": untracked_instances} + if tracker.uses_image: track_args["img"] = lf.video[lf.frame_idx] else: From 1ea047db60ea4ad28bd5338d40f2b4b387fefe4e Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Wed, 24 May 2023 17:26:23 -0700 Subject: [PATCH 6/6] Add back instances that were not tracked --- sleap/nn/tracking.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index fedc3597d..eacb55e6a 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1190,20 +1190,26 @@ def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[Labele inst.track = None # Prefer user instances over predicted instances - untracked_instances = ( - lf.user_instances if lf.has_user_instances else lf.predicted_instances - ) - track_args = {"untracked_instances": untracked_instances} + instances = [] + if lf.has_user_instances: + instances_to_track = lf.user_instances + if lf.has_predicted_instances: + instances = lf.predicted_instances + else: + instances_to_track = lf.predicted_instances + + track_args = {"untracked_instances": instances_to_track} if tracker.uses_image: track_args["img"] = lf.video[lf.frame_idx] else: track_args["img"] = None + instances.extend(tracker.track(**track_args)) new_lf = LabeledFrame( frame_idx=lf.frame_idx, video=lf.video, - instances=tracker.track(**track_args), + instances=instances, ) new_lfs.append(new_lf)