Skip to content

Commit

Permalink
Add function to get error per view
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys committed Nov 27, 2023
1 parent bd0f32d commit 7ee80b0
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3772,6 +3772,7 @@ def _calculate_reprojection_error(
session: RecordingSession,
instances: Dict[int, Dict[Camcorder, List[Instance]]],
per_instance: bool = False,
per_view: bool = False,
) -> Union[
Dict[int, float], Dict[int, Dict[Camcorder, List[Tuple[Instance, float]]]]
]:
Expand All @@ -3784,6 +3785,9 @@ def _calculate_reprojection_error(
per_instance: If True, then return a dict with frame identifier keys and
values of another inner dict with `Camcorder` keys and
`List[Tuple[Instance, float]]` values.
per_view: If True, then return a dict with frame identifier keys and values
of another inner dict with `Camcorder` keys and
`Tuple[Tuple[str, str], float]` values. If per_instance is True, then that takes precendence.
Returns:
Dict with frame identifier keys (not the frame index) and values of another
Expand All @@ -3801,9 +3805,10 @@ def _calculate_reprojection_error(
session=session, instances=instances
)
for frame_id, instances_in_frame in instances_and_coords.items():
frame_error = {} if per_instance else 0
frame_error = {} if per_instance or per_view else 0
for cam, instances_in_view in instances_in_frame.items():
# Compare instance coordinates here
instance_ids = []
view_error = [] if per_instance else 0
for inst, inst_coords in instances_in_view:
node_errors = np.nan_to_num(inst.numpy() - inst_coords)
Expand All @@ -3814,8 +3819,13 @@ def _calculate_reprojection_error(
else:
view_error += instance_error

inst_id = inst.track if inst.track is not None else "None"
instance_ids.append(inst_id)

if per_instance:
frame_error[cam] = view_error
elif per_view:
frame_error[cam] = (tuple(instance_ids), view_error)
else:
frame_error += view_error

Expand All @@ -3837,6 +3847,18 @@ def calculate_error_per_instance(

return reprojection_error_per_instance

@staticmethod
def calculate_error_per_view(
session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]]
) -> Dict[int, float]:
"""Calculate reprojection error per instance."""

reprojection_error_per_view = TriangulateSession._calculate_reprojection_error(
session=session, instances=instances, per_view=True
)

return reprojection_error_per_view

@staticmethod
def calculate_error_per_frame(
session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]]
Expand Down

0 comments on commit 7ee80b0

Please sign in to comment.