Skip to content

Commit

Permalink
Add tracking score to seekbar header
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Dec 16, 2024
1 parent c5b4fff commit d979342
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 8 deletions.
6 changes: 5 additions & 1 deletion sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,8 @@ def new_instance_menu_action():
"Point Displacement (max)",
"Primary Point Displacement (sum)",
"Primary Point Displacement (max)",
"Tracking Score (mean)",
"Tracking Score (min)",
"Instance Score (sum)",
"Instance Score (min)",
"Point Score (sum)",
Expand Down Expand Up @@ -1406,6 +1408,8 @@ def _set_seekbar_header(self, graph_name: str):
"Point Displacement (max)": data_obj.get_point_displacement_series,
"Primary Point Displacement (sum)": data_obj.get_primary_point_displacement_series,
"Primary Point Displacement (max)": data_obj.get_primary_point_displacement_series,
"Tracking Score (mean)": data_obj.get_tracking_score_series,
"Tracking Score (min)": data_obj.get_tracking_score_series,
"Instance Score (sum)": data_obj.get_instance_score_series,
"Instance Score (min)": data_obj.get_instance_score_series,
"Point Score (sum)": data_obj.get_point_score_series,
Expand All @@ -1419,7 +1423,7 @@ def _set_seekbar_header(self, graph_name: str):
else:
if graph_name in header_functions:
kwargs = dict(video=self.state["video"])
reduction_name = re.search("\\((sum|max|min)\\)", graph_name)
reduction_name = re.search("\\((sum|max|min|mean)\\)", graph_name)
if reduction_name is not None:
kwargs["reduction"] = reduction_name.group(1)
series = header_functions[graph_name](**kwargs)
Expand Down
44 changes: 38 additions & 6 deletions sleap/info/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class StatisticSeries:
are frame index and value are some numerical value for the frame.
Args:
labels: The :class:`Labels` for which to calculate series.
labels: The `Labels` for which to calculate series.
"""

labels: Labels
Expand All @@ -41,7 +41,7 @@ def get_point_score_series(
"""Get series with statistic of point scores in each frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to scores:
* sum
* min
Expand All @@ -67,7 +67,7 @@ def get_instance_score_series(self, video, reduction="sum") -> Dict[int, float]:
"""Get series with statistic of instance scores in each frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to scores:
* sum
* min
Expand All @@ -93,7 +93,7 @@ def get_point_displacement_series(self, video, reduction="sum") -> Dict[int, flo
same track) from the closest earlier labeled frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to point scores:
* sum
* mean
Expand Down Expand Up @@ -121,7 +121,7 @@ def get_primary_point_displacement_series(
Get sum of displacement for single node of each instance per frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to point scores:
* sum
* mean
Expand Down Expand Up @@ -226,7 +226,7 @@ def _calculate_frame_velocity(
Calculate total point displacement between two given frames.
Args:
lf: The :class:`LabeledFrame` for which we want velocity
lf: The `LabeledFrame` for which we want velocity
last_lf: The frame from which to calculate displacement.
reduce_function: Numpy function (e.g., np.sum, np.nanmean)
is applied to *point* displacement, and then those
Expand All @@ -246,3 +246,35 @@ def _calculate_frame_velocity(
inst_dist = reduce_function(point_dist)
val += inst_dist if not np.isnan(inst_dist) else 0
return val

def get_tracking_score_series(
self, video: Video, reduction: str = "min"
) -> Dict[int, float]:
"""Get series with statistic of tracking scores in each frame.
Args:
video: The `Video` for which to calculate statistic.
reduction: name of function applied to scores:
* mean
* min
Returns:
The series dictionary (see class docs for details)
"""
reduce_fn = {
"min": np.nanmin,
"mean": np.nanmean,
}[reduction]

series = dict()

for lf in self.labels.find(video):
vals = [
inst.tracking_score for inst in lf if hasattr(inst, "tracking_score")
]
if vals:
val = reduce_fn(vals)
if not np.isnan(val):
series[lf.frame_idx] = val

return series
15 changes: 14 additions & 1 deletion tests/info/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def test_frame_statistics(simple_predictions):

x = stats.get_point_displacement_series(video, "max")
assert len(x) == 2
assert len(x) == 2
assert x[0] == 0
assert x[1] == 18.0


def test_get_tracking_score_series(min_tracks_2node_predictions):

stats = StatisticSeries(min_tracks_2node_predictions)
x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "min")
assert len(x) == 1500
assert x[0] == 0.9999966621398926
assert x[1000] == 0.9998022317886353

x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "mean")
assert len(x) == 1500
assert x[0] == 0.9999983310699463
assert x[1000] == 0.9999011158943176

0 comments on commit d979342

Please sign in to comment.