Skip to content

Commit

Permalink
fix docstrings + small typo
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad committed Dec 4, 2023
1 parent a03cca9 commit 7c4c645
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions biogtr/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ def __init__(
"""Initialize a tracker to run inference.
Args:
window_size: the size of the window used during sliding inference
use_vis_feats: Whether or not to use visual feature extractor
overlap_thresh: the trajectory overlap threshold to be used for assignment
mult_thresh: Whether or not to use weight threshold
decay_time: weight for `decay_time` postprocessing
window_size: the size of the window used during sliding inference.
use_vis_feats: Whether or not to use visual feature extractor.
overlap_thresh: the trajectory overlap threshold to be used for assignment.
mult_thresh: Whether or not to use weight threshold.
decay_time: weight for `decay_time` postprocessing.
iou: Either [None, '', "mult" or "max"]
Whether to use multiplicative or max iou reweighting
max_center_dist: distance threshold for filtering trajectory score matrix
persistent_tracking: whether to keep a buffer across chunks or not
Whether to use multiplicative or max iou reweighting.
max_center_dist: distance threshold for filtering trajectory score matrix.
persistent_tracking: whether to keep a buffer across chunks or not.
max_gap: the max number of frames a trajectory can be missing before termination.
verbose: Whether or not to turn on debug printing after each operation.
"""
self.track_queue = TrackQueue(
window_size=window_size, max_gap=max_gap, verbose=verbose
Expand Down Expand Up @@ -232,7 +234,7 @@ def _run_global_tracker(

# (L=1, n_query, total_instances)
with torch.no_grad():
asso_output, embed = model(frames, query_frame=query_frame)
asso_output, embed = model(frames, query_frame=query_ind)

Check warning on line 237 in biogtr/inference/tracker.py

View check run for this annotation

Codecov / codecov/patch

biogtr/inference/tracker.py#L237

Added line #L237 was not covered by tests
# if model.transformer.return_embedding:
# query_frame.embeddings = embed TODO add embedding to Instance Object
# if query_frame == 1:
Expand Down Expand Up @@ -262,7 +264,7 @@ def _run_global_tracker(
[
x.get_pred_track_ids()
for batch_idx, x in enumerate(frames)
if batch_idx != query_frame
if batch_idx != query_ind
],
dim=0,
).view(
Expand All @@ -280,8 +282,8 @@ def _run_global_tracker(
query_inds = [

Check warning on line 282 in biogtr/inference/tracker.py

View check run for this annotation

Codecov / codecov/patch

biogtr/inference/tracker.py#L282

Added line #L282 was not covered by tests
x
for x in range(
sum(instances_per_frame[:query_frame]),
sum(instances_per_frame[: query_frame + 1]),
sum(instances_per_frame[:query_ind]),
sum(instances_per_frame[: query_ind + 1]),
)
]
nonquery_inds = [i for i in range(total_instances) if i not in query_inds]
Expand All @@ -308,7 +310,7 @@ def _run_global_tracker(

# (n_query x n_nonquery) x (n_nonquery x n_traj) --> n_k x n_traj
traj_score = post_processing.weight_decay_time(
asso_nonquery, self.decay_time, reid_features, window_size, query_frame
asso_nonquery, self.decay_time, reid_features, window_size, query_ind
)

traj_score = torch.mm(traj_score, id_inds.cpu()) # (n_query, n_traj)
Expand Down

0 comments on commit 7c4c645

Please sign in to comment.