Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove kwargs and mutable defaults #48

Merged
merged 28 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8ddc557
add back-reference to frame as attribute for instane
aaprasad May 15, 2024
7887d56
separate get_boxes_times into two functions and use `Instances` as input
aaprasad May 15, 2024
9bab7bc
use instances as input into model instead of frames
aaprasad May 15, 2024
82293d2
create io module, move config, visualize there. abstract `Frame` and …
aaprasad May 16, 2024
b4049b8
refactor `Frame` and `Instance` initialization to use `attrs` instead…
aaprasad May 16, 2024
2b0bc55
add doc strings, fix small bugs
aaprasad May 16, 2024
ef26012
Implement AssociationMatrix class for handling model output
aaprasad May 16, 2024
94a0e61
create io module, move config, visualize there. abstract `Frame` and …
aaprasad May 16, 2024
c4bc0fb
refactor `Frame` and `Instance` initialization to use `attrs` instead…
aaprasad May 16, 2024
42f8a8c
add doc strings, fix small bugs
aaprasad May 16, 2024
b5f39b4
Implement AssociationMatrix class for handling model output
aaprasad May 16, 2024
ccd523a
Merge remote-tracking branch 'origin/aadi/refactor-data-structures' i…
aaprasad May 16, 2024
0f535af
fix overwrites from merge
aaprasad May 17, 2024
56e038a
store model outputs in association matrix
aaprasad May 17, 2024
8a71d6d
add track object for storing tracklets
aaprasad May 17, 2024
766820b
add reduction function to association matrix
aaprasad May 20, 2024
c0ceac1
add doc_strings
aaprasad May 20, 2024
92095e0
fix tests, docstrings
aaprasad May 20, 2024
a6a6ace
add spatial/temporal embeddings as attribute to `Instance`
aaprasad May 20, 2024
56e0555
fix typo
aaprasad May 20, 2024
80400fc
add `from_slp` converters
aaprasad May 21, 2024
7ff22e7
fix docstrings
aaprasad May 21, 2024
89007a9
store embeddings in Instance object instead of returning
aaprasad May 21, 2024
cbf915a
only keep visualize in io
aaprasad May 21, 2024
b3c5661
remove mutable types from default arguments. Don't use kwargs unless …
aaprasad May 21, 2024
adb3715
handle edge case where ckpt_path is not in config
aaprasad May 21, 2024
5d021d4
Merge branch 'main' into aadi/remove-kwargs-mutable-defaults
aaprasad Jun 3, 2024
79de38a
fix errors from merge conflict resolution
aaprasad Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions biogtr/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame

Args:
model: the pretrained GlobalTrackingTransformer to be used for inference
frames: A list of Frames (See `biogtr.io.data_structures.Frame` for more info).

frames: A list of Frames (See `biogtr.io.Frame` for more info).

Returns:
Frames: A list of Frames populated with pred_track_ids and asso_matrices
Expand Down
5 changes: 1 addition & 4 deletions biogtr/io/association_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def reduce(
Either "instance" (remains unchanged), or "track" (n_cols=n_traj)
row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt".
col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt".
method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing.
reduce_method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing.

Returns:
The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe.
Expand All @@ -199,7 +199,6 @@ def reduce(

reduced_matrix = []
for row_track, row_instances in row_tracks.items():

for col_track, col_instances in col_tracks.items():
asso_matrix = self[row_instances, col_instances]

Expand All @@ -208,7 +207,6 @@ def reduce(

if row_dims == "track":
asso_matrix = reduce_method(asso_matrix, axis=0)

reduced_matrix.append(asso_matrix)

reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T
Expand All @@ -234,7 +232,6 @@ def __getitem__(self, inds) -> np.ndarray:

try:
return self.numpy()[query_ind[:, None], ref_ind].squeeze()

except IndexError as e:
print(f"Query_insts: {type(query_inst)}")
print(f"Query_inds: {query_ind}")
Expand Down
15 changes: 12 additions & 3 deletions biogtr/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def get_model(self) -> GlobalTrackingTransformer:
A global tracking transformer with parameters indicated by cfg
"""
model_params = self.cfg.model
ckpt_path = model_params.pop("ckpt_path", None)

if ckpt_path is not None and len(ckpt_path) > 0:
return GTRRunner.load_from_checkpoint(ckpt_path).model

return GlobalTrackingTransformer(**model_params)

def get_tracker_cfg(self) -> dict:
Expand All @@ -100,17 +105,21 @@ def get_gtr_runner(self):
loss_params = self.cfg.loss
gtr_runner_params = self.cfg.runner

if self.cfg.model.ckpt_path is not None and self.cfg.model.ckpt_path != "":
model_params = self.cfg.model

ckpt_path = model_params.pop("ckpt_path", None)

if ckpt_path is not None and ckpt_path != "":

model = GTRRunner.load_from_checkpoint(
self.cfg.model.ckpt_path,
ckpt_path,
tracker_cfg=tracker_params,
train_metrics=self.cfg.runner.metrics.train,
val_metrics=self.cfg.runner.metrics.val,
test_metrics=self.cfg.runner.metrics.test,
)

else:
model_params = self.cfg.model
model = GTRRunner(
model_params,
tracker_params,
Expand Down
1 change: 0 additions & 1 deletion biogtr/models/global_tracking_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
embedding_meta: dict = None,
return_embedding: bool = False,
decoder_self_attn: bool = False,
**kwargs,
):
"""Initialize GTR.

Expand Down
43 changes: 28 additions & 15 deletions biogtr/models/gtr_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@ class GTRRunner(LightningModule):
Used for training, validation and inference.
"""

DEFAULT_METRICS = {
"train": [],
"val": ["num_switches"],
"test": ["num_switches"],
}
DEFAULT_TRACKING = {
"train": False,
"val": True,
"test": True,
}

def __init__(
self,
model_cfg: dict = {},
tracker_cfg: dict = {},
loss_cfg: dict = {},
model_cfg: dict = None,
tracker_cfg: dict = None,
loss_cfg: dict = None,
optimizer_cfg: dict = None,
scheduler_cfg: dict = None,
metrics: dict[str, list[str]] = {
"train": [],
"val": ["num_switches"],
"test": ["num_switches"],
},
persistent_tracking: dict[str, bool] = {
"train": False,
"val": True,
"test": True,
},
metrics: dict[str, list[str]] = None,
persistent_tracking: dict[str, bool] = None,
):
"""Initialize a lightning module for GTR.

Expand All @@ -51,15 +54,24 @@ def __init__(
super().__init__()
self.save_hyperparameters()

model_cfg = model_cfg if model_cfg else {}
loss_cfg = loss_cfg if loss_cfg else {}
tracker_cfg = tracker_cfg if tracker_cfg else {}

_ = model_cfg.pop("ckpt_path", None)
self.model = GlobalTrackingTransformer(**model_cfg)
self.loss = AssoLoss(**loss_cfg)
self.tracker = Tracker(**tracker_cfg)

self.optimizer_cfg = optimizer_cfg
self.scheduler_cfg = scheduler_cfg

self.metrics = metrics
self.persistent_tracking = persistent_tracking
self.metrics = metrics if metrics is not None else self.DEFAULT_METRICS
self.persistent_tracking = (
persistent_tracking
if persistent_tracking is not None
else self.DEFAULT_TRACKING
)

def forward(
self, ref_instances: list[Instance], query_instances: list[Instance] = None
Expand Down Expand Up @@ -159,6 +171,7 @@ def _shared_eval_step(self, frames: list[Frame], mode: str) -> dict[str, float]:
"""
try:
instances = [instance for frame in frames for instance in frame.instances]

if len(instances) == 0:
return None

Expand Down
Loading