Skip to content

Commit

Permalink
Remove kwargs and mutable defaults (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad authored Jun 3, 2024
1 parent 98106b7 commit f8a33df
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 25 deletions.
3 changes: 1 addition & 2 deletions biogtr/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame
Args:
model: the pretrained GlobalTrackingTransformer to be used for inference
frames: A list of Frames (See `biogtr.io.data_structures.Frame` for more info).
frames: A list of Frames (See `biogtr.io.Frame` for more info).
Returns:
Frames: A list of Frames populated with pred_track_ids and asso_matrices
Expand Down
5 changes: 1 addition & 4 deletions biogtr/io/association_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def reduce(
Either "instance" (remains unchanged), or "track" (n_cols=n_traj)
row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt".
col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt".
method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing.
reduce_method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing.
Returns:
The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe.
Expand All @@ -199,7 +199,6 @@ def reduce(

reduced_matrix = []
for row_track, row_instances in row_tracks.items():

for col_track, col_instances in col_tracks.items():
asso_matrix = self[row_instances, col_instances]

Expand All @@ -208,7 +207,6 @@ def reduce(

if row_dims == "track":
asso_matrix = reduce_method(asso_matrix, axis=0)

reduced_matrix.append(asso_matrix)

reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T
Expand All @@ -234,7 +232,6 @@ def __getitem__(self, inds) -> np.ndarray:

try:
return self.numpy()[query_ind[:, None], ref_ind].squeeze()

except IndexError as e:
print(f"Query_insts: {type(query_inst)}")
print(f"Query_inds: {query_ind}")
Expand Down
15 changes: 12 additions & 3 deletions biogtr/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def get_model(self) -> GlobalTrackingTransformer:
A global tracking transformer with parameters indicated by cfg
"""
model_params = self.cfg.model
ckpt_path = model_params.pop("ckpt_path", None)

if ckpt_path is not None and len(ckpt_path) > 0:
return GTRRunner.load_from_checkpoint(ckpt_path).model

return GlobalTrackingTransformer(**model_params)

def get_tracker_cfg(self) -> dict:
Expand All @@ -100,17 +105,21 @@ def get_gtr_runner(self):
loss_params = self.cfg.loss
gtr_runner_params = self.cfg.runner

if self.cfg.model.ckpt_path is not None and self.cfg.model.ckpt_path != "":
model_params = self.cfg.model

ckpt_path = model_params.pop("ckpt_path", None)

if ckpt_path is not None and ckpt_path != "":

model = GTRRunner.load_from_checkpoint(
self.cfg.model.ckpt_path,
ckpt_path,
tracker_cfg=tracker_params,
train_metrics=self.cfg.runner.metrics.train,
val_metrics=self.cfg.runner.metrics.val,
test_metrics=self.cfg.runner.metrics.test,
)

else:
model_params = self.cfg.model
model = GTRRunner(
model_params,
tracker_params,
Expand Down
1 change: 0 additions & 1 deletion biogtr/models/global_tracking_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
embedding_meta: dict = None,
return_embedding: bool = False,
decoder_self_attn: bool = False,
**kwargs,
):
"""Initialize GTR.
Expand Down
43 changes: 28 additions & 15 deletions biogtr/models/gtr_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@ class GTRRunner(LightningModule):
Used for training, validation and inference.
"""

DEFAULT_METRICS = {
"train": [],
"val": ["num_switches"],
"test": ["num_switches"],
}
DEFAULT_TRACKING = {
"train": False,
"val": True,
"test": True,
}

def __init__(
self,
model_cfg: dict = {},
tracker_cfg: dict = {},
loss_cfg: dict = {},
model_cfg: dict = None,
tracker_cfg: dict = None,
loss_cfg: dict = None,
optimizer_cfg: dict = None,
scheduler_cfg: dict = None,
metrics: dict[str, list[str]] = {
"train": [],
"val": ["num_switches"],
"test": ["num_switches"],
},
persistent_tracking: dict[str, bool] = {
"train": False,
"val": True,
"test": True,
},
metrics: dict[str, list[str]] = None,
persistent_tracking: dict[str, bool] = None,
):
"""Initialize a lightning module for GTR.
Expand All @@ -51,15 +54,24 @@ def __init__(
super().__init__()
self.save_hyperparameters()

model_cfg = model_cfg if model_cfg else {}
loss_cfg = loss_cfg if loss_cfg else {}
tracker_cfg = tracker_cfg if tracker_cfg else {}

_ = model_cfg.pop("ckpt_path", None)
self.model = GlobalTrackingTransformer(**model_cfg)
self.loss = AssoLoss(**loss_cfg)
self.tracker = Tracker(**tracker_cfg)

self.optimizer_cfg = optimizer_cfg
self.scheduler_cfg = scheduler_cfg

self.metrics = metrics
self.persistent_tracking = persistent_tracking
self.metrics = metrics if metrics is not None else self.DEFAULT_METRICS
self.persistent_tracking = (
persistent_tracking
if persistent_tracking is not None
else self.DEFAULT_TRACKING
)

def forward(
self, ref_instances: list[Instance], query_instances: list[Instance] = None
Expand Down Expand Up @@ -159,6 +171,7 @@ def _shared_eval_step(self, frames: list[Frame], mode: str) -> dict[str, float]:
"""
try:
instances = [instance for frame in frames for instance in frame.instances]

if len(instances) == 0:
return None

Expand Down

0 comments on commit f8a33df

Please sign in to comment.