diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 213914c2..5f6fce84 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -128,8 +128,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame Args: model: the pretrained GlobalTrackingTransformer to be used for inference - frames: A list of Frames (See `biogtr.io.data_structures.Frame` for more info). - + frames: A list of Frames (See `biogtr.io.Frame` for more info). Returns: Frames: A list of Frames populated with pred_track_ids and asso_matrices diff --git a/biogtr/io/association_matrix.py b/biogtr/io/association_matrix.py index a5f4a3a7..1447249c 100644 --- a/biogtr/io/association_matrix.py +++ b/biogtr/io/association_matrix.py @@ -173,7 +173,7 @@ def reduce( Either "instance" (remains unchanged), or "track" (n_cols=n_traj) row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt". col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt". - method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing. + reduce_method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing. Returns: The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe. @@ -199,7 +199,6 @@ def reduce( reduced_matrix = [] for row_track, row_instances in row_tracks.items(): - for col_track, col_instances in col_tracks.items(): asso_matrix = self[row_instances, col_instances] @@ -208,7 +207,6 @@ def reduce( if row_dims == "track": asso_matrix = reduce_method(asso_matrix, axis=0) - reduced_matrix.append(asso_matrix) reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T @@ -234,7 +232,6 @@ def __getitem__(self, inds) -> np.ndarray: try: return self.numpy()[query_ind[:, None], ref_ind].squeeze() - except IndexError as e: print(f"Query_insts: {type(query_inst)}") print(f"Query_inds: {query_ind}") diff --git a/biogtr/io/config.py b/biogtr/io/config.py index f30d3128..2d83be38 100644 --- a/biogtr/io/config.py +++ b/biogtr/io/config.py @@ -78,6 +78,11 @@ def get_model(self) -> GlobalTrackingTransformer: A global tracking transformer with parameters indicated by cfg """ model_params = self.cfg.model + ckpt_path = model_params.pop("ckpt_path", None) + + if ckpt_path is not None and len(ckpt_path) > 0: + return GTRRunner.load_from_checkpoint(ckpt_path).model + return GlobalTrackingTransformer(**model_params) def get_tracker_cfg(self) -> dict: @@ -100,9 +105,14 @@ def get_gtr_runner(self): loss_params = self.cfg.loss gtr_runner_params = self.cfg.runner - if self.cfg.model.ckpt_path is not None and self.cfg.model.ckpt_path != "": + model_params = self.cfg.model + + ckpt_path = model_params.pop("ckpt_path", None) + + if ckpt_path is not None and ckpt_path != "": + model = GTRRunner.load_from_checkpoint( - self.cfg.model.ckpt_path, + ckpt_path, tracker_cfg=tracker_params, train_metrics=self.cfg.runner.metrics.train, val_metrics=self.cfg.runner.metrics.val, @@ -110,7 +120,6 @@ def get_gtr_runner(self): ) else: - model_params = self.cfg.model model = GTRRunner( model_params, tracker_params, diff --git a/biogtr/models/global_tracking_transformer.py b/biogtr/models/global_tracking_transformer.py index b556ee1e..4a74d6b5 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/biogtr/models/global_tracking_transformer.py @@ -27,7 +27,6 @@ def __init__( embedding_meta: dict = None, return_embedding: bool = False, decoder_self_attn: bool = False, - **kwargs, ): """Initialize GTR. diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index b37ed718..fc9afeb4 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -18,23 +18,26 @@ class GTRRunner(LightningModule): Used for training, validation and inference. """ + DEFAULT_METRICS = { + "train": [], + "val": ["num_switches"], + "test": ["num_switches"], + } + DEFAULT_TRACKING = { + "train": False, + "val": True, + "test": True, + } + def __init__( self, - model_cfg: dict = {}, - tracker_cfg: dict = {}, - loss_cfg: dict = {}, + model_cfg: dict = None, + tracker_cfg: dict = None, + loss_cfg: dict = None, optimizer_cfg: dict = None, scheduler_cfg: dict = None, - metrics: dict[str, list[str]] = { - "train": [], - "val": ["num_switches"], - "test": ["num_switches"], - }, - persistent_tracking: dict[str, bool] = { - "train": False, - "val": True, - "test": True, - }, + metrics: dict[str, list[str]] = None, + persistent_tracking: dict[str, bool] = None, ): """Initialize a lightning module for GTR. @@ -51,6 +54,11 @@ def __init__( super().__init__() self.save_hyperparameters() + model_cfg = model_cfg if model_cfg else {} + loss_cfg = loss_cfg if loss_cfg else {} + tracker_cfg = tracker_cfg if tracker_cfg else {} + + _ = model_cfg.pop("ckpt_path", None) self.model = GlobalTrackingTransformer(**model_cfg) self.loss = AssoLoss(**loss_cfg) self.tracker = Tracker(**tracker_cfg) @@ -58,8 +66,12 @@ def __init__( self.optimizer_cfg = optimizer_cfg self.scheduler_cfg = scheduler_cfg - self.metrics = metrics - self.persistent_tracking = persistent_tracking + self.metrics = metrics if metrics is not None else self.DEFAULT_METRICS + self.persistent_tracking = ( + persistent_tracking + if persistent_tracking is not None + else self.DEFAULT_TRACKING + ) def forward( self, ref_instances: list[Instance], query_instances: list[Instance] = None @@ -159,6 +171,7 @@ def _shared_eval_step(self, frames: list[Frame], mode: str) -> dict[str, float]: """ try: instances = [instance for frame in frames for instance in frame.instances] + if len(instances) == 0: return None