Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tracker module #70

Merged
merged 29 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
182d1fe
Refactor preprocessing config
gitttt-1234 Jul 24, 2024
b5e2732
Merge train and val data configs
gitttt-1234 Jul 24, 2024
59082bf
Remove pipeline name
gitttt-1234 Jul 24, 2024
d42c206
Modify backbone_config
gitttt-1234 Jul 25, 2024
0ae3b7d
Modify ckpts
gitttt-1234 Jul 25, 2024
402d8e7
Fix inference tests
gitttt-1234 Jul 25, 2024
34b2457
Fix device for inference
gitttt-1234 Jul 26, 2024
624f1e3
Fix scale in inference
gitttt-1234 Jul 27, 2024
9a32aba
Fix Predictor
gitttt-1234 Jul 29, 2024
8f1e105
Modify `bottom_up` to `bottomup`
gitttt-1234 Jul 29, 2024
2b7e7b3
Fix bottomup inference
gitttt-1234 Jul 29, 2024
afe6efe
Fix scale in augmentation
gitttt-1234 Jul 30, 2024
558a292
Add tracker
gitttt-1234 Jul 31, 2024
e4ecb40
Fix tracker queue
gitttt-1234 Aug 1, 2024
caabf97
gerge branch 'divya/refactor-aug-config' of https://github.com/talmol…
gitttt-1234 Aug 6, 2024
1c78bcd
Add local queues
gitttt-1234 Aug 6, 2024
19f904c
mergeMerge branch 'divya/refactor-aug-config' of https://github.com/t…
gitttt-1234 Aug 7, 2024
cd3cfb2
Modify local queues\
gitttt-1234 Aug 7, 2024
24aa072
Add features
gitttt-1234 Aug 8, 2024
bc2472e
Add optical flow
gitttt-1234 Aug 13, 2024
87d6516
merge with baseMerge branch 'divya/tracker' of https://github.com/tal…
gitttt-1234 Aug 14, 2024
c81c3ae
Add Optical flow
gitttt-1234 Aug 15, 2024
2645683
Add tracking score
gitttt-1234 Aug 16, 2024
84df63a
Refactor candidate update
gitttt-1234 Aug 20, 2024
f427281
Integrate with Predictors
gitttt-1234 Aug 21, 2024
aa1fe5a
Fix lint
gitttt-1234 Aug 23, 2024
6c9262c
Fix tracks
gitttt-1234 Aug 23, 2024
5e77de9
Resume training and automatically compute crop size for TopDownConfma…
gitttt-1234 Sep 11, 2024
58a9715
Merge branch 'divya/fix-topdown-aug' into divya/tracker
gitttt-1234 Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 111 additions & 106 deletions sleap_nn/inference/predictors.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions sleap_nn/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tracker related modules."""
143 changes: 143 additions & 0 deletions sleap_nn/tracking/candidates/fixed_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Module to generate Fixed window candidates."""

from typing import Optional, List, Deque, Union
from sleap_nn.tracking.track_instance import TrackInstances, TrackedInstanceFeature
import sleap_io as sio
from collections import deque
import numpy as np


class FixedWindowCandidates:
"""Fixed-window method for candidate generation.

This module handles `tracker_queue` using the fixed window method, where track assignments
are determined based on the last `window_size` frames.

Attributes:
window_size: Number of previous frames to compare the current predicted instance with.
Default: 5.
instance_score_threshold: Instance score threshold for creating new tracks.
Default: 0.0.
tracker_queue: Deque object that stores the past `window_size` tracked instances.
current_tracks: List of track IDs that are being tracked.
"""

def __init__(self, window_size: int = 5, instance_score_threshold: float = 0.0):
"""Initialize class variables."""
self.window_size = window_size
self.instance_score_threshold = instance_score_threshold
self.tracker_queue = deque(maxlen=self.window_size)
self.current_tracks = []

def get_track_instances(
self,
feature_list: List[Union[np.array]],
untracked_instances: List[sio.PredictedInstance],
frame_idx: int,
image: np.array,
) -> TrackInstances:
"""Return an instance of `TrackInstances` object for the `untracked_instances`."""
track_instance = TrackInstances(
src_instances=untracked_instances,
track_ids=[None] * len(untracked_instances),
tracking_scores=[None] * len(untracked_instances),
features=feature_list,
instance_scores=[instance.score for instance in untracked_instances],
frame_idx=frame_idx,
image=image,
)
return track_instance

def get_features_from_track_id(
self, track_id: int, candidates_list: Optional[Deque] = None
) -> List[TrackedInstanceFeature]:
"""Return list of `TrackedInstanceFeature` objects for instances in tracker queue with the given `track_id`.

Note: If `candidates_list` is `None`, then features of all the instances in the
tracker queue are returned by default. Else, only the features from the given
candidates_list are returned.
"""
output = []
tracked_candidates = (
candidates_list if candidates_list is not None else self.tracker_queue
)
for t in tracked_candidates:
if track_id in t.track_ids:
track_idx = t.track_ids.index(track_id)
tracked_instance_feature = TrackedInstanceFeature(
feature=t.features[track_idx],
src_predicted_instance=t.src_instances[track_idx],
frame_idx=t.frame_idx,
tracking_score=t.tracking_scores[track_idx],
instance_score=t.instance_scores[track_idx],
shifted_keypoints=None,
)
output.append(tracked_instance_feature)
return output

def get_new_track_id(self) -> int:
"""Return a new track_id."""
if not self.current_tracks:
new_track_id = 0
else:
new_track_id = max(self.current_tracks) + 1
return new_track_id

def add_new_tracks(
self, current_instances: TrackInstances, add_to_queue: bool = True
) -> TrackInstances:
"""Add new track IDs to the `TrackInstances` object and to the tracker queue."""
is_new_track = False
for i, score in enumerate(current_instances.instance_scores):
if (
score > self.instance_score_threshold
and current_instances.track_ids[i] is None
):
is_new_track = True
new_tracks_id = self.get_new_track_id()
current_instances.track_ids[i] = new_tracks_id
current_instances.tracking_scores[i] = 1.0
self.current_tracks.append(new_tracks_id)

if add_to_queue and is_new_track:
self.tracker_queue.append(current_instances)

return current_instances

def update_tracks(
self,
current_instances: TrackInstances,
row_inds: np.array,
col_inds: np.array,
tracking_scores: List[float],
) -> TrackInstances:
"""Assign tracks to `TrackInstances` based on the output of track matching algorithm.

Args:
current_instances: `TrackInstances` instance with features and unassigned tracks.
row_inds: List of indices for the `current_instances` object that has an assigned
track.
col_inds: List of track IDs that have been assigned a new instance.
tracking_scores: List of tracking scores from the cost matrix.

"""
add_to_queue = True
if np.any(row_inds) and np.any(col_inds):

for idx, (row, col) in enumerate(zip(row_inds, col_inds)):
current_instances.track_ids[row] = col
current_instances.tracking_scores[row] = tracking_scores[idx]

# update tracks to queue
self.tracker_queue.append(current_instances)
add_to_queue = False

# Create new tracks for instances with unassigned tracks from track matching
new_current_instances_inds = [
x for x in range(len(current_instances.features)) if x not in row_inds
]
if new_current_instances_inds:
current_instances = self.add_new_tracks(
current_instances, add_to_queue=add_to_queue
)
return current_instances
164 changes: 164 additions & 0 deletions sleap_nn/tracking/candidates/local_queues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Module to generate Tracking local queue candidates."""

from typing import Dict, Optional, List, Deque, DefaultDict, Union
import numpy as np
import sleap_io as sio
from sleap_nn.tracking.track_instance import (
TrackInstanceLocalQueue,
TrackedInstanceFeature,
)
from collections import defaultdict, deque


class LocalQueueCandidates:
"""Track local queues method for candidate generation.

This module handles `tracker_queue` using the local queues method, where track assignments
are determined based on the last `window_size` instances for each track.

Attributes:
window_size: Number of previous frames to compare the current predicted instance with.
Default: 5.
max_tracks: Maximum number of new tracks that can be created. Default: None.
instance_score_threshold: Instance score threshold for creating new tracks.
Default: 0.0.
tracker_queue: Dictionary that stores the past frames of all the tracks identified
so far as `deque`.
current_tracks: List of track IDs that are being tracked.
"""

def __init__(
self,
window_size: int = 5,
max_tracks: Optional[int] = None,
instance_score_threshold: float = 0.0,
):
"""Initialize class variables."""
self.window_size = window_size
self.max_tracks = max_tracks
self.instance_score_threshold = instance_score_threshold
self.tracker_queue = defaultdict(Deque)
self.current_tracks = []

def get_track_instances(
self,
feature_list: List[Union[np.array]],
untracked_instances: List[sio.PredictedInstance],
frame_idx: int,
image: np.array,
) -> List[TrackInstanceLocalQueue]:
"""Return a list of `TrackInstanceLocalQueue` instances for the `untracked_instances`."""
track_instances = []
for ind, (feat, instance) in enumerate(zip(feature_list, untracked_instances)):
track_instance = TrackInstanceLocalQueue(
src_instance=instance,
src_instance_idx=ind,
track_id=None,
feature=feat,
instance_score=instance.score,
frame_idx=frame_idx,
image=image,
)
track_instances.append(track_instance)
return track_instances

def get_features_from_track_id(
self, track_id: int, candidates_list: Optional[DefaultDict[int, Deque]] = None
) -> List[TrackedInstanceFeature]:
"""Return list of `TrackedInstanceFeature` objects for instances in tracker queue with the given `track_id`.

Note: If `candidates_list` is `None`, then features of all the instances in the
tracker queue are returned by default. Else, only the features from the given
candidates_list are returned.
"""
tracked_instances = (
candidates_list if candidates_list is not None else self.tracker_queue
)
output = []
for t in tracked_instances[track_id]:
tracked_instance_feature = TrackedInstanceFeature(
feature=t.feature,
src_predicted_instance=t.src_instance,
frame_idx=t.frame_idx,
tracking_score=t.tracking_score,
instance_score=t.instance_score,
shifted_keypoints=None,
)
output.append(tracked_instance_feature)
return output

def get_new_track_id(self) -> int:
"""Return a new track_id."""
if not self.current_tracks:
new_track_id = 0
else:
new_track_id = max(self.current_tracks) + 1
if self.max_tracks is not None and new_track_id > self.max_tracks: # TODO
raise Exception("Exceeding max tracks")
self.tracker_queue[new_track_id] = deque(maxlen=self.window_size)
return new_track_id

def add_new_tracks(
self, current_instances: List[TrackInstanceLocalQueue]
) -> List[TrackInstanceLocalQueue]:
"""Add new track IDs to the `TrackInstanceLocalQueue` objects and to the tracker queue."""
track_instances = []
for t in current_instances:
if t.instance_score > self.instance_score_threshold:
new_track_id = self.get_new_track_id()
t.track_id = new_track_id
t.tracking_score = 1.0
self.current_tracks.append(new_track_id)
self.tracker_queue[new_track_id].append(t)
track_instances.append(t)

return track_instances

def update_tracks(
self,
current_instances: List[TrackInstanceLocalQueue],
row_inds: np.array,
col_inds: np.array,
tracking_scores: List[float],
) -> List[TrackInstanceLocalQueue]:
"""Assign tracks to `TrackInstanceLocalQueue` objects based on the output of track matching algorithm.

Args:
current_instances: List of TrackInstanceLocalQueue objects with features and unassigned tracks.
row_inds: List of indices for the `current_instances` object that has an assigned
track.
col_inds: List of track IDs that have been assigned a new instance.
tracking_scores: List of tracking scores from the cost matrix.

"""
if np.any(row_inds) and np.any(col_inds):
for idx, (row, col) in enumerate(zip(row_inds, col_inds)):
current_instances[row].track_id = col
current_instances[row].tracking_score = tracking_scores[idx]

for track_instance in current_instances:
if track_instance.track_id is not None:
self.tracker_queue[track_instance.track_id].append(track_instance)

# Create new tracks for instances with unassigned tracks from track matching
new_current_instances_inds = [
x for x in range(len(current_instances)) if x not in row_inds
]
if new_current_instances_inds:
for ind in new_current_instances_inds:
self.add_new_tracks(current_instances[ind])

return current_instances

def get_instances_groupby_frame_idx(
self, candidates_list: Optional[DefaultDict[int, Deque]]
) -> Dict[int, List[TrackInstanceLocalQueue]]:
"""Return dictionary with list of `TrackInstanceLocalQueue` objects grouped by frame index."""
instances_dict = defaultdict(list)
tracked_instances = (
candidates_list if candidates_list is not None else self.tracker_queue
)
for _, instances in tracked_instances.items():
for instance in instances:
instances_dict[instance.frame_idx].append(instance)
return instances_dict
50 changes: 50 additions & 0 deletions sleap_nn/tracking/track_instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""TrackInstance Data structure for Tracker queue."""

from typing import List, Optional
import attrs
import numpy as np
import sleap_io as sio


@attrs.define
class TrackInstances:
"""Data structure for instances in tracker queue for fixed window method."""

src_instances: List[sio.PredictedInstance]
features: List[np.array]
instance_scores: List[float] = None
track_ids: Optional[List[int]] = None
tracking_scores: Optional[List[float]] = None
frame_idx: Optional[float] = None
image: Optional[np.array] = None


@attrs.define
class TrackInstanceLocalQueue:
"""Data structure for instances in tracker queue for Local Queue method."""

src_instance: sio.PredictedInstance
src_instance_idx: int
feature: np.array
instance_score: float = None
track_id: Optional[int] = None
tracking_score: Optional[float] = None
frame_idx: Optional[float] = None
image: Optional[np.array] = None


@attrs.define
class TrackedInstanceFeature:
"""Data structure for tracked instances.

This data structure is used for updating the previous tracked instances and get the
features of the tracked instances. `shifted_keypoints` is used only for the `FlowShiftTracker`
to store the optical flow shifted instances.
"""

feature: np.ndarray
src_predicted_instance: sio.PredictedInstance
frame_idx: int
tracking_score: float
instance_score: float
shifted_keypoints: np.ndarray = None
Loading
Loading