Skip to content

Commit

Permalink
add doc_strings
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad committed May 20, 2024
1 parent 766820b commit c0ceac1
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions biogtr/io/association_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def reduce(
by: tuple[str] = (None, "pred"),
reduce_method: callable = np.sum,
) -> pd.DataFrame:
"""Reduce association matrix rows/columns to inst/traj x traj/inst
"""Reduce association matrix rows/columns to inst/traj x traj/inst.
Args:
to:
by:
method:
to: A tuple indicating how to reduce rows/columns. Either inst (remains unchanged), or traj (reduces matrix from n_query/n_ref to n_traj
by: A tuple indicating how to group rows/columns when aggregating. Either "pred" or "gt"
method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing.
Returns:
The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe.
Expand Down Expand Up @@ -260,7 +260,18 @@ def __getindices__(

return ind

def get_tracks(self, instances, label="pred"):
def get_tracks(
self, instances: list["Instance"], label: str = "pred"
) -> dict[int, list["Instance"]]:
"""Group instances by track.
Args:
instances: The list of instances to group
label: the track id type to group by. Either `pred` or `gt`.
Returns:
A dictionary of track_id:instances
"""
if "pred" in label.lower():
traj_ids = set([instance.pred_track_id.item() for instance in instances])
traj = {
Expand Down

0 comments on commit c0ceac1

Please sign in to comment.