Skip to content

Commit

Permalink
clarify reduce args by separating row/columns and using better names.
Browse files Browse the repository at this point in the history
handle string parsing better
  • Loading branch information
aaprasad committed May 29, 2024
1 parent 430dff6 commit 74c4889
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 52 deletions.
154 changes: 104 additions & 50 deletions biogtr/io/association_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ class AssociationMatrix:
ref_instances: list[Instance] = attrs.field()
query_instances: list[Instance] = attrs.field()

AVAILABLE_REDUCTIONS = attrs.field(
init=False,
default={
"instance": ["inst", "instance"],
"track": ["track", "traj", "trajectory"],
None: ["", None],
},
)
AVAILABLE_GROUPINGS = attrs.field(init=False, default=["pred", "gt", None])

@ref_instances.validator
def _check_ref_instances(self, attribute, value):
"""Check to ensure that the number of association matrix columns and reference instances match.
Expand Down Expand Up @@ -56,7 +66,7 @@ def _check_query_instances(self, attribute, value):
raise ValueError(
(
"Query instances must equal number of rows in Association matrix"
f"Found {len(value)} query instances but {self.matrix.shape[0]} columns."
f"Found {len(value)} query instances but {self.matrix.shape[0]} rows."
)
)

Expand All @@ -83,16 +93,18 @@ def numpy(self) -> np.ndarray:
return self.matrix

def to_dataframe(
self, row_label: str = "gt", col_label: str = "gt"
self, row_labels: str = "gt", col_labels: str = "gt"
) -> pd.DataFrame:
"""Convert the association matrix to a pandas DataFrame.
Args:
row_label: How to label the rows(queries).
If `gt` then label by gt track id.
If `pred` then label by pred track id.
row_labels: How to label the rows(queries).
If list, then must match # of rows/queries
If `"gt"` then label by gt track id.
If `"pred"` then label by pred track id.
Otherwise label by the query_instance indices
col_label: How to label the columns(references).
col_labels: How to label the columns(references).
If list, then must match # of columns/refs
If `gt` then label by gt track id.
If `pred` then label by pred track id.
Otherwise label by the ref_instance indices
Expand All @@ -102,61 +114,97 @@ def to_dataframe(
"""
matrix = self.numpy()

if row_label.lower() == "gt":
row_inds = [
instance.gt_track_id.item() for instance in self.query_instances
]

elif row_label.lower() == "pred":
row_inds = [
instance.pred_track_id.item() for instance in self.query_instances
]

if not isinstance(row_labels, str):
if len(row_labels) == len(self.query_instances):
row_inds = row_labels
else:
raise ValueError(
(
f"Mismatched # of rows and labels!",
f"Found {len(row_labels)} with {len(self.query_instances)} rows",
)
)
else:
row_inds = np.arange(len(self.query_instances))
if row_labels.lower() == "gt":
row_inds = [
instance.gt_track_id.item() for instance in self.query_instances
]

if col_label.lower() == "gt":
col_inds = [instance.gt_track_id.item() for instance in self.ref_instances]
elif row_labels.lower() == "pred":
row_inds = [
instance.pred_track_id.item() for instance in self.query_instances
]

elif col_label.lower() == "pred":
col_inds = [
instance.pred_track_id.item() for instance in self.ref_instances
]
else:
row_inds = np.arange(len(self.query_instances))

if not isinstance(col_labels, str):
if len(col_labels) == len(self.ref_instances):
col_inds = col_labels
else:
raise ValueError(
(
f"Mismatched # of columns and labels!",
f"Found {len(col_labels)} with {len(self.ref_instances)} columns",
)
)
else:
col_inds = np.arange(len(self.ref_instances))
if col_labels.lower() == "gt":
col_inds = [
instance.gt_track_id.item() for instance in self.ref_instances
]

elif col_labels.lower() == "pred":
col_inds = [
instance.pred_track_id.item() for instance in self.ref_instances
]

else:
col_inds = np.arange(len(self.ref_instances))

asso_df = pd.DataFrame(matrix, index=row_inds, columns=col_inds)

return asso_df

def reduce(
self,
to: tuple[str] = ("inst", "traj"),
by: tuple[str] = (None, "pred"),
row_dims: str = "inst",
col_dims: str = "traj",
row_grouping: str = None,
col_grouping: str = "pred",
reduce_method: callable = np.sum,
) -> pd.DataFrame:
"""Reduce association matrix rows/columns to inst/traj x traj/inst.
"""Aggregate the association matrix by specified dimensions and grouping.
Args:
to: A tuple indicating how to reduce rows/columns. Either inst (remains unchanged), or traj (reduces matrix from n_query/n_ref to n_traj
by: A tuple indicating how to group rows/columns when aggregating. Either "pred" or "gt"
row_dims: A str indicating how to what dimensions to reduce rows to. Either inst (remains unchanged), or traj (n_rows=n_traj)
col_dims: A str indicating how to dimensions to reduce rows to. Either inst (remains unchanged), or traj (n_cols=n_traj)
row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt".
col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt".
method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing.
Returns:
The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe.
"""
row_to, col_to = to
row_by, col_by = by

if row_to not in ("inst", "traj") or col_to not in ("inst", "traj"):
if (
row_dims is not None
and row_dims.lower()
not in [key for keys in self.AVAILABLE_REDUCTIONS.values() for key in keys]
) or (
col_dims is not None
and col_dims.lower()
not in [key for keys in self.AVAILABLE_REDUCTIONS.values() for key in keys]
):
raise ValueError(
f"Can only reduce to inst/traj x inst/traj but ({row_to} x {col_to}) was requested!"
f"Can only reduce to inst/traj x inst/traj but ({row_dims} x {col_dims}) was requested!"
)

if row_by not in ("pred", "gt", None) or col_by not in ("pred", "gt", None):
if (
row_grouping not in self.AVAILABLE_GROUPINGS
or col_grouping not in self.AVAILABLE_GROUPINGS
):
raise ValueError(
f"Can aggregate by [gt, pred, None] but {row_by} and {col_by} was requested!"
f"Can aggregate by [gt, pred, None] but {row_grouping} and {col_grouping} was requested!"
)

n_rows = len(self.query_instances)
Expand All @@ -168,27 +216,28 @@ def reduce(
col_inds = [i for i in range(len(self.ref_instances))]
row_inds = [i for i in range(len(self.query_instances))]

if "tra" in col_to:
col_tracks = self.get_tracks(self.ref_instances, col_by)
if col_dims.lower() in self.AVAILABLE_REDUCTIONS["track"]:
col_tracks = self.get_tracks(self.ref_instances, col_grouping)
col_inds = list(col_tracks.keys())
n_cols = len(col_inds)
if "tra" in row_to:
row_tracks = self.get_tracks(self.query_instances, row_by)

if row_dims.lower() in self.AVAILABLE_REDUCTIONS["track"]:
row_tracks = self.get_tracks(self.query_instances, row_grouping)
row_inds = list(row_tracks.keys())
n_rows = len(row_inds)

reduced_matrix = []
for row_track, row_instances in row_tracks.items():
# print(row_instances)

for col_track, col_instances in col_tracks.items():
asso_matrix = self[row_instances, col_instances]
# print(col_instances)
# print(asso_matrix)
if "tra" in col_to:

if col_dims.lower() in self.AVAILABLE_REDUCTIONS["track"]:
asso_matrix = reduce_method(asso_matrix, axis=1)
if "tra" in row_to:

if row_dims.lower() in self.AVAILABLE_REDUCTIONS["track"]:
asso_matrix = reduce_method(asso_matrix, axis=0)

reduced_matrix.append(asso_matrix)

reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T
Expand All @@ -211,8 +260,10 @@ def __getitem__(self, inds) -> np.ndarray:

query_ind = self.__getindices__(query_inst, self.query_instances)
ref_ind = self.__getindices__(ref_inst, self.ref_instances)

try:
return self.numpy()[query_ind[:, None], ref_ind].squeeze()

except IndexError as e:
print(f"Query_insts: {type(query_inst)}")
print(f"Query_inds: {query_ind}")
Expand All @@ -237,10 +288,13 @@ def __getindices__(
"""
if isinstance(instance, Instance):
ind = np.array([instance_lookup.index(instance)])

elif instance is None:
ind = np.arange(len(instance_lookup))

elif np.isscalar(instance):
ind = np.array([instance])

else:
instances = instance
if not [isinstance(inst, (Instance, int)) for inst in instance]:
Expand Down Expand Up @@ -272,7 +326,7 @@ def get_tracks(
Returns:
A dictionary of track_id:instances
"""
if "pred" in label.lower():
if label.lower() == "pred":
traj_ids = set([instance.pred_track_id.item() for instance in instances])
traj = {
track_id: [
Expand All @@ -282,7 +336,8 @@ def get_tracks(
]
for track_id in traj_ids
}
elif "gt" in label.lower():

elif label.lower() == "gt":
traj_ids = set(
[instance.gt_track_id.item() for instance in self.ref_instances]
)
Expand All @@ -294,9 +349,8 @@ def get_tracks(
]
for track_id in traj_ids
}

else:
raise ValueError(
f"Unsupported label '{label}'. Expected 'pred' or 'gt'."
)
raise ValueError(f"Unsupported label '{label}'. Expected 'pred' or 'gt'.")

return traj
4 changes: 2 additions & 2 deletions tests/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def test_association_matrix():

traj_score = pd.concat(
[
query_matrix.to_dataframe(row_label="inst").drop(1, axis=1).sum(1),
query_matrix.to_dataframe(row_label="inst").drop(0, axis=1).sum(1),
query_matrix.to_dataframe(row_labels="inst").drop(1, axis=1).sum(1),
query_matrix.to_dataframe(row_labels="inst").drop(0, axis=1).sum(1),
],
axis=1,
)
Expand Down

0 comments on commit 74c4889

Please sign in to comment.