Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use instances as model input #46

Merged
merged 5 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion biogtr/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def __init__(
self._device = device
self.to(self._device)

self._frame = None

def __repr__(self) -> str:
"""Return string representation of the Instance."""
return (
Expand Down Expand Up @@ -421,6 +423,26 @@ def has_features(self) -> bool:
else:
return True

@property
def frame(self) -> "Frame":
"""Get the frame the instance belongs to.

Returns:
The back reference to the `Frame` that this `Instance` belongs to.
"""
return self._frame

@frame.setter
def frame(self, frame: "Frame") -> None:
"""Set the back reference to the `Frame` that this `Instance` belongs to.

This field is set when instances are added to `Frame` object.

Args:
frame: A `Frame` object containing the metadata for the frame that the instance belongs to
"""
self._frame = frame

aaprasad marked this conversation as resolved.
Show resolved Hide resolved
@property
def pose(self) -> dict[str, ArrayLike]:
"""Get the pose of the instance.
Expand Down Expand Up @@ -580,9 +602,12 @@ def __init__(
self._img_shape = img_shape
else:
self._img_shape = torch.tensor([img_shape])

if instances is None:
self.instances = []
else:
for instance in instances:
instance.frame = self
self._instances = instances

self._asso_output = asso_output
Expand Down Expand Up @@ -612,7 +637,7 @@ def __repr__(self) -> str:
f"img_shape={self._img_shape}, "
f"num_detected={self.num_detected}, "
f"asso_output={self._asso_output}, "
f"traj_score={self._traj_score}, "
f"traj_score={list(self._traj_score.keys())}, "
f"matches={self._matches}, "
f"instances={self._instances}, "
f"device={self._device}"
Expand Down Expand Up @@ -796,6 +821,9 @@ def instances(self, instances: List[Instance]) -> None:
Args:
instances: A list of Instances that appear in the frame.
"""
for instance in instances:
instance.frame = self

self._instances = instances

def has_instances(self) -> bool:
Expand Down
14 changes: 10 additions & 4 deletions biogtr/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame
# W: width.

for batch_idx, frame_to_track in enumerate(frames):
tracked_frames = self.track_queue.collate_tracks()
tracked_frames = self.track_queue.collate_tracks(
device=frame_to_track.frame_id.device
)
if self.verbose:
warnings.warn(
f"Current number of tracks is {self.track_queue.n_tracks}"
Expand Down Expand Up @@ -229,8 +231,12 @@ def _run_global_tracker(
# E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window.

_ = model.eval()

query_frame = frames[query_ind]

query_instances = query_frame.instances
all_instances = [instance for frame in frames for instance in frame.instances]

if self.verbose:
print(f"Frame {query_frame.frame_id.item()}")

Expand All @@ -253,7 +259,7 @@ def _run_global_tracker(

# (L=1, n_query, total_instances)
with torch.no_grad():
asso_output, embed = model(frames, query_frame=query_ind)
asso_output, embed = model(all_instances, query_instances)
# if model.transformer.return_embedding:
# query_frame.embeddings = embed TODO add embedding to Instance Object
# if query_frame == 1:
Expand Down Expand Up @@ -321,6 +327,7 @@ def _run_global_tracker(
]
nonquery_inds = [i for i in range(total_instances) if i not in query_inds]

# instead should we do model(nonquery_instances, query_instances)?
asso_nonquery = asso_output[:, nonquery_inds] # (n_query, n_nonquery)

asso_nonquery_df = pd.DataFrame(
Expand All @@ -332,10 +339,9 @@ def _run_global_tracker(

query_frame.add_traj_score("asso_nonquery", asso_nonquery_df)

pred_boxes, _ = model_utils.get_boxes_times(frames)
pred_boxes = model_utils.get_boxes(all_instances)
query_boxes = pred_boxes[query_inds] # n_k x 4
nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4
# TODO: Insert postprocessing.

unique_ids = torch.unique(instance_ids) # (n_nonquery,)

Expand Down
56 changes: 42 additions & 14 deletions biogtr/models/global_tracking_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from biogtr.models.transformer import Transformer
from biogtr.models.visual_encoder import VisualEncoder
from biogtr.data_structures import Frame
from torch import nn
from biogtr.data_structures import Instance
import torch
Comment on lines +5 to +6
Copy link
Contributor

@coderabbitai coderabbitai bot May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import.

- from biogtr.data_structures import Instance

The Instance class is not used in this file. Removing unused imports helps keep the code clean and maintainable.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from biogtr.data_structures import Instance
import torch
import torch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used for type hinting

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aaprasad, thank you for pointing that out. You're absolutely right; the Instance class is used for type hinting in the forward method parameters. I appreciate your correction, and I'll make sure to consider this in future reviews. 🐇


New Learnings Added
User: aaprasad
PR: talmolab/biogtr#46
File: biogtr/models/global_tracking_transformer.py:5-6
Timestamp: 2024-05-29T17:50:29.550Z
Learning: The `Instance` class is used for type hinting in the `forward` method parameters of the `GlobalTrackingTransformer` class in `biogtr/models/global_tracking_transformer.py`.

Note: Learnings are effective only in the context of specific code segments. To apply general coding guidelines, please configure review instructions.


# todo: do we want to handle params with configs already here?


class GlobalTrackingTransformer(nn.Module):
class GlobalTrackingTransformer(torch.nn.Module):
"""Modular GTR model composed of visual encoder + transformer used for tracking."""

def __init__(
Expand Down Expand Up @@ -79,26 +79,54 @@ def __init__(
decoder_self_attn=decoder_self_attn,
)

def forward(self, frames: list[Frame], query_frame: int = None):
def forward(
self, ref_instances: list[Instance], query_instances: list[Instance] = None
):
"""Execute forward pass of GTR Model to get asso matrix.

Args:
frames: List of Frames from chunk containing crops of objects + gt label info
query_frame: Frame index used as query for self attention. Only used in sliding inference where query frame is the last frame in the window.
ref_instances: List of instances from chunk containing crops of objects + gt label info
query_instances: list of instances used as query in decoder.

Returns:
An N_T x N association matrix
"""
# Extract feature representations with pre-trained encoder.
for frame in filter(
lambda f: f.has_instances() and not f.has_features(), frames
):
crops = frame.get_crops()
z = self.visual_encoder(crops)
self.extract_features(ref_instances)

for i, z_i in enumerate(z):
frame.instances[i].features = z_i
if query_instances:
self.extract_features(query_instances)

asso_preds, emb = self.transformer(frames, query_frame=query_frame)
asso_preds, emb = self.transformer(ref_instances, query_instances)

return asso_preds, emb

def extract_features(
self, instances: list["Instance"], force_recompute: bool = False
) -> None:
"""Extract features from instances using visual encoder backbone.

Args:
instances: A list of instances to compute features for
force_recompute: indicate whether to compute features for all instances regardless of if they have instances
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
force_recompute: indicate whether to compute features for all instances regardless of if they have instances
force_recompute: indicate whether to compute features for all instances regardless of if they have features

"""
if not force_recompute:
instances_to_compute = [
instance
for instance in instances
if instance.has_crop() and not instance.has_features()
]
else:
instances_to_compute = instances

if len(instances_to_compute) == 0:
return
elif len(instances_to_compute) == 1: # handle batch norm error when B=1
instances_to_compute = instances

crops = torch.concatenate([instance.crop for instance in instances_to_compute])

features = self.visual_encoder(crops)

for i, z_i in enumerate(features):
instances_to_compute[i].features = z_i
74 changes: 39 additions & 35 deletions biogtr/models/gtr_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from biogtr.training.losses import AssoLoss
from biogtr.models.model_utils import init_optimizer, init_scheduler
from pytorch_lightning import LightningModule
from biogtr.data_structures import Frame, Instance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import.

- from biogtr.data_structures import Frame, Instance
+ from biogtr.data_structures import Frame

The Instance class is not used in this file. Removing unused imports helps keep the code clean and maintainable.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from biogtr.data_structures import Frame, Instance
from biogtr.data_structures import Frame



class GTRRunner(LightningModule):
Expand Down Expand Up @@ -59,28 +60,29 @@ def __init__(
self.metrics = metrics
self.persistent_tracking = persistent_tracking

def forward(self, instances) -> torch.Tensor:
def forward(
self, ref_instances: list[Instance], query_instances: list[Instance] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the docstring for the new signature

) -> torch.Tensor:
"""Execute forward pass of the lightning module.

Args:
instances: a list of dicts where each dict is a frame with gt data
ref_instances: a list of `Instance` objects containing crops and other data needed for transformer model
query_instances: a list of `Instance` objects used as queries in the decoder. Mostly used for inference.

Returns:
An association matrix between objects
"""
if sum([frame.num_detected for frame in instances]) > 0:
asso_preds, _ = self.model(instances)
return asso_preds
return None
asso_preds, _ = self.model(ref_instances, query_instances)
return asso_preds

def training_step(
self, train_batch: list[dict], batch_idx: int
self, train_batch: list[list[Frame]], batch_idx: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring

) -> dict[str, float]:
"""Execute single training step for model.

Args:
train_batch: A single batch from the dataset which is a list of dicts
with length `clip_length` where each dict is a frame
train_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with length `clip_length` containing Instances and other metadata.
with length `clip_length` containing Instances and other metadata.

batch_idx: the batch number used by lightning

Returns:
Expand All @@ -92,13 +94,13 @@ def training_step(
return result

def validation_step(
self, val_batch: list[dict], batch_idx: int
self, val_batch: list[list[Frame]], batch_idx: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring

) -> dict[str, float]:
"""Execute single val step for model.

Args:
val_batch: A single batch from the dataset which is a list of dicts
with length `clip_length` where each dict is a frame
val_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning

Returns:
Expand All @@ -109,12 +111,14 @@ def validation_step(

return result

def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]:
def test_step(
self, test_batch: list[list[Frame]], batch_idx: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring

) -> dict[str, float]:
"""Execute single test step for model.

Args:
val_batch: A single batch from the dataset which is a list of dicts
with length `clip_length` where each dict is a frame
test_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning

Returns:
Expand All @@ -125,57 +129,57 @@ def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]:

return result

def predict_step(self, batch: list[dict], batch_idx: int) -> dict:
def predict_step(self, batch: list[list[Frame]], batch_idx: int) -> list[Frame]:
"""Run inference for model.

Computes association + assignment.

Args:
batch: A single batch from the dataset which is a list of dicts
with length `clip_length` where each dict is a frame
batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning

Returns:
A list of dicts where each dict is a frame containing the predicted track ids
"""
self.tracker.persistent_tracking = True
instances_pred = self.tracker(self.model, batch[0])
return instances_pred
frames_pred = self.tracker(self.model, batch[0])
return frames_pred

def _shared_eval_step(self, instances, mode):
def _shared_eval_step(self, frames: list[Frame], mode: str) -> dict[str, float]:
"""Run evaluation used by train, test, and val steps.

Args:
instances: A list of dicts where each dict is a frame containing gt data
frames: A list of `Frame` objects with length `clip_length` containing Instances and other metadata.
mode: which metrics to compute and whether to use persistent tracking or not

Returns:
a dict containing the loss and any other metrics specified by `eval_metrics`
"""
try:
instances = [frame for frame in instances if frame.has_instances()]
instances = [instance for frame in frames for instance in frame.instances]
if len(instances) == 0:
return None

eval_metrics = self.metrics[mode]
persistent_tracking = self.persistent_tracking[mode]

logits = self(instances)

if not logits:
return None

loss = self.loss(logits, instances)
loss = self.loss(logits, frames)

return_metrics = {"loss": loss}
if eval_metrics is not None and len(eval_metrics) > 0:
self.tracker.persistent_tracking = persistent_tracking
instances_pred = self.tracker(self.model, instances)
instances_mm = metrics.to_track_eval(instances_pred)
clearmot = metrics.get_pymotmetrics(instances_mm, eval_metrics)

frames_pred = self.tracker(self.model, frames)

frames_mm = metrics.to_track_eval(frames_pred)
clearmot = metrics.get_pymotmetrics(frames_mm, eval_metrics)

return_metrics.update(clearmot.to_dict())
return_metrics["batch_size"] = len(instances)
return_metrics["batch_size"] = len(frames)
except Exception as e:
print(
f"Failed on frame {instances[0].frame_id} of video {instances[0].video_id}"
)
print(f"Failed on frame {frames[0].frame_id} of video {frames[0].video_id}")
raise (e)

return return_metrics
Expand Down
Loading
Loading