Skip to content

Commit

Permalink
Merge branch 'main' into aadi/remove-kwargs-mutable-defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad authored Jun 3, 2024
2 parents adb3715 + 98106b7 commit 5d021d4
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 114 deletions.
3 changes: 1 addition & 2 deletions biogtr/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame
Args:
model: the pretrained GlobalTrackingTransformer to be used for inference
frame: A list of Frames (See `biogtr.io.data_structures.Frame` for more info).
frames: A list of Frames (See `biogtr.io.Frame` for more info).
Returns:
Frames: A list of Frames populated with pred_track_ids and asso_matrices
Expand Down
153 changes: 87 additions & 66 deletions biogtr/io/association_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class AssociationMatrix:
query_instances: query instances that were associated against ref instances.
"""

matrix: Union[np.ndarray, torch.Tensor] = attrs.field()
matrix: Union[np.ndarray, torch.Tensor]
ref_instances: list[Instance] = attrs.field()
query_instances: list[Instance] = attrs.field()

Expand Down Expand Up @@ -56,7 +56,7 @@ def _check_query_instances(self, attribute, value):
raise ValueError(
(
"Query instances must equal number of rows in Association matrix"
f"Found {len(value)} query instances but {self.matrix.shape[0]} columns."
f"Found {len(value)} query instances but {self.matrix.shape[0]} rows."
)
)

Expand All @@ -83,82 +83,100 @@ def numpy(self) -> np.ndarray:
return self.matrix

def to_dataframe(
self, row_label: str = "gt", col_label: str = "gt"
self, row_labels: str = "gt", col_labels: str = "gt"
) -> pd.DataFrame:
"""Convert the association matrix to a pandas DataFrame.
Args:
row_label: How to label the rows(queries).
If `gt` then label by gt track id.
If `pred` then label by pred track id.
Otherwise label by the query_instance indices
col_label: How to label the columns(references).
If `gt` then label by gt track id.
If `pred` then label by pred track id.
Otherwise label by the ref_instance indices
row_labels: How to label the rows(queries).
If list, then must match # of rows/queries
If `"gt"` then label by gt track id.
If `"pred"` then label by pred track id.
Otherwise label by the query_instance indices
col_labels: How to label the columns(references).
If list, then must match # of columns/refs
If `"gt"` then label by gt track id.
If `"pred"` then label by pred track id.
Otherwise label by the ref_instance indices
Returns:
The association matrix as a pandas dataframe.
"""
matrix = self.numpy()

if row_label.lower() == "gt":
row_inds = [
instance.gt_track_id.item() for instance in self.query_instances
]

elif row_label.lower() == "pred":
row_inds = [
instance.pred_track_id.item() for instance in self.query_instances
]

if not isinstance(row_labels, str):
if len(row_labels) == len(self.query_instances):
row_inds = row_labels
else:
raise ValueError(
(
f"Mismatched # of rows and labels!",
f"Found {len(row_labels)} with {len(self.query_instances)} rows",
)
)
else:
row_inds = np.arange(len(self.query_instances))
if row_labels == "gt":
row_inds = [
instance.gt_track_id.item() for instance in self.query_instances
]

if col_label.lower() == "gt":
col_inds = [instance.gt_track_id.item() for instance in self.ref_instances]
elif row_labels == "pred":
row_inds = [
instance.pred_track_id.item() for instance in self.query_instances
]

elif col_label.lower() == "pred":
col_inds = [
instance.pred_track_id.item() for instance in self.ref_instances
]
else:
row_inds = np.arange(len(self.query_instances))

if not isinstance(col_labels, str):
if len(col_labels) == len(self.ref_instances):
col_inds = col_labels
else:
raise ValueError(
(
f"Mismatched # of columns and labels!",
f"Found {len(col_labels)} with {len(self.ref_instances)} columns",
)
)
else:
col_inds = np.arange(len(self.ref_instances))
if col_labels == "gt":
col_inds = [
instance.gt_track_id.item() for instance in self.ref_instances
]

elif col_labels == "pred":
col_inds = [
instance.pred_track_id.item() for instance in self.ref_instances
]

else:
col_inds = np.arange(len(self.ref_instances))

asso_df = pd.DataFrame(matrix, index=row_inds, columns=col_inds)

return asso_df

def reduce(
self,
to: tuple[str] = ("inst", "traj"),
by: tuple[str] = (None, "pred"),
row_dims: str = "instance",
col_dims: str = "track",
row_grouping: str = None,
col_grouping: str = "pred",
reduce_method: callable = np.sum,
) -> pd.DataFrame:
"""Reduce association matrix rows/columns to inst/traj x traj/inst.
"""Aggregate the association matrix by specified dimensions and grouping.
Args:
to: A tuple indicating how to reduce rows/columns. Either inst (remains unchanged), or traj (reduces matrix from n_query/n_ref to n_traj
by: A tuple indicating how to group rows/columns when aggregating. Either "pred" or "gt"
method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing.
row_dims: A str indicating how to what dimensions to reduce rows to.
Either "instance" (remains unchanged), or "track" (n_rows=n_traj).
col_dims: A str indicating how to dimensions to reduce rows to.
Either "instance" (remains unchanged), or "track" (n_cols=n_traj)
row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt".
col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt".
reduce_method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing.
Returns:
The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe.
"""
row_to, col_to = to
row_by, col_by = by

if row_to not in ("inst", "traj") or col_to not in ("inst", "traj"):
raise ValueError(
f"Can only reduce to inst/traj x inst/traj but ({row_to} x {col_to}) was requested!"
)

if row_by not in ("pred", "gt", None) or col_by not in ("pred", "gt", None):
raise ValueError(
f"Can aggregate by [gt, pred, None] but {row_by} and {col_by} was requested!"
)

n_rows = len(self.query_instances)
n_cols = len(self.ref_instances)

Expand All @@ -168,26 +186,25 @@ def reduce(
col_inds = [i for i in range(len(self.ref_instances))]
row_inds = [i for i in range(len(self.query_instances))]

if "tra" in col_to:
col_tracks = self.get_tracks(self.ref_instances, col_by)
if col_dims == "track":
col_tracks = self.get_tracks(self.ref_instances, col_grouping)
col_inds = list(col_tracks.keys())
n_cols = len(col_inds)
if "tra" in row_to:
row_tracks = self.get_tracks(self.query_instances, row_by)

if row_dims == "track":
row_tracks = self.get_tracks(self.query_instances, row_grouping)
row_inds = list(row_tracks.keys())
n_rows = len(row_inds)

reduced_matrix = []
for row_track, row_instances in row_tracks.items():
# print(row_instances)

for col_track, col_instances in col_tracks.items():
asso_matrix = self[row_instances, col_instances]
# print(col_instances)
# print(asso_matrix)
if "tra" in col_to:

if col_dims == "track":
asso_matrix = reduce_method(asso_matrix, axis=1)
if "tra" in row_to:

if row_dims == "track":
asso_matrix = reduce_method(asso_matrix, axis=0)
reduced_matrix.append(asso_matrix)

Expand All @@ -200,7 +217,7 @@ def __getitem__(self, inds) -> np.ndarray:
Args:
inds: A tuple of query indices and reference indices.
Indices can be either:
Indices can be either:
A single instance or integer.
A list of instances or integers.
Expand All @@ -211,6 +228,7 @@ def __getitem__(self, inds) -> np.ndarray:

query_ind = self.__getindices__(query_inst, self.query_instances)
ref_ind = self.__getindices__(ref_inst, self.ref_instances)

try:
return self.numpy()[query_ind[:, None], ref_ind].squeeze()
except IndexError as e:
Expand All @@ -229,18 +247,21 @@ def __getindices__(
Args:
instance: The instance(s) to be retrieved
Can either be a single int/instance or a list of int/instances/
Can either be a single int/instance or a list of int/instances
instance_lookup: A list of Instances to be used to retrieve indices
Returns:
A np array of indices.
"""
if isinstance(instance, Instance):
ind = np.array([instance_lookup.index(instance)])

elif instance is None:
ind = np.arange(len(instance_lookup))

elif np.isscalar(instance):
ind = np.array([instance])

else:
instances = instance
if not [isinstance(inst, (Instance, int)) for inst in instance]:
Expand Down Expand Up @@ -272,7 +293,7 @@ def get_tracks(
Returns:
A dictionary of track_id:instances
"""
if "pred" in label.lower():
if label == "pred":
traj_ids = set([instance.pred_track_id.item() for instance in instances])
traj = {
track_id: [
Expand All @@ -282,7 +303,8 @@ def get_tracks(
]
for track_id in traj_ids
}
elif "gt" in label.lower():

elif label == "gt":
traj_ids = set(
[instance.gt_track_id.item() for instance in self.ref_instances]
)
Expand All @@ -294,9 +316,8 @@ def get_tracks(
]
for track_id in traj_ids
}

else:
raise ValueError(
f"Can only group tracks by `pred` or `gt` {label.lower()} was requested!"
)
raise ValueError(f"Unsupported label '{label}'. Expected 'pred' or 'gt'.")

return traj
6 changes: 4 additions & 2 deletions biogtr/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_model(self) -> GlobalTrackingTransformer:

if ckpt_path is not None and len(ckpt_path) > 0:
return GTRRunner.load_from_checkpoint(ckpt_path).model

return GlobalTrackingTransformer(**model_params)

def get_tracker_cfg(self) -> dict:
Expand All @@ -104,13 +104,15 @@ def get_gtr_runner(self):
scheduler_params = self.cfg.scheduler
loss_params = self.cfg.loss
gtr_runner_params = self.cfg.runner

model_params = self.cfg.model

ckpt_path = model_params.pop("ckpt_path", None)

if ckpt_path is not None and ckpt_path != "":

model = GTRRunner.load_from_checkpoint(
self.cfg.model.ckpt_path,
ckpt_path,
tracker_cfg=tracker_params,
train_metrics=self.cfg.runner.metrics.train,
val_metrics=self.cfg.runner.metrics.val,
Expand Down
60 changes: 34 additions & 26 deletions biogtr/models/global_tracking_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,40 +84,48 @@ def forward(
"""Execute forward pass of GTR Model to get asso matrix.
Args:
frames: List of Frames from chunk containing crops of objects + gt label info
query_frame: Frame index used as query for self attention. Only used in sliding inference where query frame is the last frame in the window.
ref_instances: List of instances from chunk containing crops of objects + gt label info
query_instances: list of instances used as query in decoder.
Returns:
An N_T x N association matrix
"""
# Extract feature representations with pre-trained encoder.
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in ref_instances
]
):
ref_crops = torch.concat(
[instance.crop for instance in ref_instances], axis=0
)
ref_z = self.visual_encoder(ref_crops)
for i, z_i in enumerate(ref_z):
ref_instances[i].features = z_i
self.extract_features(ref_instances)

if query_instances:
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in query_instances
]
):
query_crops = torch.concat(
[instance.crop for instance in query_instances], axis=0
)
query_z = self.visual_encoder(query_crops)
for i, z_i in enumerate(query_z):
query_instances[i].features = z_i
self.extract_features(query_instances)

asso_preds = self.transformer(ref_instances, query_instances)

return asso_preds

def extract_features(
self, instances: list["Instance"], force_recompute: bool = False
) -> None:
"""Extract features from instances using visual encoder backbone.
Args:
instances: A list of instances to compute features for
force_recompute: indicate whether to compute features for all instances regardless of if they have instances
"""
if not force_recompute:
instances_to_compute = [
instance
for instance in instances
if instance.has_crop() and not instance.has_features()
]
else:
instances_to_compute = instances

if len(instances_to_compute) == 0:
return
elif len(instances_to_compute) == 1: # handle batch norm error when B=1
instances_to_compute = instances

crops = torch.concatenate([instance.crop for instance in instances_to_compute])

features = self.visual_encoder(crops)

for i, z_i in enumerate(features):
instances_to_compute[i].features = z_i
Loading

0 comments on commit 5d021d4

Please sign in to comment.