Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tracker module #70

Merged
merged 29 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
182d1fe
Refactor preprocessing config
gitttt-1234 Jul 24, 2024
b5e2732
Merge train and val data configs
gitttt-1234 Jul 24, 2024
59082bf
Remove pipeline name
gitttt-1234 Jul 24, 2024
d42c206
Modify backbone_config
gitttt-1234 Jul 25, 2024
0ae3b7d
Modify ckpts
gitttt-1234 Jul 25, 2024
402d8e7
Fix inference tests
gitttt-1234 Jul 25, 2024
34b2457
Fix device for inference
gitttt-1234 Jul 26, 2024
624f1e3
Fix scale in inference
gitttt-1234 Jul 27, 2024
9a32aba
Fix Predictor
gitttt-1234 Jul 29, 2024
8f1e105
Modify `bottom_up` to `bottomup`
gitttt-1234 Jul 29, 2024
2b7e7b3
Fix bottomup inference
gitttt-1234 Jul 29, 2024
afe6efe
Fix scale in augmentation
gitttt-1234 Jul 30, 2024
558a292
Add tracker
gitttt-1234 Jul 31, 2024
e4ecb40
Fix tracker queue
gitttt-1234 Aug 1, 2024
caabf97
gerge branch 'divya/refactor-aug-config' of https://github.com/talmol…
gitttt-1234 Aug 6, 2024
1c78bcd
Add local queues
gitttt-1234 Aug 6, 2024
19f904c
mergeMerge branch 'divya/refactor-aug-config' of https://github.com/t…
gitttt-1234 Aug 7, 2024
cd3cfb2
Modify local queues\
gitttt-1234 Aug 7, 2024
24aa072
Add features
gitttt-1234 Aug 8, 2024
bc2472e
Add optical flow
gitttt-1234 Aug 13, 2024
87d6516
merge with baseMerge branch 'divya/tracker' of https://github.com/tal…
gitttt-1234 Aug 14, 2024
c81c3ae
Add Optical flow
gitttt-1234 Aug 15, 2024
2645683
Add tracking score
gitttt-1234 Aug 16, 2024
84df63a
Refactor candidate update
gitttt-1234 Aug 20, 2024
f427281
Integrate with Predictors
gitttt-1234 Aug 21, 2024
aa1fe5a
Fix lint
gitttt-1234 Aug 23, 2024
6c9262c
Fix tracks
gitttt-1234 Aug 23, 2024
5e77de9
Resume training and automatically compute crop size for TopDownConfma…
gitttt-1234 Sep 11, 2024
58a9715
Merge branch 'divya/fix-topdown-aug' into divya/tracker
gitttt-1234 Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sleap_nn/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tracker related modules."""
118 changes: 118 additions & 0 deletions sleap_nn/tracking/candidates/fixed_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Module to generate Fixed window candidates."""

from typing import Optional, List, Deque, Union
from sleap_nn.tracking.track_instance import TrackInstances
import sleap_io as sio
from collections import defaultdict, deque
import attrs
import numpy as np


class FixedWindowCandidates:
"""Fixed-window method for candidate generation.

This module handles `tracker_queue` using the fixed window method, where track assignments
are determined based on the last `window_size` frames.

Attributes:
window_size: Number of previous frames to compare the current predicted instance with.
Default: 8.
instance_score_threshold: Instance score threshold for creating new tracks.
Default: 0.0.
tracker_queue: Deque object that stores the past `window_size` tracked instances.
current_tracks: List of track IDs that are being tracked.
"""

def __init__(self, window_size: int = 8, instance_score_threshold: float = 0.0):
"""Initialize class variables."""
self.window_size = window_size
self.instance_score_threshold = instance_score_threshold
self.tracker_queue = deque(maxlen=self.window_size)
self.current_tracks = []

def get_track_instances(
self,
feature_list: List[Union[np.array]],
untracked_instances: List[sio.PredictedInstance],
frame_idx: int,
image: np.array,
) -> TrackInstances:
"""Return an instance of `TrackInstances` object for the `untracked_instances`."""

track_instance = TrackInstances(
src_instances=untracked_instances,
track_ids=[None] * len(untracked_instances),
features=feature_list,
instance_scores=[instance.score for instance in untracked_instances],
frame_idx=frame_idx,
image=image,
)
return track_instance

def get_features_from_track_id(self, track_id: int) -> List[np.array]:
"""Return list of features for instances in queue with the given `track_id`."""
output = []
for t in self.tracker_queue:
if track_id in t.track_ids:
output.append(t.features[t.track_ids.index(track_id)])
return output

def get_new_track_id(self) -> int:
"""Return a new track_id."""
if not self.current_tracks:
new_track_id = 0
else:
new_track_id = max(self.current_tracks) + 1
return new_track_id

def add_new_tracks(
self, new_track_instances: TrackInstances, add_to_queue: bool = True
) -> TrackInstances:
"""Add new track IDs to the `TrackInstances` object and to the tracker queue."""
is_new_track = False
for i, score in enumerate(new_track_instances.instance_scores):
if (
score > self.instance_score_threshold
and new_track_instances.track_ids[i] is None
):
is_new_track = True
new_tracks_id = self.get_new_track_id()
new_track_instances.track_ids[i] = new_tracks_id
self.current_tracks.append(new_tracks_id)

if add_to_queue and is_new_track:
self.tracker_queue.append(new_track_instances)

return new_track_instances

def update_candidates(
self, track_instances: TrackInstances, row_inds: np.array, col_inds: np.array
) -> TrackInstances:
"""Assign tracks to `TrackInstances` based on the output of track matching algorithm.

Args:
track_instances: `TrackInstances` instance with features.
row_inds: List of indices for the `track_instances` object that has an assigned
track.
col_inds: List of track IDs that have been assigned a new instance.

"""
add_to_queue = True
if np.any(row_inds) and np.any(col_inds):

for row, col in zip(row_inds, col_inds):
track_instances.track_ids[row] = col

# update tracks to queue
self.tracker_queue.append(track_instances)
add_to_queue = False

# Create new tracks for instances with unassigned tracks from track matching
new_track_instances_inds = [
x for x in range(len(track_instances.features)) if x not in row_inds
]
if new_track_instances_inds:
track_instances = self.add_new_tracks(
track_instances, add_to_queue=add_to_queue
)
return track_instances
139 changes: 139 additions & 0 deletions sleap_nn/tracking/candidates/local_queues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Module to generate Tracking local queue candidates."""

from typing import Dict, Optional, List, Deque, Union
import numpy as np
import sleap_io as sio
from sleap_nn.tracking.track_instance import TrackInstanceLocalQueue
from collections import defaultdict, deque


class LocalQueueCandidates:
"""Track local queues method for candidate generation.

This module handles `tracker_queue` using the local queues method, where track assignments
are determined based on the last `window_instances` instances for each track.

Attributes:
window_size: Number of previous frames to compare the current predicted instance with.
Default: 8.
max_tracks: Maximum number of new tracks that can be created. Default: 10.
instance_score_threshold: Instance score threshold for creating new tracks.
Default: 0.0.
tracker_queue: Dictionary that stores the past frames of all the tracks identified
so far as `deque`.
current_tracks: List of track IDs that are being tracked.
"""

def __init__(
self,
window_size: int = 8,
max_tracks: int = 10,
instance_score_threshold: float = 0.0,
):
"""Initialize class variables."""
self.window_size = window_size
self.max_tracks = max_tracks
self.instance_score_threshold = instance_score_threshold
self.tracker_queue = defaultdict(Deque)
self.current_tracks = []

def get_track_instances(
self,
feature_list: List[Union[np.array]],
untracked_instances: List[sio.PredictedInstance],
frame_idx: int,
image: np.array,
) -> List[TrackInstanceLocalQueue]:
"""Return a list of `TrackInstanceLocalQueue` instances from `untracked_instances`."""

track_instances = []
for ind, (feat, instance) in enumerate(zip(feature_list, untracked_instances)):
track_instance = TrackInstanceLocalQueue(
src_instance=instance,
src_instance_idx=ind,
track_id=None,
feature=feat,
instance_score=instance.score,
frame_idx=frame_idx,
image=image,
)
track_instances.append(track_instance)
return track_instances

def get_features_from_track_id(self, track_id: int) -> List[np.array]:
"""Return list of features for instances in queue with the given `track_id`."""
return [t.feature for t in self.tracker_queue[track_id]]

def get_new_track_id(self) -> int:
"""Return a new track_id."""
if not self.current_tracks:
new_track_id = 0
else:
new_track_id = max(self.current_tracks) + 1
if new_track_id > self.max_tracks:
raise Exception("Exceeding max tracks")
self.tracker_queue[new_track_id] = deque(maxlen=self.window_size)
return new_track_id

def add_new_tracks(
self, new_track_instances: List[TrackInstanceLocalQueue]
) -> List[TrackInstanceLocalQueue]:
"""Add new track IDs to the `TrackInstanceLocalQueue` objects and to the tracker queue."""

track_instances = []
for t in new_track_instances:
if t.instance_score > self.instance_score_threshold:
new_track_id = self.get_new_track_id()
t.track_id = new_track_id
self.current_tracks.append(new_track_id)
self.tracker_queue[new_track_id].append(t)
track_instances.append(t)

return track_instances

def update_candidates(
self,
track_instances: List[TrackInstanceLocalQueue],
row_inds: np.array,
col_inds: np.array,
) -> List[TrackInstanceLocalQueue]:
"""Assign tracks to `TrackInstanceLocalQueue` objects based on the output of track matching algorithm.

Args:
track_instances: List of TrackInstanceLocalQueue objects with features.
row_inds: List of indices for the `track_instances` object that has an assigned
track.
col_inds: List of track IDs that have been assigned a new instance.

"""
if np.any(row_inds) and np.any(col_inds):
for row, col in zip(row_inds, col_inds):
track_instances[row].track_id = col

for track_instance in track_instances:
if track_instance.track_id is not None:
self.tracker_queue[track_instance.track_id].append(track_instance)

# Create new tracks for instances with unassigned tracks from track matching
new_track_instances_inds = [
x for x in range(len(track_instances)) if x not in row_inds
]
if new_track_instances_inds:
for ind in new_track_instances_inds:
if (
track_instances[ind].instance_score
> self.instance_score_threshold
):
self.add_new_tracks(track_instances[ind])

return track_instances

def get_instances_groupby_frame_idx(
self,
) -> Dict[int, List[TrackInstanceLocalQueue]]:
"""Return dictionary with list of `TrackInstanceLocalQueue` objects grouped by frame index."""
instances_dict = defaultdict(list)
for track_id, instances in self.tracker_queue.items():
for instance in instances:
instances_dict[track_id].append(instance)
return instances_dict
49 changes: 49 additions & 0 deletions sleap_nn/tracking/track_instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""TrackInstance Data structure for Tracker queue."""

from typing import List, Optional, Union
import attrs
import numpy as np
import sleap_io as sio


@attrs.define
class TrackInstances:
"""Data structure for instances in tracker queue for fixed window method."""

src_instances: List[sio.PredictedInstance]
features: List[np.array]
instance_scores: List[float] = None
track_ids: Optional[List[int]] = []
tracking_scores: Optional[List[float]] = None # TODO
frame_idx: Optional[float] = None
image: Optional[np.array] = None


@attrs.define
class TrackInstanceLocalQueue:
"""Data structure for instances in tracker queue for Local Queue method."""

src_instance: sio.PredictedInstance
src_instance_idx: int
feature: np.array
instance_score: float = None
track_id: Optional[int] = None
tracking_score: Optional[float] = None
frame_idx: Optional[float] = None
image: Optional[np.array] = None


@attrs.define
class ShiftedInstance:
"""Data structure for `FlowShiftTracker`.

Note: This data structure is only used to get the shifted points for the instances
in the tracker queue (has an assigned track ID).
"""

src_track_instance: Union[TrackInstances, List[TrackInstanceLocalQueue]]
shifted_pts: np.array
src_instance_idx: int
frame_idx: int
shift_score: float
track_id: int
Loading
Loading