Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use instances as model input #46

Merged
merged 5 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions biogtr/models/global_tracking_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,40 +85,48 @@ def forward(
"""Execute forward pass of GTR Model to get asso matrix.

Args:
frames: List of Frames from chunk containing crops of objects + gt label info
query_frame: Frame index used as query for self attention. Only used in sliding inference where query frame is the last frame in the window.
ref_instances: List of instances from chunk containing crops of objects + gt label info
query_instances: list of instances used as query in decoder.

Returns:
An N_T x N association matrix
"""
# Extract feature representations with pre-trained encoder.
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in ref_instances
]
):
ref_crops = torch.concat(
[instance.crop for instance in ref_instances], axis=0
)
ref_z = self.visual_encoder(ref_crops)
for i, z_i in enumerate(ref_z):
ref_instances[i].features = z_i
self.extract_features(ref_instances)

if query_instances:
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in query_instances
]
):
query_crops = torch.concat(
[instance.crop for instance in query_instances], axis=0
)
query_z = self.visual_encoder(query_crops)
for i, z_i in enumerate(query_z):
query_instances[i].features = z_i
self.extract_features(query_instances)

asso_preds, emb = self.transformer(ref_instances, query_instances)

return asso_preds, emb

def extract_features(
self, instances: list["Instance"], force_recompute: bool = False
) -> None:
"""Extract features from instances using visual encoder backbone.

Args:
instances: A list of instances to compute features for
force_recompute: indicate whether to compute features for all instances regardless of if they have instances
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
force_recompute: indicate whether to compute features for all instances regardless of if they have instances
force_recompute: indicate whether to compute features for all instances regardless of if they have features

"""
if not force_recompute:
instances_to_compute = [
instance
for instance in instances
if instance.has_crop() and not instance.has_features()
]
else:
instances_to_compute = instances

if len(instances_to_compute) == 0:
return
elif len(instances_to_compute) == 1: # handle batch norm error when B=1
instances_to_compute = instances

crops = torch.concatenate([instance.crop for instance in instances_to_compute])

features = self.visual_encoder(crops)

for i, z_i in enumerate(features):
instances_to_compute[i].features = z_i
27 changes: 13 additions & 14 deletions biogtr/models/gtr_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def forward(
"""Execute forward pass of the lightning module.

Args:
instances: a list of dicts where each dict is a frame with gt data
ref_instances: a list of `Instance` objects containing crops and other data needed for transformer model
query_instances: a list of `Instance` objects used as queries in the decoder. Mostly used for inference.

Returns:
An association matrix between objects
Expand All @@ -80,8 +81,8 @@ def training_step(
"""Execute single training step for model.

Args:
train_batch: A single batch from the dataset which is a list of dicts
with length `clip_length` where each dict is a frame
train_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with length `clip_length` containing Instances and other metadata.
with length `clip_length` containing Instances and other metadata.

batch_idx: the batch number used by lightning

Returns:
Expand All @@ -98,8 +99,8 @@ def validation_step(
"""Execute single val step for model.

Args:
val_batch: A single batch from the dataset which is a list of dicts
with length `clip_length` where each dict is a frame
val_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning

Returns:
Expand All @@ -116,8 +117,8 @@ def test_step(
"""Execute single test step for model.

Args:
val_batch: A single batch from the dataset which is a list of dicts
with length `clip_length` where each dict is a frame
test_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning

Returns:
Expand All @@ -134,8 +135,8 @@ def predict_step(self, batch: list[list[Frame]], batch_idx: int) -> list[Frame]:
Computes association + assignment.

Args:
batch: A single batch from the dataset which is a list of dicts
with length `clip_length` where each dict is a frame
batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning

Returns:
Expand All @@ -149,18 +150,16 @@ def _shared_eval_step(self, frames: list[Frame], mode: str) -> dict[str, float]:
"""Run evaluation used by train, test, and val steps.

Args:
frames: A list of dicts where each dict is a frame containing gt data
frames: A list of `Frame` objects with length `clip_length` containing Instances and other metadata.
mode: which metrics to compute and whether to use persistent tracking or not

Returns:
a dict containing the loss and any other metrics specified by `eval_metrics`
"""
try:
frames = [frame for frame in frames if frame.has_instances()]
if len(frames) == 0:
return None

instances = [instance for frame in frames for instance in frame.instances]
if len(instances) == 0:
return None

eval_metrics = self.metrics[mode]
persistent_tracking = self.persistent_tracking[mode]
Expand Down
5 changes: 3 additions & 2 deletions biogtr/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import torch


def get_boxes(instances: List[Instance]) -> torch.tensor:
def get_boxes(instances: List[Instance]) -> torch.Tensor:
"""Extract the bounding boxes from the input list of instances.

Args:
instances: List of Instance objects.

Returns:
The bounding boxes normalized by the height and width of the image
An (n_instances, n_points, 4) float tensor containing the bounding boxes
normalized by the height and width of the image
"""
boxes = []
for i, instance in enumerate(instances):
Expand Down
1 change: 0 additions & 1 deletion biogtr/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def forward(
embedding_dict: A dictionary containing the "pos" and "temp" embeddings
if `self.return_embeddings` is False then they are None.
"""

ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)
Expand Down
Loading