Skip to content

Commit

Permalink
use instances as input into model instead of frames
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad committed May 15, 2024
1 parent 7887d56 commit 9bab7bc
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 110 deletions.
14 changes: 10 additions & 4 deletions biogtr/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame
# W: width.

for batch_idx, frame_to_track in enumerate(frames):
tracked_frames = self.track_queue.collate_tracks()
tracked_frames = self.track_queue.collate_tracks(
device=frame_to_track.frame_id.device
)
if self.verbose:
warnings.warn(
f"Current number of tracks is {self.track_queue.n_tracks}"
Expand Down Expand Up @@ -229,8 +231,12 @@ def _run_global_tracker(
# E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window.

_ = model.eval()

query_frame = frames[query_ind]

query_instances = query_frame.instances
all_instances = [instance for frame in frames for instance in frame.instances]

if self.verbose:
print(f"Frame {query_frame.frame_id.item()}")

Expand All @@ -253,7 +259,7 @@ def _run_global_tracker(

# (L=1, n_query, total_instances)
with torch.no_grad():
asso_output, embed = model(frames, query_frame=query_ind)
asso_output, embed = model(all_instances, query_instances)
# if model.transformer.return_embedding:
# query_frame.embeddings = embed TODO add embedding to Instance Object
# if query_frame == 1:
Expand Down Expand Up @@ -321,6 +327,7 @@ def _run_global_tracker(
]
nonquery_inds = [i for i in range(total_instances) if i not in query_inds]

# instead should we do model(nonquery_instances, query_instances)?
asso_nonquery = asso_output[:, nonquery_inds] # (n_query, n_nonquery)

asso_nonquery_df = pd.DataFrame(
Expand All @@ -332,10 +339,9 @@ def _run_global_tracker(

query_frame.add_traj_score("asso_nonquery", asso_nonquery_df)

pred_boxes, _ = model_utils.get_boxes_times(frames)
pred_boxes = model_utils.get_boxes(all_instances)
query_boxes = pred_boxes[query_inds] # n_k x 4
nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4
# TODO: Insert postprocessing.

unique_ids = torch.unique(instance_ids) # (n_nonquery,)

Expand Down
42 changes: 31 additions & 11 deletions biogtr/models/global_tracking_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from biogtr.models.transformer import Transformer
from biogtr.models.visual_encoder import VisualEncoder
from biogtr.data_structures import Frame
from torch import nn
from biogtr.data_structures import Instance
import torch

# todo: do we want to handle params with configs already here?


class GlobalTrackingTransformer(nn.Module):
class GlobalTrackingTransformer(torch.nn.Module):
"""Modular GTR model composed of visual encoder + transformer used for tracking."""

def __init__(
Expand Down Expand Up @@ -79,7 +79,9 @@ def __init__(
decoder_self_attn=decoder_self_attn,
)

def forward(self, frames: list[Frame], query_frame: int = None):
def forward(
self, ref_instances: list[Instance], query_instances: list[Instance] = None
):
"""Execute forward pass of GTR Model to get asso matrix.
Args:
Expand All @@ -90,15 +92,33 @@ def forward(self, frames: list[Frame], query_frame: int = None):
An N_T x N association matrix
"""
# Extract feature representations with pre-trained encoder.
for frame in filter(
lambda f: f.has_instances() and not f.has_features(), frames
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in ref_instances
]
):
crops = frame.get_crops()
z = self.visual_encoder(crops)
ref_crops = torch.concat(
[instance.crop for instance in ref_instances], axis=0
)
ref_z = self.visual_encoder(ref_crops)
for i, z_i in enumerate(ref_z):
ref_instances[i].features = z_i

for i, z_i in enumerate(z):
frame.instances[i].features = z_i
if query_instances:
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in query_instances
]
):
query_crops = torch.concat(
[instance.crop for instance in query_instances], axis=0
)
query_z = self.visual_encoder(query_crops)
for i, z_i in enumerate(query_z):
query_instances[i].features = z_i

asso_preds, emb = self.transformer(frames, query_frame=query_frame)
asso_preds, emb = self.transformer(ref_instances, query_instances)

return asso_preds, emb
57 changes: 31 additions & 26 deletions biogtr/models/gtr_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from biogtr.training.losses import AssoLoss
from biogtr.models.model_utils import init_optimizer, init_scheduler
from pytorch_lightning import LightningModule
from biogtr.data_structures import Frame, Instance


class GTRRunner(LightningModule):
Expand Down Expand Up @@ -59,7 +60,9 @@ def __init__(
self.metrics = metrics
self.persistent_tracking = persistent_tracking

def forward(self, instances) -> torch.Tensor:
def forward(
self, ref_instances: list[Instance], query_instances: list[Instance] = None
) -> torch.Tensor:
"""Execute forward pass of the lightning module.
Args:
Expand All @@ -68,13 +71,11 @@ def forward(self, instances) -> torch.Tensor:
Returns:
An association matrix between objects
"""
if sum([frame.num_detected for frame in instances]) > 0:
asso_preds, _ = self.model(instances)
return asso_preds
return None
asso_preds, _ = self.model(ref_instances, query_instances)
return asso_preds

def training_step(
self, train_batch: list[dict], batch_idx: int
self, train_batch: list[list[Frame]], batch_idx: int
) -> dict[str, float]:
"""Execute single training step for model.
Expand All @@ -92,7 +93,7 @@ def training_step(
return result

def validation_step(
self, val_batch: list[dict], batch_idx: int
self, val_batch: list[list[Frame]], batch_idx: int
) -> dict[str, float]:
"""Execute single val step for model.
Expand All @@ -109,7 +110,9 @@ def validation_step(

return result

def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]:
def test_step(
self, test_batch: list[list[Frame]], batch_idx: int
) -> dict[str, float]:
"""Execute single test step for model.
Args:
Expand All @@ -125,7 +128,7 @@ def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]:

return result

def predict_step(self, batch: list[dict], batch_idx: int) -> dict:
def predict_step(self, batch: list[list[Frame]], batch_idx: int) -> list[Frame]:
"""Run inference for model.
Computes association + assignment.
Expand All @@ -139,43 +142,45 @@ def predict_step(self, batch: list[dict], batch_idx: int) -> dict:
A list of dicts where each dict is a frame containing the predicted track ids
"""
self.tracker.persistent_tracking = True
instances_pred = self.tracker(self.model, batch[0])
return instances_pred
frames_pred = self.tracker(self.model, batch[0])
return frames_pred

def _shared_eval_step(self, instances, mode):
def _shared_eval_step(self, frames: list[Frame], mode: str) -> dict[str, float]:
"""Run evaluation used by train, test, and val steps.
Args:
instances: A list of dicts where each dict is a frame containing gt data
frames: A list of dicts where each dict is a frame containing gt data
mode: which metrics to compute and whether to use persistent tracking or not
Returns:
a dict containing the loss and any other metrics specified by `eval_metrics`
"""
try:
instances = [frame for frame in instances if frame.has_instances()]
frames = [frame for frame in frames if frame.has_instances()]
if len(frames) == 0:
return None

instances = [instance for frame in frames for instance in frame.instances]

eval_metrics = self.metrics[mode]
persistent_tracking = self.persistent_tracking[mode]

logits = self(instances)

if not logits:
return None

loss = self.loss(logits, instances)
loss = self.loss(logits, frames)

return_metrics = {"loss": loss}
if eval_metrics is not None and len(eval_metrics) > 0:
self.tracker.persistent_tracking = persistent_tracking
instances_pred = self.tracker(self.model, instances)
instances_mm = metrics.to_track_eval(instances_pred)
clearmot = metrics.get_pymotmetrics(instances_mm, eval_metrics)

frames_pred = self.tracker(self.model, frames)

frames_mm = metrics.to_track_eval(frames_pred)
clearmot = metrics.get_pymotmetrics(frames_mm, eval_metrics)

return_metrics.update(clearmot.to_dict())
return_metrics["batch_size"] = len(instances)
return_metrics["batch_size"] = len(frames)
except Exception as e:
print(
f"Failed on frame {instances[0].frame_id} of video {instances[0].video_id}"
)
print(f"Failed on frame {frames[0].frame_id} of video {frames[0].video_id}")
raise (e)

return return_metrics
Expand Down
Loading

0 comments on commit 9bab7bc

Please sign in to comment.