Skip to content

Commit

Permalink
store embeddings in Instance object instead of returning
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad committed May 21, 2024
1 parent 7ff22e7 commit 89007a9
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 35 deletions.
2 changes: 1 addition & 1 deletion biogtr/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _run_global_tracker(

# (L=1, n_query, total_instances)
with torch.no_grad():
asso_matrix, embed = model(all_instances, query_instances)
asso_matrix = model(all_instances, query_instances)
# if model.transformer.return_embedding:
# query_frame.embeddings = embed TODO add embedding to Instance Object
# if query_frame == 1:
Expand Down
6 changes: 5 additions & 1 deletion biogtr/io/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,12 @@ def to(self, map_location: str):
for key, val in self._traj_score.items():
if isinstance(val, torch.Tensor):
self._traj_score[key] = val.to(map_location)

Check warning on line 116 in biogtr/io/frame.py

View check run for this annotation

Codecov / codecov/patch

biogtr/io/frame.py#L115-L116

Added lines #L115 - L116 were not covered by tests
for instance in self.instances:
instance = instance.to(map_location)

if isinstance(map_location, str):
self._device = map_location

Check warning on line 121 in biogtr/io/frame.py

View check run for this annotation

Codecov / codecov/patch

biogtr/io/frame.py#L121

Added line #L121 was not covered by tests

self._device = map_location
return self

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion biogtr/io/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def to(self, map_location):
self._bbox = self._bbox.to(map_location)
self._crop = self._crop.to(map_location)
self._features = self._features.to(map_location)
self.device = map_location
if isinstance(map_location, str):
self.device = map_location

Check warning on line 150 in biogtr/io/instance.py

View check run for this annotation

Codecov / codecov/patch

biogtr/io/instance.py#L150

Added line #L150 was not covered by tests

return self

Expand Down
6 changes: 3 additions & 3 deletions biogtr/models/global_tracking_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(

def forward(
self, ref_instances: list[Instance], query_instances: list[Instance] = None
):
) -> list["AssociationMatrix"]:
"""Execute forward pass of GTR Model to get asso matrix.
Args:
Expand Down Expand Up @@ -119,6 +119,6 @@ def forward(
for i, z_i in enumerate(query_z):
query_instances[i].features = z_i

Check warning on line 120 in biogtr/models/global_tracking_transformer.py

View check run for this annotation

Codecov / codecov/patch

biogtr/models/global_tracking_transformer.py#L118-L120

Added lines #L118 - L120 were not covered by tests

asso_preds, emb = self.transformer(ref_instances, query_instances)
asso_preds = self.transformer(ref_instances, query_instances)

return asso_preds, emb
return asso_preds
2 changes: 1 addition & 1 deletion biogtr/models/gtr_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def forward(
Returns:
An association matrix between objects
"""
asso_preds, _ = self.model(ref_instances, query_instances)
asso_preds = self.model(ref_instances, query_instances)
return asso_preds

def training_step(
Expand Down
26 changes: 11 additions & 15 deletions biogtr/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _reset_parameters(self):

def forward(
self, ref_instances: list[Instance], query_instances: list[Instance] = None
) -> tuple[list[AssociationMatrix], dict[str, torch.Tensor]]:
) -> list[AssociationMatrix]:
"""Execute a forward pass through the transformer and attention head.
Args:
Expand All @@ -154,8 +154,6 @@ def forward(
L: number of decoder blocks
n_query: number of instances in current query/frame
total_instances: number of instances in window
embedding_dict: A dictionary containing the "pos" and "temp" embeddings
if `self.return_embeddings` is False then they are None.
"""
ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
Expand All @@ -165,10 +163,6 @@ def forward(
# instances_per_frame = [frame.num_detected for frame in frames]
total_instances = len(ref_instances)
embed_dim = ref_features.shape[-1]
embeddings_dict = {
"ref": {"pos": None, "temp": None},
"query": {"pos": None, "temp": None},
}
# print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}')
ref_boxes = get_boxes(ref_instances) # total_instances, 4
ref_boxes = torch.nan_to_num(ref_boxes, -1.0)
Expand All @@ -177,12 +171,13 @@ def forward(
window_length = len(ref_times.unique())

ref_temp_emb = self.temp_emb(ref_times / window_length)
if self.return_embedding:
embeddings_dict["ref"]["temp"] = ref_temp_emb

ref_pos_emb = self.pos_emb(ref_boxes)

if self.return_embedding:
embeddings_dict["ref"]["pos"] = ref_pos_emb
for i, instance in enumerate(ref_instances):
instance.add_embedding("pos", ref_pos_emb[i])
instance.add_embedding("temp", ref_temp_emb[i])

ref_emb = (ref_pos_emb + ref_temp_emb) / 2.0

Expand Down Expand Up @@ -223,12 +218,8 @@ def forward(
query_boxes = get_boxes(query_instances)

query_temp_emb = self.temp_emb(query_times / window_length)
if self.return_embedding:
embeddings_dict["query"]["temp"] = query_temp_emb

query_pos_emb = self.pos_emb(query_boxes)
if self.return_embedding:
embeddings_dict["query"]["pos"] = query_pos_emb

query_emb = (query_pos_emb + query_temp_emb) / 2.0

Expand All @@ -238,6 +229,11 @@ def forward(
else:
query_instances = ref_instances

if self.return_embedding:
for i, instance in enumerate(query_instances):
instance.add_embedding("pos", query_pos_emb[i])
instance.add_embedding("temp", query_temp_emb[i])

decoder_features = self.decoder(
query_features,
encoder_features,
Expand All @@ -262,7 +258,7 @@ def forward(
asso_output.append(asso_matrix)

# (L=1, n_query, total_instances)
return (asso_output, embeddings_dict)
return asso_output


class TransformerEncoderLayer(nn.Module):
Expand Down
48 changes: 35 additions & 13 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def test_transformer_basic():
)

instances = [instance for frame in frames for instance in frame.instances]
asso_preds, _ = transformer(instances)
asso_preds = transformer(instances)

assert asso_preds[0].matrix.size() == (num_detected * num_frames,) * 2

Expand Down Expand Up @@ -458,15 +458,26 @@ def test_transformer_embedding():
assert transformer.pos_emb.mode == "learned"
assert transformer.temp_emb.mode == "learned"

asso_preds, embeddings = transformer(instances)
asso_preds = transformer(instances)

assert asso_preds[0].matrix.size() == (num_detected * num_frames,) * 2

for emb_type, embedding in embeddings["ref"].items():
assert embedding.size() == (
num_detected * num_frames,
feats,
), f"{emb_type}, {embedding.size()}"
pos_emb = torch.concat(
[instance.get_embedding("pos") for instance in instances], axis=0
)
temp_emb = torch.concat(
[instance.get_embedding("pos") for instance in instances], axis=0
)

assert pos_emb.size() == (
len(instances),
feats,
), pos_emb.shape

assert temp_emb.size() == (
len(instances),
feats,
), temp_emb.shape


def test_tracking_transformer():
Expand Down Expand Up @@ -511,12 +522,23 @@ def test_tracking_transformer():
return_embedding=True,
)
instances = [instance for frame in frames for instance in frame.instances]
asso_preds, embeddings = tracking_transformer(instances)
asso_preds = tracking_transformer(instances)

assert asso_preds[0].matrix.size() == (num_detected * num_frames,) * 2

for emb_type, embedding in embeddings["ref"].items():
assert embedding.size() == (
num_detected * num_frames,
feats,
), embeddings
pos_emb = torch.concat(
[instance.get_embedding("pos") for instance in instances], axis=0
)
temp_emb = torch.concat(
[instance.get_embedding("pos") for instance in instances], axis=0
)

assert pos_emb.size() == (
len(instances),
feats,
), pos_emb.shape

assert temp_emb.size() == (
len(instances),
feats,
), temp_emb.shape

0 comments on commit 89007a9

Please sign in to comment.