From dc3096759587f4defe106bd7df734dc78a783e58 Mon Sep 17 00:00:00 2001 From: Aaditya Prasad <78439225+aaprasad@users.noreply.github.com> Date: Wed, 29 May 2024 14:58:14 -0700 Subject: [PATCH 1/7] Use instances as model input (#46) --- biogtr/data_structures.py | 30 ++++- biogtr/inference/tracker.py | 14 +- biogtr/models/global_tracking_transformer.py | 56 ++++++-- biogtr/models/gtr_runner.py | 74 +++++----- biogtr/models/model_utils.py | 67 ++++++--- biogtr/models/transformer.py | 135 ++++++++++--------- biogtr/training/losses.py | 6 +- tests/test_models.py | 20 ++- 8 files changed, 261 insertions(+), 141 deletions(-) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index 7aa706b7..81b2f551 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -124,6 +124,8 @@ def __init__( self._device = device self.to(self._device) + self._frame = None + def __repr__(self) -> str: """Return string representation of the Instance.""" return ( @@ -421,6 +423,26 @@ def has_features(self) -> bool: else: return True + @property + def frame(self) -> "Frame": + """Get the frame the instance belongs to. + + Returns: + The back reference to the `Frame` that this `Instance` belongs to. + """ + return self._frame + + @frame.setter + def frame(self, frame: "Frame") -> None: + """Set the back reference to the `Frame` that this `Instance` belongs to. + + This field is set when instances are added to `Frame` object. + + Args: + frame: A `Frame` object containing the metadata for the frame that the instance belongs to + """ + self._frame = frame + @property def pose(self) -> dict[str, ArrayLike]: """Get the pose of the instance. @@ -580,9 +602,12 @@ def __init__( self._img_shape = img_shape else: self._img_shape = torch.tensor([img_shape]) + if instances is None: self.instances = [] else: + for instance in instances: + instance.frame = self self._instances = instances self._asso_output = asso_output @@ -612,7 +637,7 @@ def __repr__(self) -> str: f"img_shape={self._img_shape}, " f"num_detected={self.num_detected}, " f"asso_output={self._asso_output}, " - f"traj_score={self._traj_score}, " + f"traj_score={list(self._traj_score.keys())}, " f"matches={self._matches}, " f"instances={self._instances}, " f"device={self._device}" @@ -796,6 +821,9 @@ def instances(self, instances: List[Instance]) -> None: Args: instances: A list of Instances that appear in the frame. """ + for instance in instances: + instance.frame = self + self._instances = instances def has_instances(self) -> bool: diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 84fffb73..3e3405ba 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -141,7 +141,9 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame # W: width. for batch_idx, frame_to_track in enumerate(frames): - tracked_frames = self.track_queue.collate_tracks() + tracked_frames = self.track_queue.collate_tracks( + device=frame_to_track.frame_id.device + ) if self.verbose: warnings.warn( f"Current number of tracks is {self.track_queue.n_tracks}" @@ -229,8 +231,12 @@ def _run_global_tracker( # E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window. _ = model.eval() + query_frame = frames[query_ind] + query_instances = query_frame.instances + all_instances = [instance for frame in frames for instance in frame.instances] + if self.verbose: print(f"Frame {query_frame.frame_id.item()}") @@ -253,7 +259,7 @@ def _run_global_tracker( # (L=1, n_query, total_instances) with torch.no_grad(): - asso_output, embed = model(frames, query_frame=query_ind) + asso_output, embed = model(all_instances, query_instances) # if model.transformer.return_embedding: # query_frame.embeddings = embed TODO add embedding to Instance Object # if query_frame == 1: @@ -321,6 +327,7 @@ def _run_global_tracker( ] nonquery_inds = [i for i in range(total_instances) if i not in query_inds] + # instead should we do model(nonquery_instances, query_instances)? asso_nonquery = asso_output[:, nonquery_inds] # (n_query, n_nonquery) asso_nonquery_df = pd.DataFrame( @@ -332,10 +339,9 @@ def _run_global_tracker( query_frame.add_traj_score("asso_nonquery", asso_nonquery_df) - pred_boxes, _ = model_utils.get_boxes_times(frames) + pred_boxes = model_utils.get_boxes(all_instances) query_boxes = pred_boxes[query_inds] # n_k x 4 nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4 - # TODO: Insert postprocessing. unique_ids = torch.unique(instance_ids) # (n_nonquery,) diff --git a/biogtr/models/global_tracking_transformer.py b/biogtr/models/global_tracking_transformer.py index c02aec5b..59206c17 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/biogtr/models/global_tracking_transformer.py @@ -2,13 +2,13 @@ from biogtr.models.transformer import Transformer from biogtr.models.visual_encoder import VisualEncoder -from biogtr.data_structures import Frame -from torch import nn +from biogtr.data_structures import Instance +import torch # todo: do we want to handle params with configs already here? -class GlobalTrackingTransformer(nn.Module): +class GlobalTrackingTransformer(torch.nn.Module): """Modular GTR model composed of visual encoder + transformer used for tracking.""" def __init__( @@ -79,26 +79,54 @@ def __init__( decoder_self_attn=decoder_self_attn, ) - def forward(self, frames: list[Frame], query_frame: int = None): + def forward( + self, ref_instances: list[Instance], query_instances: list[Instance] = None + ): """Execute forward pass of GTR Model to get asso matrix. Args: - frames: List of Frames from chunk containing crops of objects + gt label info - query_frame: Frame index used as query for self attention. Only used in sliding inference where query frame is the last frame in the window. + ref_instances: List of instances from chunk containing crops of objects + gt label info + query_instances: list of instances used as query in decoder. Returns: An N_T x N association matrix """ # Extract feature representations with pre-trained encoder. - for frame in filter( - lambda f: f.has_instances() and not f.has_features(), frames - ): - crops = frame.get_crops() - z = self.visual_encoder(crops) + self.extract_features(ref_instances) - for i, z_i in enumerate(z): - frame.instances[i].features = z_i + if query_instances: + self.extract_features(query_instances) - asso_preds, emb = self.transformer(frames, query_frame=query_frame) + asso_preds, emb = self.transformer(ref_instances, query_instances) return asso_preds, emb + + def extract_features( + self, instances: list["Instance"], force_recompute: bool = False + ) -> None: + """Extract features from instances using visual encoder backbone. + + Args: + instances: A list of instances to compute features for + force_recompute: indicate whether to compute features for all instances regardless of if they have instances + """ + if not force_recompute: + instances_to_compute = [ + instance + for instance in instances + if instance.has_crop() and not instance.has_features() + ] + else: + instances_to_compute = instances + + if len(instances_to_compute) == 0: + return + elif len(instances_to_compute) == 1: # handle batch norm error when B=1 + instances_to_compute = instances + + crops = torch.concatenate([instance.crop for instance in instances_to_compute]) + + features = self.visual_encoder(crops) + + for i, z_i in enumerate(features): + instances_to_compute[i].features = z_i diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index e7e6a577..98955d23 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -8,6 +8,7 @@ from biogtr.training.losses import AssoLoss from biogtr.models.model_utils import init_optimizer, init_scheduler from pytorch_lightning import LightningModule +from biogtr.data_structures import Frame, Instance class GTRRunner(LightningModule): @@ -59,28 +60,29 @@ def __init__( self.metrics = metrics self.persistent_tracking = persistent_tracking - def forward(self, instances) -> torch.Tensor: + def forward( + self, ref_instances: list[Instance], query_instances: list[Instance] = None + ) -> torch.Tensor: """Execute forward pass of the lightning module. Args: - instances: a list of dicts where each dict is a frame with gt data + ref_instances: a list of `Instance` objects containing crops and other data needed for transformer model + query_instances: a list of `Instance` objects used as queries in the decoder. Mostly used for inference. Returns: An association matrix between objects """ - if sum([frame.num_detected for frame in instances]) > 0: - asso_preds, _ = self.model(instances) - return asso_preds - return None + asso_preds, _ = self.model(ref_instances, query_instances) + return asso_preds def training_step( - self, train_batch: list[dict], batch_idx: int + self, train_batch: list[list[Frame]], batch_idx: int ) -> dict[str, float]: """Execute single training step for model. Args: - train_batch: A single batch from the dataset which is a list of dicts - with length `clip_length` where each dict is a frame + train_batch: A single batch from the dataset which is a list of `Frame` objects + with length `clip_length` containing Instances and other metadata. batch_idx: the batch number used by lightning Returns: @@ -92,13 +94,13 @@ def training_step( return result def validation_step( - self, val_batch: list[dict], batch_idx: int + self, val_batch: list[list[Frame]], batch_idx: int ) -> dict[str, float]: """Execute single val step for model. Args: - val_batch: A single batch from the dataset which is a list of dicts - with length `clip_length` where each dict is a frame + val_batch: A single batch from the dataset which is a list of `Frame` objects + with length `clip_length` containing Instances and other metadata. batch_idx: the batch number used by lightning Returns: @@ -109,12 +111,14 @@ def validation_step( return result - def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]: + def test_step( + self, test_batch: list[list[Frame]], batch_idx: int + ) -> dict[str, float]: """Execute single test step for model. Args: - val_batch: A single batch from the dataset which is a list of dicts - with length `clip_length` where each dict is a frame + test_batch: A single batch from the dataset which is a list of `Frame` objects + with length `clip_length` containing Instances and other metadata. batch_idx: the batch number used by lightning Returns: @@ -125,57 +129,57 @@ def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]: return result - def predict_step(self, batch: list[dict], batch_idx: int) -> dict: + def predict_step(self, batch: list[list[Frame]], batch_idx: int) -> list[Frame]: """Run inference for model. Computes association + assignment. Args: - batch: A single batch from the dataset which is a list of dicts - with length `clip_length` where each dict is a frame + batch: A single batch from the dataset which is a list of `Frame` objects + with length `clip_length` containing Instances and other metadata. batch_idx: the batch number used by lightning Returns: A list of dicts where each dict is a frame containing the predicted track ids """ self.tracker.persistent_tracking = True - instances_pred = self.tracker(self.model, batch[0]) - return instances_pred + frames_pred = self.tracker(self.model, batch[0]) + return frames_pred - def _shared_eval_step(self, instances, mode): + def _shared_eval_step(self, frames: list[Frame], mode: str) -> dict[str, float]: """Run evaluation used by train, test, and val steps. Args: - instances: A list of dicts where each dict is a frame containing gt data + frames: A list of `Frame` objects with length `clip_length` containing Instances and other metadata. mode: which metrics to compute and whether to use persistent tracking or not Returns: a dict containing the loss and any other metrics specified by `eval_metrics` """ try: - instances = [frame for frame in instances if frame.has_instances()] + instances = [instance for frame in frames for instance in frame.instances] + if len(instances) == 0: + return None + eval_metrics = self.metrics[mode] persistent_tracking = self.persistent_tracking[mode] logits = self(instances) - - if not logits: - return None - - loss = self.loss(logits, instances) + loss = self.loss(logits, frames) return_metrics = {"loss": loss} if eval_metrics is not None and len(eval_metrics) > 0: self.tracker.persistent_tracking = persistent_tracking - instances_pred = self.tracker(self.model, instances) - instances_mm = metrics.to_track_eval(instances_pred) - clearmot = metrics.get_pymotmetrics(instances_mm, eval_metrics) + + frames_pred = self.tracker(self.model, frames) + + frames_mm = metrics.to_track_eval(frames_pred) + clearmot = metrics.get_pymotmetrics(frames_mm, eval_metrics) + return_metrics.update(clearmot.to_dict()) - return_metrics["batch_size"] = len(instances) + return_metrics["batch_size"] = len(frames) except Exception as e: - print( - f"Failed on frame {instances[0].frame_id} of video {instances[0].video_id}" - ) + print(f"Failed on frame {frames[0].frame_id} of video {frames[0].video_id}") raise (e) return return_metrics diff --git a/biogtr/models/model_utils.py b/biogtr/models/model_utils.py index 14e2948f..d92a822b 100644 --- a/biogtr/models/model_utils.py +++ b/biogtr/models/model_utils.py @@ -2,35 +2,70 @@ from typing import List, Tuple, Iterable from pytorch_lightning import loggers -from biogtr.data_structures import Frame +from biogtr.data_structures import Instance import torch -def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]: - """Extract the bounding boxes and frame indices from the input list of instances. +def get_boxes(instances: List[Instance]) -> torch.Tensor: + """Extract the bounding boxes from the input list of instances. Args: - frames (List[Frame]): List of frame objects containing metadata and instances. + instances: List of Instance objects. Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors containing the - bounding boxes normalized by the height and width of the image - and corresponding frame indices, respectively. + An (n_instances, n_points, 4) float tensor containing the bounding boxes + normalized by the height and width of the image """ - boxes, times = [], [] - _, h, w = frames[0].img_shape.flatten() - - for fidx, frame in enumerate(frames): - bbox = frame.get_bboxes().clone() + boxes = [] + for i, instance in enumerate(instances): + _, h, w = instance.frame.img_shape.flatten() + bbox = instance.bbox.clone() bbox[:, :, [0, 2]] /= w bbox[:, :, [1, 3]] /= h - boxes.append(bbox) - times.append(torch.full((bbox.shape[0],), fidx)) boxes = torch.cat(boxes, dim=0) # N, n_anchors, 4 - times = torch.cat(times, dim=0).to(boxes.device) # N - return boxes, times + + return boxes + + +def get_times( + ref_instances: list[Instance], query_instances: list[Instance] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """Extract the time indices of each instance relative to the window length. + + Args: + ref_instances: Set of instances to query against + query_instances: Set of query instances to look up using decoder. + + Returns: + Tuple of Corresponding frame indices eg [0, 0, 1, 1, ..., T, T] for ref and query instances. + """ + try: + ref_inds = torch.concat([instance.frame.frame_id for instance in ref_instances]) + except RuntimeError as e: + print([instance.frame.frame_id.device for instance in ref_instances]) + raise (e) + if query_instances is not None: + query_inds = torch.concat( + [instance.frame.frame_id for instance in query_instances] + ) + else: + query_inds = torch.tensor([], device=ref_inds.device) + + frame_inds = torch.concat([ref_inds, query_inds]) + window_length = len(frame_inds.unique()) + + frame_idx_mapping = {frame_inds.unique()[i].item(): i for i in range(window_length)} + ref_t = torch.tensor( + [frame_idx_mapping[ind.item()] for ind in ref_inds], device=ref_inds.device + ) + + query_t = torch.tensor( + [frame_idx_mapping[ind.item()] for ind in query_inds], device=ref_inds.device + ) + + return ref_t, query_t def softmax_asso(asso_output: list[torch.Tensor]) -> list[torch.Tensor]: diff --git a/biogtr/models/transformer.py b/biogtr/models/transformer.py index 7f84222e..4951c3e5 100644 --- a/biogtr/models/transformer.py +++ b/biogtr/models/transformer.py @@ -11,10 +11,10 @@ * added fixed embeddings over boxes """ -from biogtr.data_structures import Frame +from biogtr.data_structures import Instance from biogtr.models.attention_head import ATTWeightHead from biogtr.models.embedding import Embedding -from biogtr.models.model_utils import get_boxes_times +from biogtr.models.model_utils import get_boxes, get_times from torch import nn import copy import torch @@ -140,13 +140,13 @@ def _reset_parameters(self): raise (e) def forward( - self, frames: list[Frame], query_frame: int = None + self, ref_instances: list[Instance], query_instances: list[Instance] = None ) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]: """Execute a forward pass through the transformer and attention head. Args: - frames: A list of Frames (See `biogtr.data_structures.Frame for more info.) - query_frame: An integer (k) specifying the frame within the window to be queried. + ref instances: A list of instance objects (See `biogtr.data_structures.Instance` for more info.) + query_instances: An set of instances to be used as decoder queries. Returns: asso_output: A list of torch.Tensors of shape (L, n_query, total_instances) where: @@ -156,79 +156,90 @@ def forward( embedding_dict: A dictionary containing the "pos" and "temp" embeddings if `self.return_embeddings` is False then they are None. """ - try: - reid_features = torch.cat( - [frame.get_features() for frame in frames], dim=0 - ).unsqueeze(0) - except Exception as e: - print([[f.device for f in frame.get_features()] for frame in frames]) - raise (e) - - window_length = len(frames) - instances_per_frame = [frame.num_detected for frame in frames] - total_instances = sum(instances_per_frame) - embed_dim = reid_features.shape[-1] - embeddings_dict = {"pos": None, "temp": None} + ref_features = torch.cat( + [instance.features for instance in ref_instances], dim=0 + ).unsqueeze(0) + + # window_length = len(frames) + # instances_per_frame = [frame.num_detected for frame in frames] + total_instances = len(ref_instances) + embed_dim = ref_features.shape[-1] + embeddings_dict = { + "ref": {"pos": None, "temp": None}, + "query": {"pos": None, "temp": None}, + } # print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}') - pred_box, pred_time = get_boxes_times(frames) # total_instances, 4 - pred_box = torch.nan_to_num(pred_box, -1.0) + ref_boxes = get_boxes(ref_instances) # total_instances, 4 + ref_boxes = torch.nan_to_num(ref_boxes, -1.0) + ref_times, query_times = get_times(ref_instances, query_instances) + + window_length = len(ref_times.unique()) - temp_emb = self.temp_emb(pred_time / window_length) + ref_temp_emb = self.temp_emb(ref_times / window_length) if self.return_embedding: - embeddings_dict["temp"] = temp_emb + embeddings_dict["ref"]["temp"] = ref_temp_emb - pos_emb = self.pos_emb(pred_box) + ref_pos_emb = self.pos_emb(ref_boxes) if self.return_embedding: - embeddings_dict["pos"] = pos_emb + embeddings_dict["ref"]["pos"] = ref_pos_emb - try: - emb = (pos_emb + temp_emb) / 2.0 - except RuntimeError as e: - print(self.pos_emb.features, self.temp_emb.features) - print(pos_emb.shape, temp_emb.shape) - raise (e) + ref_emb = (ref_pos_emb + ref_temp_emb) / 2.0 - emb = emb.view(1, total_instances, embed_dim) + ref_emb = ref_emb.view(1, total_instances, embed_dim) - emb = emb.permute(1, 0, 2) # (total_instances, batch_size, embed_dim) + ref_emb = ref_emb.permute(1, 0, 2) # (total_instances, batch_size, embed_dim) - batch_size, total_instances, embed_dim = reid_features.shape + batch_size, total_instances, embed_dim = ref_features.shape - reid_features = reid_features.permute( + ref_features = ref_features.permute( 1, 0, 2 ) # (total_instances, batch_size, embed_dim) - encoder_queries = reid_features + encoder_queries = ref_features encoder_features = self.encoder( - encoder_queries, pos_emb=emb + encoder_queries, pos_emb=ref_emb ) # (total_instances, batch_size, embed_dim) n_query = total_instances - decoder_queries = reid_features - decoder_query_emb = emb + query_features = ref_features + query_pos_emb = ref_pos_emb + query_temp_emb = ref_temp_emb + query_emb = ref_emb - if query_frame is not None: - query_inds = [ - x - for x in range( - sum(instances_per_frame[:query_frame]), - sum(instances_per_frame[: query_frame + 1]), - ) - ] - n_query = len(query_inds) + if query_instances is not None: + n_query = len(query_instances) + + query_features = torch.cat( + [instance.features for instance in query_instances], dim=0 + ).unsqueeze(0) + + query_features = query_features.permute( + 1, 0, 2 + ) # (n_query, batch_size, embed_dim) - decoder_queries = decoder_queries[ - query_inds - ] # decoder_queries: (n_query, batch_size, embed_dim) - decoder_query_emb = decoder_query_emb[query_inds] + query_boxes = get_boxes(query_instances) + + query_temp_emb = self.temp_emb(query_times / window_length) + if self.return_embedding: + embeddings_dict["query"]["temp"] = query_temp_emb + + query_pos_emb = self.pos_emb(query_boxes) + if self.return_embedding: + embeddings_dict["query"]["pos"] = query_pos_emb + + query_emb = (query_pos_emb + query_temp_emb) / 2.0 + + query_emb = query_emb.view(1, n_query, embed_dim) + + query_emb = query_emb.permute(1, 0, 2) # (n_query, batch_size, embed_dim) decoder_features = self.decoder( - decoder_queries, + query_features, encoder_features, - pos_emb=emb, - query_pos_emb=decoder_query_emb, + ref_pos_emb=ref_emb, + query_pos_emb=query_emb, ) # (L, n_query, batch_size, embed_dim) decoder_features = decoder_features.transpose( @@ -372,7 +383,7 @@ def forward( self, decoder_queries: torch.Tensor, encoder_features: torch.Tensor, - pos_emb: torch.Tensor = None, + ref_pos_emb: torch.Tensor = None, query_pos_emb: torch.Tensor = None, ) -> torch.Tensor: """Execute forward pass of decoder layer. @@ -381,7 +392,7 @@ def forward( decoder_queries: Target sequence for decoder to generate (n_query, batch_size, embed_dim). encoder_features: Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim) - pos_emb: The input positional embedding tensor of shape (n_query, embed_dim). + ref_pos_emb: The input positional embedding tensor of shape (n_query, embed_dim). query_pos_emb: The target positional embedding of shape (n_query, embed_dim) Returns: @@ -389,11 +400,11 @@ def forward( """ if query_pos_emb is None: query_pos_emb = torch.zeros_like(decoder_queries) - if pos_emb is None: - pos_emb = torch.zeros_like(encoder_features) + if ref_pos_emb is None: + ref_pos_emb = torch.zeros_like(encoder_features) decoder_queries = decoder_queries + query_pos_emb - encoder_features = encoder_features + pos_emb + encoder_features = encoder_features + ref_pos_emb if self.decoder_self_attn: self_attn_features = self.self_attn( @@ -496,7 +507,7 @@ def forward( self, decoder_queries: torch.Tensor, encoder_features: torch.Tensor, - pos_emb: torch.Tensor = None, + ref_pos_emb: torch.Tensor = None, query_pos_emb: torch.Tensor = None, ) -> torch.Tensor: """Execute a forward pass of the decoder block. @@ -505,7 +516,7 @@ def forward( decoder_queries: Query sequence for decoder to generate (n_query, batch_size, embed_dim). encoder_features: Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim) - pos_emb: The input positional embedding tensor of shape (total_instances, batch_size, embed_dim). + ref_pos_emb: The input positional embedding tensor of shape (total_instances, batch_size, embed_dim). query_pos_emb: The query positional embedding of shape (n_query, batch_size, embed_dim) Returns: @@ -519,7 +530,7 @@ def forward( decoder_features = layer( decoder_features, encoder_features, - pos_emb=pos_emb, + ref_pos_emb=ref_pos_emb, query_pos_emb=query_pos_emb, ) if self.return_intermediate: diff --git a/biogtr/training/losses.py b/biogtr/training/losses.py index 7dfc15f2..b6f1d5e8 100644 --- a/biogtr/training/losses.py +++ b/biogtr/training/losses.py @@ -1,7 +1,7 @@ """Module containing different loss functions to be optimized.""" from biogtr.data_structures import Frame -from biogtr.models.model_utils import get_boxes_times +from biogtr.models.model_utils import get_boxes, get_times from torch import nn from typing import List, Tuple import torch @@ -49,9 +49,11 @@ def forward( # get number of detected objects and ground truth ids n_t = [frame.num_detected for frame in frames] target_inst_id = torch.cat([frame.get_gt_track_ids() for frame in frames]) + instances = [instance for frame in frames for instance in frame.instances] # for now set equal since detections are fixed - pred_box, pred_time = get_boxes_times(frames) + pred_box = get_boxes(instances) + pred_time, _ = get_times(instances) pred_box = torch.nanmean(pred_box, axis=1) target_box, target_time = pred_box, pred_time diff --git a/tests/test_models.py b/tests/test_models.py index c19f8187..08412e49 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -382,7 +382,10 @@ def test_transformer_decoder(): pos_emb = query_pos_emb = torch.ones_like(encoder_features) decoder_features = transformer_decoder( - decoder_queries, encoder_features, pos_emb=pos_emb, query_pos_emb=query_pos_emb + decoder_queries, + encoder_features, + ref_pos_emb=pos_emb, + query_pos_emb=query_pos_emb, ) assert decoder_features.size() == decoder_queries.size() @@ -411,7 +414,8 @@ def test_transformer_basic(): Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances) ) - asso_preds, _ = transformer(frames) + instances = [instance for frame in frames for instance in frame.instances] + asso_preds, _ = transformer(instances) assert asso_preds[0].size() == (num_detected * num_frames,) * 2 @@ -435,6 +439,8 @@ def test_transformer_embedding(): ) frames.append(Frame(video_id=0, frame_id=i, instances=instances)) + instances = [instance for frame in frames for instance in frame.instances] + embedding_meta = { "pos": {"mode": "learned", "emb_num": 16, "normalize": True}, "temp": {"mode": "learned", "emb_num": 16, "normalize": True}, @@ -451,11 +457,11 @@ def test_transformer_embedding(): assert transformer.pos_emb.mode == "learned" assert transformer.temp_emb.mode == "learned" - asso_preds, embeddings = transformer(frames) + asso_preds, embeddings = transformer(instances) assert asso_preds[0].size() == (num_detected * num_frames,) * 2 - for emb_type, embedding in embeddings.items(): + for emb_type, embedding in embeddings["ref"].items(): assert embedding.size() == ( num_detected * num_frames, feats, @@ -503,12 +509,12 @@ def test_tracking_transformer(): embedding_meta=embedding_meta, return_embedding=True, ) - - asso_preds, embeddings = tracking_transformer(frames) + instances = [instance for frame in frames for instance in frame.instances] + asso_preds, embeddings = tracking_transformer(instances) assert asso_preds[0].size() == (num_detected * num_frames,) * 2 - for emb_type, embedding in embeddings.items(): + for emb_type, embedding in embeddings["ref"].items(): assert embedding.size() == ( num_detected * num_frames, feats, From 98106b79b6a85a98ca0a5c02bd22649226fa4161 Mon Sep 17 00:00:00 2001 From: Aaditya Prasad <78439225+aaprasad@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:14:29 -0700 Subject: [PATCH 2/7] Refactor data structures (#47) Co-authored-by: Talmo Pereira Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- biogtr/data_structures.py | 1075 ----------------- biogtr/datasets/base_dataset.py | 2 +- biogtr/datasets/cell_tracking_dataset.py | 5 +- biogtr/datasets/eval_dataset.py | 3 +- biogtr/datasets/microscopy_dataset.py | 5 +- biogtr/datasets/sleap_dataset.py | 3 +- biogtr/inference/metrics.py | 4 +- biogtr/inference/track.py | 4 +- biogtr/inference/track_queue.py | 2 +- biogtr/inference/tracker.py | 12 +- biogtr/io/__init__.py | 8 + biogtr/io/association_matrix.py | 327 +++++ biogtr/{ => io}/config.py | 0 biogtr/io/frame.py | 554 +++++++++ biogtr/io/instance.py | 650 ++++++++++ biogtr/io/track.py | 94 ++ biogtr/{ => io}/visualize.py | 0 biogtr/models/global_tracking_transformer.py | 8 +- biogtr/models/gtr_runner.py | 6 +- biogtr/models/model_utils.py | 2 +- biogtr/models/transformer.py | 42 +- biogtr/training/losses.py | 2 +- biogtr/training/train.py | 2 +- tests/test_config.py | 2 +- ..._data_structures.py => test_data_model.py} | 157 +-- tests/test_inference.py | 82 +- tests/test_models.py | 57 +- tests/test_training.py | 5 +- 28 files changed, 1892 insertions(+), 1221 deletions(-) delete mode 100644 biogtr/data_structures.py create mode 100644 biogtr/io/__init__.py create mode 100644 biogtr/io/association_matrix.py rename biogtr/{ => io}/config.py (100%) create mode 100644 biogtr/io/frame.py create mode 100644 biogtr/io/instance.py create mode 100644 biogtr/io/track.py rename biogtr/{ => io}/visualize.py (100%) rename tests/{test_data_structures.py => test_data_model.py} (56%) diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py deleted file mode 100644 index 81b2f551..00000000 --- a/biogtr/data_structures.py +++ /dev/null @@ -1,1075 +0,0 @@ -"""Module containing data classes such as Instances and Frames.""" - -import torch -import sleap_io as sio -import numpy as np -from numpy.typing import ArrayLike -from typing import Union, List - - -class Instance: - """Class representing a single instance to be tracked.""" - - def __init__( - self, - gt_track_id: int = -1, - pred_track_id: int = -1, - bbox: ArrayLike = None, - crop: ArrayLike = None, - centroid: dict[str, ArrayLike] = None, - features: ArrayLike = None, - track_score: float = -1.0, - point_scores: ArrayLike = None, - instance_score: float = -1.0, - skeleton: sio.Skeleton = None, - pose: dict[str, ArrayLike] = None, - device: str = None, - ): - """Initialize Instance. - - Args: - gt_track_id: Ground truth track id - only used for train/eval. - pred_track_id: Predicted track id. Untracked instance is represented by -1. - bbox: The bounding box coordinate of the instance. Defaults to an empty tensor. - crop: The crop of the instance. - centroid: the centroid around which the bbox was cropped. - features: The reid features extracted from the CNN backbone used in the transformer. - track_score: The track score output from the association matrix. - point_scores: The point scores from sleap. - instance_score: The instance scores from sleap. - skeleton: The sleap skeleton used for the instance. - pose: A dictionary containing the node name and corresponding point. - device: String representation of the device the instance should be on. - """ - if gt_track_id is not None: - self._gt_track_id = torch.tensor([gt_track_id]) - else: - self._gt_track_id = torch.tensor([-1]) - - if pred_track_id is not None: - self._pred_track_id = torch.tensor([pred_track_id]) - else: - self._pred_track_id = torch.tensor([]) - - if skeleton is None: - self._skeleton = sio.Skeleton(["centroid"]) - else: - self._skeleton = skeleton - - if bbox is None: - self._bbox = torch.empty(1, 0, 4) - - elif not isinstance(bbox, torch.Tensor): - self._bbox = torch.tensor(bbox) - - else: - self._bbox = bbox - - if self._bbox.shape[0] and len(self._bbox.shape) == 1: - self._bbox = self._bbox.unsqueeze(0) # (n_anchors, 4) - - if self._bbox.shape[1] and len(self._bbox.shape) == 2: - self._bbox = self._bbox.unsqueeze(0) # (1, n_anchors, 4) - - if centroid is not None: - self._centroid = centroid - - elif self.bbox.shape[1]: - y1, x1, y2, x2 = self.bbox.squeeze(dim=0).nanmean(dim=0) - self._centroid = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])} - - else: - self._centroid = {} - - if crop is None: - self._crop = torch.tensor([]) - elif not isinstance(crop, torch.Tensor): - self._crop = torch.tensor(crop) - else: - self._crop = crop - - if len(self._crop.shape) == 2: # (h, w) - self._crop = self._crop.unsqueeze(0) # (c, h, w) - if len(self._crop.shape) == 3: - self._crop = self._crop.unsqueeze(0) # (1, c, h, w) - - if features is None: - self._features = torch.tensor([]) - elif not isinstance(features, torch.Tensor): - self._features = torch.tensor(features) - else: - self._features = features - - if self._features.shape[0] and len(self._features.shape) == 1: # (d,) - self._features = self._features.unsqueeze(0) # (1, d) - - if pose is not None: - self._pose = pose - - elif self.bbox.shape[1]: - y1, x1, y2, x2 = self.bbox.squeeze(dim=0).mean(dim=0) - self._pose = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])} - - else: - self._pose = {} - - self._track_score = track_score - self._instance_score = instance_score - - if point_scores is not None: - self._point_scores = point_scores - else: - self._point_scores = np.zeros_like(self.pose) - - self._device = device - self.to(self._device) - - self._frame = None - - def __repr__(self) -> str: - """Return string representation of the Instance.""" - return ( - "Instance(" - f"gt_track_id={self._gt_track_id.item()}, " - f"pred_track_id={self._pred_track_id.item()}, " - f"bbox={self._bbox}, " - f"centroid={self._centroid}, " - f"crop={self._crop.shape}, " - f"features={self._features.shape}, " - f"device={self._device}" - ")" - ) - - def to(self, map_location): - """Move instance to different device or change dtype. (See `torch.to` for more info). - - Args: - map_location: Either the device or dtype for the instance to be moved. - - Returns: - self: reference to the instance moved to correct device/dtype. - """ - if map_location is not None and map_location != "": - self._gt_track_id = self._gt_track_id.to(map_location) - self._pred_track_id = self._pred_track_id.to(map_location) - self._bbox = self._bbox.to(map_location) - self._crop = self._crop.to(map_location) - self._features = self._features.to(map_location) - self.device = map_location - - return self - - def to_slp( - self, track_lookup: dict[int, sio.Track] = {} - ) -> tuple[sio.PredictedInstance, dict[int, sio.Track]]: - """Convert instance to sleap_io.PredictedInstance object. - - Args: - track_lookup: A track look up dictionary containing track_id:sio.Track. - Returns: A sleap_io.PredictedInstance with necessary metadata - and a track_lookup dictionary to persist tracks. - """ - try: - track_id = self.pred_track_id.item() - if track_id not in track_lookup: - track_lookup[track_id] = sio.Track(name=self.pred_track_id.item()) - - track = track_lookup[track_id] - - return ( - sio.PredictedInstance.from_numpy( - points=self.pose, - skeleton=self.skeleton, - point_scores=self.point_scores, - instance_score=self.instance_score, - tracking_score=self.track_score, - track=track, - ), - track_lookup, - ) - except Exception as e: - print( - f"Pose shape: {self.pose.shape}, Pose score shape {self.point_scores.shape}" - ) - raise RuntimeError(f"Failed to convert to sio.PredictedInstance: {e}") - - @property - def device(self) -> str: - """The device the instance is on. - - Returns: - The str representation of the device the gpu is on. - """ - return self._device - - @device.setter - def device(self, device) -> None: - """Set for the device property. - - Args: - device: The str representation of the device. - """ - self._device = device - - @property - def gt_track_id(self) -> torch.Tensor: - """The ground truth track id of the instance. - - Returns: - A tensor containing the ground truth track id - """ - return self._gt_track_id - - @gt_track_id.setter - def gt_track_id(self, track: int): - """Set the instance ground-truth track id. - - Args: - track: An int representing the ground-truth track id. - """ - if track is not None: - self._gt_track_id = torch.tensor([track]) - else: - self._gt_track_id = torch.tensor([]) - - def has_gt_track_id(self) -> bool: - """Determine if instance has a gt track assignment. - - Returns: - True if the gt track id is set, otherwise False. - """ - if self._gt_track_id.shape[0] == 0: - return False - else: - return True - - @property - def pred_track_id(self) -> torch.Tensor: - """The track id predicted by the tracker using asso_output from model. - - Returns: - A tensor containing the predicted track id. - """ - return self._pred_track_id - - @pred_track_id.setter - def pred_track_id(self, track: int) -> None: - """Set predicted track id. - - Args: - track: an int representing the predicted track id. - """ - if track is not None: - self._pred_track_id = torch.tensor([track]) - else: - self._pred_track_id = torch.tensor([]) - - def has_pred_track_id(self) -> bool: - """Determine whether instance has predicted track id. - - Returns: - True if instance has a pred track id, False otherwise. - """ - if self._pred_track_id.item() == -1 or self._pred_track_id.shape[0] == 0: - return False - else: - return True - - @property - def bbox(self) -> torch.Tensor: - """The bounding box coordinates of the instance in the original frame. - - Returns: - A (1,4) tensor containing the bounding box coordinates. - """ - return self._bbox - - @bbox.setter - def bbox(self, bbox: ArrayLike) -> None: - """Set the instance bounding box. - - Args: - bbox: an arraylike object containing the bounding box coordinates. - """ - if bbox is None or len(bbox) == 0: - self._bbox = torch.empty((0, 4)) - else: - if not isinstance(bbox, torch.Tensor): - self._bbox = torch.tensor(bbox) - else: - self._bbox = bbox - - if self._bbox.shape[0] and len(self._bbox.shape) == 1: - self._bbox = self._bbox.unsqueeze(0) - if self._bbox.shape[1] and len(self._bbox.shape) == 2: - self._bbox = self._bbox.unsqueeze(0) - - def has_bbox(self) -> bool: - """Determine if the instance has a bbox. - - Returns: - True if the instance has a bounding box, false otherwise. - """ - if self._bbox.shape[1] == 0: - return False - else: - return True - - @property - def centroid(self) -> dict[str, ArrayLike]: - """The centroid around which the crop was formed. - - Returns: - A dict containing the anchor name and the x, y bbox midpoint. - """ - return self._centroid - - @centroid.setter - def centroid(self, centroid: dict[str, ArrayLike]) -> None: - """Set the centroid of the instance. - - Args: - centroid: A dict containing the anchor name and points. - """ - self._centroid = centroid - - @property - def anchor(self) -> list[str]: - """The anchor node name around which the crop was formed. - - Returns: - the list of anchors around which each crop was formed - the list of anchors around which each crop was formed - """ - if self.centroid: - return list(self.centroid.keys()) - return "" - - @property - def crop(self) -> torch.Tensor: - """The crop of the instance. - - Returns: - A (1, c, h , w) tensor containing the cropped image centered around the instance. - """ - return self._crop - - @crop.setter - def crop(self, crop: ArrayLike) -> None: - """Set the crop of the instance. - - Args: - crop: an arraylike object containing the cropped image of the centered instance. - """ - if crop is None or len(crop) == 0: - self._crop = torch.tensor([]) - else: - if not isinstance(crop, torch.Tensor): - self._crop = torch.tensor(crop) - else: - self._crop = crop - - if len(self._crop.shape) == 2: - self._crop = self._crop.unsqueeze(0) - if len(self._crop.shape) == 3: - self._crop = self._crop.unsqueeze(0) - - def has_crop(self) -> bool: - """Determine if the instance has a crop. - - Returns: - True if the instance has an image otherwise False. - """ - if self._crop.shape[0] == 0: - return False - else: - return True - - @property - def features(self) -> torch.Tensor: - """Re-ID feature vector from backbone model to be used as input to transformer. - - Returns: - a (1, d) tensor containing the reid feature vector. - """ - return self._features - - @features.setter - def features(self, features: ArrayLike) -> None: - """Set the reid feature vector of the instance. - - Args: - features: a (1,d) array like object containing the reid features for the instance. - """ - if features is None or len(features) == 0: - self._features = torch.tensor([]) - - elif not isinstance(features, torch.Tensor): - self._features = torch.tensor(features) - else: - self._features = features - - if self._features.shape[0] and len(self._features.shape) == 1: - self._features = self._features.unsqueeze(0) - - def has_features(self) -> bool: - """Determine if the instance has computed reid features. - - Returns: - True if the instance has reid features, False otherwise. - """ - if self._features.shape[0] == 0: - return False - else: - return True - - @property - def frame(self) -> "Frame": - """Get the frame the instance belongs to. - - Returns: - The back reference to the `Frame` that this `Instance` belongs to. - """ - return self._frame - - @frame.setter - def frame(self, frame: "Frame") -> None: - """Set the back reference to the `Frame` that this `Instance` belongs to. - - This field is set when instances are added to `Frame` object. - - Args: - frame: A `Frame` object containing the metadata for the frame that the instance belongs to - """ - self._frame = frame - - @property - def pose(self) -> dict[str, ArrayLike]: - """Get the pose of the instance. - - Returns: - A dictionary containing the node and corresponding x,y points - """ - return self._pose - - @pose.setter - def pose(self, pose: dict[str, ArrayLike]) -> None: - """Set the pose of the instance. - - Args: - pose: A nodes x 2 array containing the pose coordinates. - """ - if pose is not None: - self._pose = pose - - elif self.bbox.shape[0]: - y1, x1, y2, x2 = self.bbox.squeeze() - self._pose = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])} - - else: - self._pose = {} - - def has_pose(self) -> bool: - """Check if the instance has a pose. - - Returns True if the instance has a pose. - """ - if len(self.pose): - return True - return False - - @property - def shown_pose(self) -> dict[str, ArrayLike]: - """Get the pose with shown nodes only. - - Returns: A dictionary filtered by nodes that are shown (points are not nan). - """ - pose = self.pose - return {node: point for node, point in pose.items() if not np.isna(point).any()} - - @property - def skeleton(self) -> sio.Skeleton: - """Get the skeleton associated with the instance. - - Returns: The sio.Skeleton associated with the instance. - """ - return self._skeleton - - @skeleton.setter - def skeleton(self, skeleton: sio.Skeleton) -> None: - """Set the skeleton associated with the instance. - - Args: - skeleton: The sio.Skeleton associated with the instance. - """ - self._skeleton = skeleton - - @property - def point_scores(self) -> ArrayLike: - """Get the point scores associated with the pose prediction. - - Returns: a vector of shape n containing the point scores outputed from sleap associated with pose predictions. - """ - return self._point_scores - - @point_scores.setter - def point_scores(self, point_scores: ArrayLike) -> None: - """Set the point scores associated with the pose prediction. - - Args: - point_scores: a vector of shape n containing the point scores - outputted from sleap associated with pose predictions. - """ - self._point_scores = point_scores - - @property - def instance_score(self) -> float: - """Get the pose prediction score associated with the instance. - - Returns: a float from 0-1 representing an instance_score. - """ - return self._instance_score - - @instance_score.setter - def instance_score(self, instance_score: float) -> None: - """Set the pose prediction score associated with the instance. - - Args: - instance_score: a float from 0-1 representing an instance_score. - """ - self._instance_score = instance_score - - @property - def track_score(self) -> float: - """Get the track_score of the instance. - - Returns: A float from 0-1 representing the output used in the tracker for assignment. - """ - return self._track_score - - @track_score.setter - def track_score(self, track_score: float) -> None: - """Set the track_score of the instance. - - Args: - track_score: A float from 0-1 representing the output used in the tracker for assignment. - """ - self._track_score = track_score - - -class Frame: - """Data structure containing metadata for a single frame of a video.""" - - def __init__( - self, - video_id: int, - frame_id: int, - vid_file: str = "", - img_shape: ArrayLike = None, - instances: List[Instance] = None, - asso_output: ArrayLike = None, - matches: tuple = None, - traj_score: Union[ArrayLike, dict] = None, - device=None, - ): - """Initialize Frame. - - Args: - video_id: The video index in the dataset. - frame_id: The index of the frame in a video. - vid_file: The path to the video the frame is from. - img_shape: The shape of the original frame (not the crop). - instances: A list of Instance objects that appear in the frame. - asso_output: The association matrix between instances - output directly from the transformer. - matches: matches from LSA algorithm between the instances and - available trajectories during tracking. - traj_score: Either a dict containing the association matrix - between instances and trajectories along postprocessing pipeline - or a single association matrix. - device: The device the frame should be moved to. - """ - self._video_id = torch.tensor([video_id]) - self._frame_id = torch.tensor([frame_id]) - - try: - self._video = sio.Video(vid_file) - except ValueError: - self._video = vid_file - if img_shape is None: - self._img_shape = torch.tensor([0, 0, 0]) - elif isinstance(img_shape, torch.Tensor): - self._img_shape = img_shape - else: - self._img_shape = torch.tensor([img_shape]) - - if instances is None: - self.instances = [] - else: - for instance in instances: - instance.frame = self - self._instances = instances - - self._asso_output = asso_output - self._matches = matches - - if traj_score is None: - self._traj_score = {} - elif isinstance(traj_score, dict): - self._traj_score = traj_score - else: - self._traj_score = {"initial": traj_score} - - self._device = device - self.to(device) - - def __repr__(self) -> str: - """Return String representation of the Frame. - - Returns: - The string representation of the frame. - """ - return ( - "Frame(" - f"video={self._video.filename if isinstance(self._video, sio.Video) else self._video}, " - f"video_id={self._video_id.item()}, " - f"frame_id={self._frame_id.item()}, " - f"img_shape={self._img_shape}, " - f"num_detected={self.num_detected}, " - f"asso_output={self._asso_output}, " - f"traj_score={list(self._traj_score.keys())}, " - f"matches={self._matches}, " - f"instances={self._instances}, " - f"device={self._device}" - ")" - ) - - def to(self, map_location: str): - """Move frame to different device or dtype (See `torch.to` for more info). - - Args: - map_location: A string representing the device to move to. - - Returns: - The frame moved to a different device/dtype. - """ - self._video_id = self._video_id.to(map_location) - self._frame_id = self._frame_id.to(map_location) - self._img_shape = self._img_shape.to(map_location) - - if isinstance(self._asso_output, torch.Tensor): - self._asso_output = self._asso_output.to(map_location) - - if isinstance(self._matches, torch.Tensor): - self._matches = self._matches.to(map_location) - - for key, val in self._traj_score.items(): - if isinstance(val, torch.Tensor): - self._traj_score[key] = val.to(map_location) - - for instance in self._instances: - instance = instance.to(map_location) - - self._device = map_location - return self - - def to_slp( - self, track_lookup: dict[int, sio.Track] = {} - ) -> tuple[sio.LabeledFrame, dict[int, sio.Track]]: - """Convert Frame to sleap_io.LabeledFrame object. - - Args: - track_lookup: A lookup dictionary containing the track_id and sio.Track for persistence - - Returns: A tuple containing a LabeledFrame object with necessary metadata and - a lookup dictionary containing the track_id and sio.Track for persistence - """ - slp_instances = [] - for instance in self.instances: - slp_instance, track_lookup = instance.to_slp(track_lookup=track_lookup) - slp_instances.append(slp_instance) - return ( - sio.LabeledFrame( - video=self.video, - frame_idx=self.frame_id.item(), - instances=slp_instances, - ), - track_lookup, - ) - - @property - def device(self) -> str: - """The device the frame is on. - - Returns: - The string representation of the device the frame is on. - """ - return self._device - - @device.setter - def device(self, device: str) -> None: - """Set the device. - - Note: Do not set `frame.device = device` normally. Use `frame.to(device)` instead. - - Args: - device: the device the function should be on. - """ - self._device = device - - @property - def video_id(self) -> torch.Tensor: - """The index of the video the frame comes from. - - Returns: - A tensor containing the video index. - """ - return self._video_id - - @video_id.setter - def video_id(self, video_id: int) -> None: - """Set the video index. - - Note: Generally the video_id should be immutable after initialization. - - Args: - video_id: an int representing the index of the video that the frame came from. - """ - self._video_id = torch.tensor([video_id]) - - @property - def frame_id(self) -> torch.Tensor: - """The index of the frame in a full video. - - Returns: - A torch tensor containing the index of the frame in the video. - """ - return self._frame_id - - @frame_id.setter - def frame_id(self, frame_id: int) -> None: - """Set the frame index of the frame. - - Note: The frame_id should generally be immutable after initialization. - - Args: - frame_id: The int index of the frame in the full video. - """ - self._frame_id = torch.tensor([frame_id]) - - @property - def video(self) -> Union[sio.Video, str]: - """Get the video associated with the frame. - - Returns: An sio.Video object representing the video or a placeholder string - if it is not possible to create the sio.Video - """ - return self._video - - @video.setter - def video(self, video_filename: str) -> None: - """Set the video associated with the frame. - - Note: we try to store the video in an sio.Video object. - However, if this is not possible (e.g. incompatible format or missing filepath) - then we simply store the string. - - Args: - video_filename: string path to video_file - """ - try: - self._video = sio.Video(video_filename) - except ValueError: - self._video = video_filename - - @property - def img_shape(self) -> torch.Tensor: - """The shape of the pre-cropped frame. - - Returns: - A torch tensor containing the shape of the frame. Should generally be (c, h, w) - """ - return self._img_shape - - @img_shape.setter - def img_shape(self, img_shape: ArrayLike) -> None: - """Set the shape of the frame image. - - Note: the img_shape should generally be immutable after initialization. - - Args: - img_shape: an ArrayLike object containing the shape of the frame image. - """ - if isinstance(img_shape, torch.Tensor): - self._img_shape = img_shape - else: - self._img_shape = torch.tensor([img_shape]) - - @property - def instances(self) -> List[Instance]: - """A list of instances in the frame. - - Returns: - The list of instances that appear in the frame. - """ - return self._instances - - @instances.setter - def instances(self, instances: List[Instance]) -> None: - """Set the frame's instance. - - Args: - instances: A list of Instances that appear in the frame. - """ - for instance in instances: - instance.frame = self - - self._instances = instances - - def has_instances(self) -> bool: - """Determine whether there are instances in the frame. - - Returns: - True if there are instances in the frame, otherwise False. - """ - if self.num_detected == 0: - return False - return True - - @property - def num_detected(self) -> int: - """The number of instances in the frame. - - Returns: - the number of instances in the frame. - """ - return len(self.instances) - - @property - def asso_output(self) -> ArrayLike: - """The association matrix between instances outputed directly by transformer. - - Returns: - An arraylike (n_query, n_nonquery) association matrix between instances. - """ - return self._asso_output - - def has_asso_output(self) -> bool: - """Determine whether the frame has an association matrix computed. - - Returns: - True if the frame has an association matrix otherwise, False. - """ - if self._asso_output is None or len(self._asso_output) == 0: - return False - return True - - @asso_output.setter - def asso_output(self, asso_output: ArrayLike) -> None: - """Set the association matrix of a frame. - - Args: - asso_output: An arraylike (n_query, n_nonquery) association matrix between instances. - """ - self._asso_output = asso_output - - @property - def matches(self) -> tuple: - """Matches between frame instances and availabel trajectories. - - Returns: - A tuple containing the instance idx and trajectory idx for the matched instance. - """ - return self._matches - - @matches.setter - def matches(self, matches: tuple) -> None: - """Set the frame matches. - - Args: - matches: A tuple containing the instance idx and trajectory idx for the matched instance. - """ - self._matches = matches - - def has_matches(self) -> bool: - """Check whether or not matches have been computed for frame. - - Returns: - True if frame contains matches otherwise False. - """ - if self._matches is not None and len(self._matches) > 0: - return True - return False - - def get_traj_score(self, key=None) -> Union[dict, ArrayLike, None]: - """Get dictionary containing association matrix between instances and trajectories along postprocessing pipeline. - - Args: - key: The key of the trajectory score to be accessed. - Can be one of {None, 'initial', 'decay_time', 'max_center_dist', 'iou', 'final'} - - Returns: - - dictionary containing all trajectory scores if key is None - - trajectory score associated with key - - None if the key is not found - """ - if key is None: - return self._traj_score - else: - try: - return self._traj_score[key] - except KeyError as e: - print(f"Could not access {key} traj_score due to {e}") - return None - - def add_traj_score(self, key, traj_score: ArrayLike) -> None: - """Add trajectory score to dictionary. - - Args: - key: key associated with traj score to be used in dictionary - traj_score: association matrix between instances and trajectories - """ - self._traj_score[key] = traj_score - - def has_traj_score(self) -> bool: - """Check if any trajectory association matrix has been saved. - - Returns: - True there is at least one association matrix otherwise, false. - """ - if len(self._traj_score) == 0: - return False - return True - - def has_gt_track_ids(self) -> bool: - """Check if any of frames instances has a gt track id. - - Returns: - True if at least 1 instance has a gt track id otherwise False. - """ - if self.has_instances(): - return any([instance.has_gt_track_id() for instance in self.instances]) - return False - - def get_gt_track_ids(self) -> torch.Tensor: - """Get the gt track ids of all instances in the frame. - - Returns: - an (N,) shaped tensor with the gt track ids of each instance in the frame. - """ - if not self.has_instances(): - return torch.tensor([]) - return torch.cat([instance.gt_track_id for instance in self.instances]) - - def has_pred_track_ids(self) -> bool: - """Check if any of frames instances has a pred track id. - - Returns: - True if at least 1 instance has a pred track id otherwise False. - """ - if self.has_instances(): - return any([instance.has_pred_track_id() for instance in self.instances]) - return False - - def get_pred_track_ids(self) -> torch.Tensor: - """Get the pred track ids of all instances in the frame. - - Returns: - an (N,) shaped tensor with the pred track ids of each instance in the frame. - """ - if not self.has_instances(): - return torch.tensor([]) - return torch.cat([instance.pred_track_id for instance in self.instances]) - - def has_bboxes(self) -> bool: - """Check if any of frames instances has a bounding box. - - Returns: - True if at least 1 instance has a bounding box otherwise False. - """ - if self.has_instances(): - return any([instance.has_bboxes() for instance in self.instances]) - return False - - def get_bboxes(self) -> torch.Tensor: - """Get the bounding boxes of all instances in the frame. - - Returns: - an (N,4) shaped tensor with bounding boxes of each instance in the frame. - """ - if not self.has_instances(): - return torch.empty(0, 4) - return torch.cat([instance.bbox for instance in self.instances], dim=0) - - def has_crops(self) -> bool: - """Check if any of frames instances has a crop. - - Returns: - True if at least 1 instance has a crop otherwise False. - """ - if self.has_instances(): - return any([instance.has_crop() for instance in self.instances]) - return False - - def get_crops(self) -> torch.Tensor: - """Get the crops of all instances in the frame. - - Returns: - an (N, C, H, W) shaped tensor with crops of each instance in the frame. - """ - if not self.has_instances(): - return torch.tensor([]) - try: - return torch.cat([instance.crop for instance in self.instances], dim=0) - except Exception as e: - print(self) - raise (e) - - def has_features(self): - """Check if any of frames instances has reid features already computed. - - Returns: - True if at least 1 instance have reid features otherwise False. - """ - if self.has_instances(): - return any([instance.has_features() for instance in self.instances]) - return False - - def get_features(self): - """Get the reid feature vectors of all instances in the frame. - - Returns: - an (N, D) shaped tensor with reid feature vectors of each instance in the frame. - """ - if not self.has_instances(): - return torch.tensor([]) - return torch.cat([instance.features for instance in self.instances], dim=0) - - def get_anchors(self) -> list[str]: - """Get the anchor names of instances in the frame. - - Returns: - A list of anchor names used by the instances to get the crop. - """ - return [instance.anchor for instance in self.instances] - - def get_centroids(self) -> tuple[list[str], ArrayLike]: - """Get the centroids around which each instance's crop was formed. - - Returns: - anchors: the node names for the corresponding point - points: an n_instances x 2 array containing the centroids - """ - anchors = [ - anchor for instance in self.instances for anchor in instance.centroid.keys() - ] - - points = np.array( - [ - point - for instance in self.instances - for point in instance.centroid.values() - ] - ) - - return (anchors, points) diff --git a/biogtr/datasets/base_dataset.py b/biogtr/datasets/base_dataset.py index ce19dc54..74ccf448 100644 --- a/biogtr/datasets/base_dataset.py +++ b/biogtr/datasets/base_dataset.py @@ -1,7 +1,7 @@ """Module containing logic for loading datasets.""" from biogtr.datasets import data_utils -from biogtr.data_structures import Frame +from biogtr.io.frame import Frame from torch.utils.data import Dataset from typing import List, Union import numpy as np diff --git a/biogtr/datasets/cell_tracking_dataset.py b/biogtr/datasets/cell_tracking_dataset.py index 03142281..65889526 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/biogtr/datasets/cell_tracking_dataset.py @@ -3,7 +3,8 @@ from PIL import Image from biogtr.datasets import data_utils from biogtr.datasets.base_dataset import BaseDataset -from biogtr.data_structures import Instance, Frame +from biogtr.io.frame import Frame +from biogtr.io.instance import Instance from scipy.ndimage import measurements from typing import List, Optional, Union import albumentations as A @@ -122,7 +123,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Fram Returns: a list of Frame objects containing frame metadata and Instance Objects. - See `biogtr.data_structures` for more info. + See `biogtr.io.data_structures` for more info. """ image = self.videos[label_idx] gt = self.labels[label_idx] diff --git a/biogtr/datasets/eval_dataset.py b/biogtr/datasets/eval_dataset.py index e2cbea2b..6f52a8c9 100644 --- a/biogtr/datasets/eval_dataset.py +++ b/biogtr/datasets/eval_dataset.py @@ -1,7 +1,8 @@ """Module containing wrapper for merging gt and pred datasets for evaluation.""" from torch.utils.data import Dataset -from biogtr.data_structures import Frame, Instance +from biogtr.io.frame import Frame +from biogtr.io.instance import Instance from typing import List diff --git a/biogtr/datasets/microscopy_dataset.py b/biogtr/datasets/microscopy_dataset.py index dfb92f09..acb50c3f 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/biogtr/datasets/microscopy_dataset.py @@ -3,7 +3,8 @@ from PIL import Image from biogtr.datasets import data_utils from biogtr.datasets.base_dataset import BaseDataset -from biogtr.data_structures import Frame, Instance +from biogtr.io.frame import Frame +from biogtr.io.instance import Instance from typing import Union import albumentations as A import numpy as np @@ -122,7 +123,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram frame_idx: index of the frames Returns: - A list of Frames containing Instances to be tracked (See `biogtr.data_structures for more info`) + A list of Frames containing Instances to be tracked (See `biogtr.io.data_structures for more info`) """ labels = self.labels[label_idx] labels = labels.dropna(how="all") diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index 82e390eb..24d0e1ec 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -7,7 +7,8 @@ import sleap_io as sio import random import warnings -from biogtr.data_structures import Frame, Instance +from biogtr.io.frame import Frame +from biogtr.io.instance import Instance from biogtr.datasets import data_utils from biogtr.datasets.base_dataset import BaseDataset from torchvision.transforms import functional as tvf diff --git a/biogtr/inference/metrics.py b/biogtr/inference/metrics.py index f71ff609..ed961a96 100644 --- a/biogtr/inference/metrics.py +++ b/biogtr/inference/metrics.py @@ -3,7 +3,7 @@ import numpy as np import motmetrics as mm import torch -from biogtr.data_structures import Frame +from biogtr.io.frame import Frame from typing import Union, Iterable # from biogtr.inference.post_processing import _pairwise_iou @@ -105,7 +105,7 @@ def to_track_eval(frames: list[Frame]) -> dict: """Reformats frames the output from `sliding_inference` to be used by `TrackEval`. Args: - frames: A list of Frames. `See biogtr.data_structures for more info`. + frames: A list of Frames. `See biogtr.io.data_structures for more info`. Returns: data: A dictionary. Example provided below. diff --git a/biogtr/inference/track.py b/biogtr/inference/track.py index aed32d8b..24b6fefa 100644 --- a/biogtr/inference/track.py +++ b/biogtr/inference/track.py @@ -1,8 +1,8 @@ """Script to run inference and get out tracks.""" -from biogtr.config import Config +from biogtr.io.config import Config from biogtr.models.gtr_runner import GTRRunner -from biogtr.data_structures import Frame +from biogtr.io.frame import Frame from omegaconf import DictConfig from pathlib import Path from pprint import pprint diff --git a/biogtr/inference/track_queue.py b/biogtr/inference/track_queue.py index a400227b..89d8991f 100644 --- a/biogtr/inference/track_queue.py +++ b/biogtr/inference/track_queue.py @@ -1,7 +1,7 @@ """Module handling sliding window tracking.""" import warnings -from biogtr.data_structures import Frame +from biogtr.io.frame import Frame from collections import deque import numpy as np diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 3e3405ba..213914c2 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -3,7 +3,7 @@ import torch import pandas as pd import warnings -from biogtr.data_structures import Frame +from biogtr.io.frame import Frame from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer from biogtr.models import model_utils from biogtr.inference.track_queue import TrackQueue @@ -128,7 +128,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame Args: model: the pretrained GlobalTrackingTransformer to be used for inference - frame: A list of Frames (See `biogtr.data_structures.Frame` for more info). + frames: A list of Frames (See `biogtr.io.data_structures.Frame` for more info). Returns: @@ -209,7 +209,7 @@ def _run_global_tracker( Args: model: the pretrained GlobalTrackingTransformer to be used for inference - frames: A list of Frames containing reid features. See `biogtr.data_structures` for more info. + frames: A list of Frames containing reid features. See `biogtr.io.data_structures` for more info. query_ind: An integer for the query frame within the window of instances. Returns: @@ -259,12 +259,12 @@ def _run_global_tracker( # (L=1, n_query, total_instances) with torch.no_grad(): - asso_output, embed = model(all_instances, query_instances) + asso_matrix = model(all_instances, query_instances) # if model.transformer.return_embedding: # query_frame.embeddings = embed TODO add embedding to Instance Object # if query_frame == 1: # print(asso_output) - asso_output = asso_output[-1].split( + asso_output = asso_matrix[-1].matrix.split( instances_per_frame, dim=1 ) # (window_size, n_query, N_i) asso_output = model_utils.softmax_asso( @@ -281,7 +281,7 @@ def _run_global_tracker( asso_output_df.columns.name = "Instances" query_frame.add_traj_score("asso_output", asso_output_df) - query_frame.asso_output = asso_output + query_frame.asso_output = asso_matrix try: n_query = ( diff --git a/biogtr/io/__init__.py b/biogtr/io/__init__.py new file mode 100644 index 00000000..eec945d8 --- /dev/null +++ b/biogtr/io/__init__.py @@ -0,0 +1,8 @@ +"""Module containing input/output data structures for easy storage and manipulation.""" + +from biogtr.io.frame import Frame +from biogtr.io.instance import Instance +from biogtr.io.association_matrix import AssociationMatrix +from biogtr.io.track import Track + +# TODO: expose config without circular import error from biogtr.io.config import Config diff --git a/biogtr/io/association_matrix.py b/biogtr/io/association_matrix.py new file mode 100644 index 00000000..a5f4a3a7 --- /dev/null +++ b/biogtr/io/association_matrix.py @@ -0,0 +1,327 @@ +"""Module containing class for storing and looking up association scores.""" + +import torch +import numpy as np +import pandas as pd +import attrs +from biogtr.io import Instance +from typing import Union + + +@attrs.define +class AssociationMatrix: + """Class representing the associations between detections. + + Attributes: + matrix: the `n_query x n_ref` association matrix` + ref_instances: all instances used to associate against. + query_instances: query instances that were associated against ref instances. + """ + + matrix: Union[np.ndarray, torch.Tensor] + ref_instances: list[Instance] = attrs.field() + query_instances: list[Instance] = attrs.field() + + @ref_instances.validator + def _check_ref_instances(self, attribute, value): + """Check to ensure that the number of association matrix columns and reference instances match. + + Args: + attribute: The ref instances. + value: the list of ref instances. + + Raises: + ValueError if the number of columns and reference instances don't match. + """ + if len(value) != self.matrix.shape[-1]: + raise ValueError( + ( + "Ref instances must equal number of columns in Association matrix" + f"Found {len(value)} ref instances but {self.matrix.shape[-1]} columns." + ) + ) + + @query_instances.validator + def _check_query_instances(self, attribute, value): + """Check to ensure that the number of association matrix rows and query instances match. + + Args: + attribute: The query instances. + value: the list of query instances. + + Raises: + ValueError if the number of rows and query instances don't match. + """ + if len(value) != self.matrix.shape[0]: + raise ValueError( + ( + "Query instances must equal number of rows in Association matrix" + f"Found {len(value)} query instances but {self.matrix.shape[0]} rows." + ) + ) + + def __repr__(self) -> str: + """Get the string representation of the Association Matrix. + + Returns: + the string representation of the association matrix. + """ + return ( + f"AssociationMatrix({self.matrix}," + f"query_instances={len(self.query_instances)}," + f"ref_instances={len(self.ref_instances)})" + ) + + def numpy(self) -> np.ndarray: + """Convert association matrix to a numpy array. + + Returns: + The association matrix as a numpy array. + """ + if isinstance(self.matrix, torch.Tensor): + return self.matrix.detach().cpu().numpy() + return self.matrix + + def to_dataframe( + self, row_labels: str = "gt", col_labels: str = "gt" + ) -> pd.DataFrame: + """Convert the association matrix to a pandas DataFrame. + + Args: + row_labels: How to label the rows(queries). + If list, then must match # of rows/queries + If `"gt"` then label by gt track id. + If `"pred"` then label by pred track id. + Otherwise label by the query_instance indices + col_labels: How to label the columns(references). + If list, then must match # of columns/refs + If `"gt"` then label by gt track id. + If `"pred"` then label by pred track id. + Otherwise label by the ref_instance indices + + Returns: + The association matrix as a pandas dataframe. + """ + matrix = self.numpy() + + if not isinstance(row_labels, str): + if len(row_labels) == len(self.query_instances): + row_inds = row_labels + else: + raise ValueError( + ( + f"Mismatched # of rows and labels!", + f"Found {len(row_labels)} with {len(self.query_instances)} rows", + ) + ) + else: + if row_labels == "gt": + row_inds = [ + instance.gt_track_id.item() for instance in self.query_instances + ] + + elif row_labels == "pred": + row_inds = [ + instance.pred_track_id.item() for instance in self.query_instances + ] + + else: + row_inds = np.arange(len(self.query_instances)) + + if not isinstance(col_labels, str): + if len(col_labels) == len(self.ref_instances): + col_inds = col_labels + else: + raise ValueError( + ( + f"Mismatched # of columns and labels!", + f"Found {len(col_labels)} with {len(self.ref_instances)} columns", + ) + ) + else: + if col_labels == "gt": + col_inds = [ + instance.gt_track_id.item() for instance in self.ref_instances + ] + + elif col_labels == "pred": + col_inds = [ + instance.pred_track_id.item() for instance in self.ref_instances + ] + + else: + col_inds = np.arange(len(self.ref_instances)) + + asso_df = pd.DataFrame(matrix, index=row_inds, columns=col_inds) + + return asso_df + + def reduce( + self, + row_dims: str = "instance", + col_dims: str = "track", + row_grouping: str = None, + col_grouping: str = "pred", + reduce_method: callable = np.sum, + ) -> pd.DataFrame: + """Aggregate the association matrix by specified dimensions and grouping. + + Args: + row_dims: A str indicating how to what dimensions to reduce rows to. + Either "instance" (remains unchanged), or "track" (n_rows=n_traj). + col_dims: A str indicating how to dimensions to reduce rows to. + Either "instance" (remains unchanged), or "track" (n_cols=n_traj) + row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt". + col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt". + method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing. + + Returns: + The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe. + """ + n_rows = len(self.query_instances) + n_cols = len(self.ref_instances) + + col_tracks = {-1: self.ref_instances} + row_tracks = {-1: self.query_instances} + + col_inds = [i for i in range(len(self.ref_instances))] + row_inds = [i for i in range(len(self.query_instances))] + + if col_dims == "track": + col_tracks = self.get_tracks(self.ref_instances, col_grouping) + col_inds = list(col_tracks.keys()) + n_cols = len(col_inds) + + if row_dims == "track": + row_tracks = self.get_tracks(self.query_instances, row_grouping) + row_inds = list(row_tracks.keys()) + n_rows = len(row_inds) + + reduced_matrix = [] + for row_track, row_instances in row_tracks.items(): + + for col_track, col_instances in col_tracks.items(): + asso_matrix = self[row_instances, col_instances] + + if col_dims == "track": + asso_matrix = reduce_method(asso_matrix, axis=1) + + if row_dims == "track": + asso_matrix = reduce_method(asso_matrix, axis=0) + + reduced_matrix.append(asso_matrix) + + reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T + + return pd.DataFrame(reduced_matrix, index=row_inds, columns=col_inds) + + def __getitem__(self, inds) -> np.ndarray: + """Get elements of the association matrix. + + Args: + inds: A tuple of query indices and reference indices. + Indices can be either: + A single instance or integer. + A list of instances or integers. + + Returns: + An np.ndarray containing the elements requested. + """ + query_inst, ref_inst = inds + + query_ind = self.__getindices__(query_inst, self.query_instances) + ref_ind = self.__getindices__(ref_inst, self.ref_instances) + + try: + return self.numpy()[query_ind[:, None], ref_ind].squeeze() + + except IndexError as e: + print(f"Query_insts: {type(query_inst)}") + print(f"Query_inds: {query_ind}") + print(f"Ref_insts: {type(ref_inst)}") + print(f"Ref_ind: {ref_ind}") + raise (e) + + def __getindices__( + self, + instance: Union[Instance, int, np.typing.ArrayLike], + instance_lookup: list[Instance], + ) -> np.ndarray: + """Get the indices of the instance for lookup. + + Args: + instance: The instance(s) to be retrieved + Can either be a single int/instance or a list of int/instances + instance_lookup: A list of Instances to be used to retrieve indices + + Returns: + A np array of indices. + """ + if isinstance(instance, Instance): + ind = np.array([instance_lookup.index(instance)]) + + elif instance is None: + ind = np.arange(len(instance_lookup)) + + elif np.isscalar(instance): + ind = np.array([instance]) + + else: + instances = instance + if not [isinstance(inst, (Instance, int)) for inst in instance]: + raise ValueError( + f"List of indices must be `int` or `Instance`. Found {set([type(inst) for inst in instance])}" + ) + ind = np.array( + [ + ( + instance_lookup.index(instance) + if isinstance(instance, Instance) + else instance + ) + for instance in instances + ] + ) + + return ind + + def get_tracks( + self, instances: list["Instance"], label: str = "pred" + ) -> dict[int, list["Instance"]]: + """Group instances by track. + + Args: + instances: The list of instances to group + label: the track id type to group by. Either `pred` or `gt`. + + Returns: + A dictionary of track_id:instances + """ + if label == "pred": + traj_ids = set([instance.pred_track_id.item() for instance in instances]) + traj = { + track_id: [ + instance + for instance in instances + if instance.pred_track_id.item() == track_id + ] + for track_id in traj_ids + } + + elif label == "gt": + traj_ids = set( + [instance.gt_track_id.item() for instance in self.ref_instances] + ) + traj = { + track_id: [ + instance + for instance in self.ref_instances + if instance.gt_track_id.item() == track_id + ] + for track_id in traj_ids + } + + else: + raise ValueError(f"Unsupported label '{label}'. Expected 'pred' or 'gt'.") + + return traj diff --git a/biogtr/config.py b/biogtr/io/config.py similarity index 100% rename from biogtr/config.py rename to biogtr/io/config.py diff --git a/biogtr/io/frame.py b/biogtr/io/frame.py new file mode 100644 index 00000000..248c0159 --- /dev/null +++ b/biogtr/io/frame.py @@ -0,0 +1,554 @@ +"""Module containing data classes such as Instances and Frames.""" + +import torch +import sleap_io as sio +import numpy as np +import attrs +from numpy.typing import ArrayLike +from typing import Union, List +from biogtr.io.instance import Instance + + +def _to_tensor(data: Union[float, ArrayLike]) -> torch.Tensor: + """Convert data to tensortype. + + Args: + data: A scalar or np.ndarray to be converted to a torch tensor + Returns: + A torch tensor containing `data`. + """ + if data is None: + return torch.tensor([]) + + if isinstance(data, torch.Tensor): + return data + elif np.isscalar(data): + return torch.tensor([data]) + else: + return torch.tensor(data) + + +@attrs.define(eq=False) +class Frame: + """Data structure containing metadata for a single frame of a video. + + Attributes: + video_id: The video index in the dataset. + frame_id: The index of the frame in a video. + vid_file: The path to the video the frame is from. + img_shape: The shape of the original frame (not the crop). + instances: A list of Instance objects that appear in the frame. + asso_output: The association matrix between instances + output directly from the transformer. + matches: matches from LSA algorithm between the instances and + available trajectories during tracking. + traj_score: Either a dict containing the association matrix + between instances and trajectories along postprocessing pipeline + or a single association matrix. + device: The device the frame should be moved to. + """ + + _video_id: int = attrs.field(alias="video_id", converter=_to_tensor) + _frame_id: int = attrs.field(alias="frame_id", converter=_to_tensor) + _video: str = attrs.field(alias="vid_file", default="") + _img_shape: ArrayLike = attrs.field( + alias="img_shape", converter=_to_tensor, factory=list + ) + + _instances: list["Instance"] = attrs.field(alias="instances", factory=list) + _asso_output: "AssociationMatrix" = attrs.field(alias="asso_output", default=None) + _matches: tuple = attrs.field(alias="matches", factory=tuple) + _traj_score: dict = attrs.field(alias="traj_score", factory=dict) + _device: str = attrs.field(alias="device", default=None) + + def __attrs_post_init__(self) -> None: + """Handle more intricate default initializations and moving to device.""" + if len(self.img_shape) == 0: + self.img_shape = torch.tensor([0, 0, 0]) + + for instance in self.instances: + instance.frame = self + + self.to(self.device) + + def __repr__(self) -> str: + """Return String representation of the Frame. + + Returns: + The string representation of the frame. + """ + return ( + "Frame(" + f"video={self._video.filename if isinstance(self._video, sio.Video) else self._video}, " + f"video_id={self._video_id.item()}, " + f"frame_id={self._frame_id.item()}, " + f"img_shape={self._img_shape}, " + f"num_detected={self.num_detected}, " + f"asso_output={self._asso_output}, " + f"traj_score={self._traj_score}, " + f"matches={self._matches}, " + f"instances={self._instances}, " + f"device={self._device}" + ")" + ) + + def to(self, map_location: str): + """Move frame to different device or dtype (See `torch.to` for more info). + + Args: + map_location: A string representing the device to move to. + + Returns: + The frame moved to a different device/dtype. + """ + self._video_id = self._video_id.to(map_location) + self._frame_id = self._frame_id.to(map_location) + self._img_shape = self._img_shape.to(map_location) + + if isinstance(self._asso_output, torch.Tensor): + self._asso_output = self._asso_output.to(map_location) + + if isinstance(self._matches, torch.Tensor): + self._matches = self._matches.to(map_location) + + for key, val in self._traj_score.items(): + if isinstance(val, torch.Tensor): + self._traj_score[key] = val.to(map_location) + for instance in self.instances: + instance = instance.to(map_location) + + if isinstance(map_location, str): + self._device = map_location + + return self + + @classmethod + def from_slp( + cls, + lf: sio.LabeledFrame, + video_id: int = 0, + device: str = None, + **kwargs, + ) -> "Frame": + """Convert `sio.LabeledFrame` to `biogtr.io.Frame`. + + Args: + lf: A sio.LabeledFrame object + + Returns: + A biogtr.io.Frame object + """ + img_shape = lf.image.shape + if len(img_shape) == 2: + img_shape = (1, *img_shape) + elif len(img_shape) > 2 and img_shape[-1] <= 3: + img_shape = (lf.image.shape[-1], lf.image.shape[0], lf.image.shape[1]) + return cls( + video_id=video_id, + frame_id=( + lf.frame_idx.astype(np.int32) + if isinstance(lf.frame_idx, np.number) + else lf.frame_idx + ), + vid_file=lf.video.filename, + img_shape=img_shape, + instances=[Instance.from_slp(instance, **kwargs) for instance in lf], + device=device, + ) + + def to_slp( + self, track_lookup: dict[int, sio.Track] = {} + ) -> tuple[sio.LabeledFrame, dict[int, sio.Track]]: + """Convert Frame to sleap_io.LabeledFrame object. + + Args: + track_lookup: A lookup dictionary containing the track_id and sio.Track for persistence + + Returns: A tuple containing a LabeledFrame object with necessary metadata and + a lookup dictionary containing the track_id and sio.Track for persistence + """ + slp_instances = [] + for instance in self.instances: + slp_instance, track_lookup = instance.to_slp(track_lookup=track_lookup) + slp_instances.append(slp_instance) + return ( + sio.LabeledFrame( + video=self.video, + frame_idx=self.frame_id.item(), + instances=slp_instances, + ), + track_lookup, + ) + + @property + def device(self) -> str: + """The device the frame is on. + + Returns: + The string representation of the device the frame is on. + """ + return self._device + + @device.setter + def device(self, device: str) -> None: + """Set the device. + + Note: Do not set `frame.device = device` normally. Use `frame.to(device)` instead. + + Args: + device: the device the function should be on. + """ + self._device = device + + @property + def video_id(self) -> torch.Tensor: + """The index of the video the frame comes from. + + Returns: + A tensor containing the video index. + """ + return self._video_id + + @video_id.setter + def video_id(self, video_id: int) -> None: + """Set the video index. + + Note: Generally the video_id should be immutable after initialization. + + Args: + video_id: an int representing the index of the video that the frame came from. + """ + self._video_id = torch.tensor([video_id]) + + @property + def frame_id(self) -> torch.Tensor: + """The index of the frame in a full video. + + Returns: + A torch tensor containing the index of the frame in the video. + """ + return self._frame_id + + @frame_id.setter + def frame_id(self, frame_id: int) -> None: + """Set the frame index of the frame. + + Note: The frame_id should generally be immutable after initialization. + + Args: + frame_id: The int index of the frame in the full video. + """ + self._frame_id = torch.tensor([frame_id]) + + @property + def video(self) -> Union[sio.Video, str]: + """Get the video associated with the frame. + + Returns: An sio.Video object representing the video or a placeholder string + if it is not possible to create the sio.Video + """ + return self._video + + @video.setter + def video(self, video_filename: str) -> None: + """Set the video associated with the frame. + + Note: we try to store the video in an sio.Video object. + However, if this is not possible (e.g. incompatible format or missing filepath) + then we simply store the string. + + Args: + video_filename: string path to video_file + """ + try: + self._video = sio.load_video(video_filename) + except ValueError: + self._video = video_filename + + @property + def img_shape(self) -> torch.Tensor: + """The shape of the pre-cropped frame. + + Returns: + A torch tensor containing the shape of the frame. Should generally be (c, h, w) + """ + return self._img_shape + + @img_shape.setter + def img_shape(self, img_shape: ArrayLike) -> None: + """Set the shape of the frame image. + + Note: the img_shape should generally be immutable after initialization. + + Args: + img_shape: an ArrayLike object containing the shape of the frame image. + """ + self._img_shape = _to_tensor(img_shape) + + @property + def instances(self) -> List["Instance"]: + """A list of instances in the frame. + + Returns: + The list of instances that appear in the frame. + """ + return self._instances + + @instances.setter + def instances(self, instances: List["Instance"]) -> None: + """Set the frame's instance. + + Args: + instances: A list of Instances that appear in the frame. + """ + for instance in instances: + instance.frame = self + self._instances = instances + + def has_instances(self) -> bool: + """Determine whether there are instances in the frame. + + Returns: + True if there are instances in the frame, otherwise False. + """ + if self.num_detected == 0: + return False + return True + + @property + def num_detected(self) -> int: + """The number of instances in the frame. + + Returns: + the number of instances in the frame. + """ + return len(self.instances) + + @property + def asso_output(self) -> "AssociationMatrix": + """The association matrix between instances outputed directly by transformer. + + Returns: + An arraylike (n_query, n_nonquery) association matrix between instances. + """ + return self._asso_output + + def has_asso_output(self) -> bool: + """Determine whether the frame has an association matrix computed. + + Returns: + True if the frame has an association matrix otherwise, False. + """ + if self._asso_output is None or len(self._asso_output.matrix) == 0: + return False + return True + + @asso_output.setter + def asso_output(self, asso_output: "AssociationMatrix") -> None: + """Set the association matrix of a frame. + + Args: + asso_output: An arraylike (n_query, n_nonquery) association matrix between instances. + """ + self._asso_output = asso_output + + @property + def matches(self) -> tuple: + """Matches between frame instances and availabel trajectories. + + Returns: + A tuple containing the instance idx and trajectory idx for the matched instance. + """ + return self._matches + + @matches.setter + def matches(self, matches: tuple) -> None: + """Set the frame matches. + + Args: + matches: A tuple containing the instance idx and trajectory idx for the matched instance. + """ + self._matches = matches + + def has_matches(self) -> bool: + """Check whether or not matches have been computed for frame. + + Returns: + True if frame contains matches otherwise False. + """ + if self._matches is not None and len(self._matches) > 0: + return True + return False + + def get_traj_score(self, key=None) -> Union[dict, ArrayLike, None]: + """Get dictionary containing association matrix between instances and trajectories along postprocessing pipeline. + + Args: + key: The key of the trajectory score to be accessed. + Can be one of {None, 'initial', 'decay_time', 'max_center_dist', 'iou', 'final'} + + Returns: + - dictionary containing all trajectory scores if key is None + - trajectory score associated with key + - None if the key is not found + """ + if key is None: + return self._traj_score + else: + try: + return self._traj_score[key] + except KeyError as e: + print(f"Could not access {key} traj_score due to {e}") + return None + + def add_traj_score(self, key, traj_score: ArrayLike) -> None: + """Add trajectory score to dictionary. + + Args: + key: key associated with traj score to be used in dictionary + traj_score: association matrix between instances and trajectories + """ + self._traj_score[key] = traj_score + + def has_traj_score(self) -> bool: + """Check if any trajectory association matrix has been saved. + + Returns: + True there is at least one association matrix otherwise, false. + """ + if len(self._traj_score) == 0: + return False + return True + + def has_gt_track_ids(self) -> bool: + """Check if any of frames instances has a gt track id. + + Returns: + True if at least 1 instance has a gt track id otherwise False. + """ + if self.has_instances(): + return any([instance.has_gt_track_id() for instance in self.instances]) + return False + + def get_gt_track_ids(self) -> torch.Tensor: + """Get the gt track ids of all instances in the frame. + + Returns: + an (N,) shaped tensor with the gt track ids of each instance in the frame. + """ + if not self.has_instances(): + return torch.tensor([]) + return torch.cat([instance.gt_track_id for instance in self.instances]) + + def has_pred_track_ids(self) -> bool: + """Check if any of frames instances has a pred track id. + + Returns: + True if at least 1 instance has a pred track id otherwise False. + """ + if self.has_instances(): + return any([instance.has_pred_track_id() for instance in self.instances]) + return False + + def get_pred_track_ids(self) -> torch.Tensor: + """Get the pred track ids of all instances in the frame. + + Returns: + an (N,) shaped tensor with the pred track ids of each instance in the frame. + """ + if not self.has_instances(): + return torch.tensor([]) + return torch.cat([instance.pred_track_id for instance in self.instances]) + + def has_bboxes(self) -> bool: + """Check if any of frames instances has a bounding box. + + Returns: + True if at least 1 instance has a bounding box otherwise False. + """ + if self.has_instances(): + return any([instance.has_bboxes() for instance in self.instances]) + return False + + def get_bboxes(self) -> torch.Tensor: + """Get the bounding boxes of all instances in the frame. + + Returns: + an (N,4) shaped tensor with bounding boxes of each instance in the frame. + """ + if not self.has_instances(): + return torch.empty(0, 4) + return torch.cat([instance.bbox for instance in self.instances], dim=0) + + def has_crops(self) -> bool: + """Check if any of frames instances has a crop. + + Returns: + True if at least 1 instance has a crop otherwise False. + """ + if self.has_instances(): + return any([instance.has_crop() for instance in self.instances]) + return False + + def get_crops(self) -> torch.Tensor: + """Get the crops of all instances in the frame. + + Returns: + an (N, C, H, W) shaped tensor with crops of each instance in the frame. + """ + if not self.has_instances(): + return torch.tensor([]) + try: + return torch.cat([instance.crop for instance in self.instances], dim=0) + except Exception as e: + print(self) + raise (e) + + def has_features(self): + """Check if any of frames instances has reid features already computed. + + Returns: + True if at least 1 instance have reid features otherwise False. + """ + if self.has_instances(): + return any([instance.has_features() for instance in self.instances]) + return False + + def get_features(self): + """Get the reid feature vectors of all instances in the frame. + + Returns: + an (N, D) shaped tensor with reid feature vectors of each instance in the frame. + """ + if not self.has_instances(): + return torch.tensor([]) + return torch.cat([instance.features for instance in self.instances], dim=0) + + def get_anchors(self) -> list[str]: + """Get the anchor names of instances in the frame. + + Returns: + A list of anchor names used by the instances to get the crop. + """ + return [instance.anchor for instance in self.instances] + + def get_centroids(self) -> tuple[list[str], ArrayLike]: + """Get the centroids around which each instance's crop was formed. + + Returns: + anchors: the node names for the corresponding point + points: an n_instances x 2 array containing the centroids + """ + anchors = [ + anchor for instance in self.instances for anchor in instance.centroid.keys() + ] + + points = np.array( + [ + point + for instance in self.instances + for point in instance.centroid.values() + ] + ) + + return (anchors, points) diff --git a/biogtr/io/instance.py b/biogtr/io/instance.py new file mode 100644 index 00000000..44dc2386 --- /dev/null +++ b/biogtr/io/instance.py @@ -0,0 +1,650 @@ +"""Module containing data class for storing detections.""" + +import torch +import sleap_io as sio +import numpy as np +import attrs +from numpy.typing import ArrayLike +from typing import Union + + +def _to_tensor(data: Union[float, ArrayLike]) -> torch.Tensor: + """Convert data to a torch.Tensor object. + + Args: + data: Either a scalar quantity or arraylike object + + Returns: + A torch Tensor containing data. + """ + if data is None: + return torch.tensor([]) + if isinstance(data, torch.Tensor): + return data + elif np.isscalar(data): + return torch.tensor([data]) + else: + return torch.tensor(data) + + +def _expand_to_rank( + arr: Union[np.ndarray, torch.Tensor], new_rank: int +) -> Union[np.ndarray, torch.Tensor]: + """Expand n-dimensional array to appropriate dimensions by adding singleton dimensions to the front of the array. + + Args: + arr: an n-dimensional array (either np.ndarray or torch.Tensor). + + Returns: + The array expanded to the correct dimensions. + """ + curr_rank = len(arr.shape) + while curr_rank < new_rank: + if isinstance(arr, np.ndarray): + arr = np.expand_dims(arr, axis=0) + elif isinstance(arr, torch.Tensor): + arr = arr.unsqueeze(0) + else: + raise TypeError( + f"`arr` must be either an np.ndarray or torch.Tensor but found {type(arr)}" + ) + curr_rank = len(arr.shape) + return arr + + +@attrs.define(eq=False) +class Instance: + """Class representing a single instance to be tracked. + + Attributes: + gt_track_id: Ground truth track id - only used for train/eval. + pred_track_id: Predicted track id. Untracked instance is represented by -1. + bbox: The bounding box coordinate of the instance. Defaults to an empty tensor. + crop: The crop of the instance. + centroid: the centroid around which the bbox was cropped. + features: The reid features extracted from the CNN backbone used in the transformer. + track_score: The track score output from the association matrix. + point_scores: The point scores from sleap. + instance_score: The instance scores from sleap. + skeleton: The sleap skeleton used for the instance. + pose: A dictionary containing the node name and corresponding point. + device: String representation of the device the instance should be on. + """ + + _gt_track_id: int = attrs.field( + alias="gt_track_id", default=-1, converter=_to_tensor + ) + _pred_track_id: int = attrs.field( + alias="pred_track_id", default=-1, converter=_to_tensor + ) + _bbox: ArrayLike = attrs.field(alias="bbox", factory=list, converter=_to_tensor) + _crop: ArrayLike = attrs.field(alias="crop", factory=list, converter=_to_tensor) + _centroid: dict[str, ArrayLike] = attrs.field(alias="centroid", factory=dict) + _features: ArrayLike = attrs.field( + alias="features", factory=list, converter=_to_tensor + ) + _embeddings: dict = attrs.field(alias="embeddings", factory=dict) + _track_score: float = attrs.field(alias="track_score", default=-1.0) + _instance_score: float = attrs.field(alias="instance_score", default=-1.0) + _point_scores: ArrayLike = attrs.field(alias="point_scores", default=None) + _skeleton: sio.Skeleton = attrs.field(alias="skeleton", default=None) + _pose: dict[str, ArrayLike] = attrs.field(alias="pose", factory=dict) + _device: str = attrs.field(alias="device", default=None) + _frame: "Frame" = None + + def __attrs_post_init__(self) -> None: + """Handle dimensionality and more intricate default initializations post-init.""" + self.bbox = _expand_to_rank(self.bbox, 3) + self.crop = _expand_to_rank(self.crop, 4) + self.features = _expand_to_rank(self.features, 2) + + if self.skeleton is None: + self.skeleton = sio.Skeleton(["centroid"]) + + if self.bbox.shape[-1] == 0: + self.bbox = torch.empty([1, 0, 4]) + + if self.crop.shape[-1] == 0 and self.bbox.shape[1] != 0: + y1, x1, y2, x2 = self.bbox.squeeze(dim=0).nanmean(dim=0) + self.centroid = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])} + + if len(self.pose) == 0 and self.bbox.shape[1]: + y1, x1, y2, x2 = self.bbox.squeeze(dim=0).mean(dim=0) + self._pose = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])} + + if self.point_scores is None and len(self.pose) != 0: + self._point_scores = np.zeros((len(self.pose), 2)) + + self.to(self.device) + + def __repr__(self) -> str: + """Return string representation of the Instance.""" + return ( + "Instance(" + f"gt_track_id={self._gt_track_id.item()}, " + f"pred_track_id={self._pred_track_id.item()}, " + f"bbox={self._bbox}, " + f"centroid={self._centroid}, " + f"crop={self._crop.shape}, " + f"features={self._features.shape}, " + f"device={self._device}" + ")" + ) + + def to(self, map_location): + """Move instance to different device or change dtype. (See `torch.to` for more info). + + Args: + map_location: Either the device or dtype for the instance to be moved. + + Returns: + self: reference to the instance moved to correct device/dtype. + """ + if map_location is not None and map_location != "": + self._gt_track_id = self._gt_track_id.to(map_location) + self._pred_track_id = self._pred_track_id.to(map_location) + self._bbox = self._bbox.to(map_location) + self._crop = self._crop.to(map_location) + self._features = self._features.to(map_location) + if isinstance(map_location, str): + self.device = map_location + + return self + + @classmethod + def from_slp( + cls, + slp_instance: Union[sio.PredictedInstance, sio.Instance], + bbox_size: Union[int, tuple] = 64, + crop: ArrayLike = None, + device: str = None, + ) -> None: + """Convert a slp instance to a biogtr instance. + + Args: + slp_instance: A `sleap_io.Instance` object representing a detection + bbox_size: size of the pose-centered bbox to form. + crop: The corresponding crop of the bbox + device: which device to keep the instance on + Returns: + A biogtr.Instance object with a pose-centered bbox and no crop. + """ + try: + track_id = int(slp_instance.track.name) + except ValueError: + track_id = int( + "".join([str(ord(c)) for c in slp_instance.track.name]) + ) # better way to handle this? + if isinstance(bbox_size, int): + bbox_size = (bbox_size, bbox_size) + + track_score = -1.0 + point_scores = np.full(len(slp_instance.points), -1) + instance_score = -1 + if isinstance(slp_instance, sio.PredictedInstance): + track_score = slp_instance.tracking_score + point_scores = slp_instance.numpy()[:, -1] + instance_score = slp_instance.score + + centroid = np.nanmean(slp_instance.numpy(), axis=1) + bbox = [ + centroid[1] - bbox_size[1], + centroid[0] - bbox_size[0], + centroid[1] + bbox_size[1], + centroid[0] + bbox_size[0], + ] + return cls( + gt_track_id=track_id, + bbox=bbox, + crop=crop, + centroid={"centroid": centroid}, + track_score=track_score, + point_scores=point_scores, + instance_score=instance_score, + skeleton=slp_instance.skeleton, + pose={ + node.name: point.numpy() for node, point in slp_instance.points.items() + }, + device=device, + ) + + def to_slp( + self, track_lookup: dict[int, sio.Track] = {} + ) -> tuple[sio.PredictedInstance, dict[int, sio.Track]]: + """Convert instance to sleap_io.PredictedInstance object. + + Args: + track_lookup: A track look up dictionary containing track_id:sio.Track. + Returns: A sleap_io.PredictedInstance with necessary metadata + and a track_lookup dictionary to persist tracks. + """ + try: + track_id = self.pred_track_id.item() + if track_id not in track_lookup: + track_lookup[track_id] = sio.Track(name=self.pred_track_id.item()) + + track = track_lookup[track_id] + + return ( + sio.PredictedInstance.from_numpy( + points=self.pose, + skeleton=self.skeleton, + point_scores=self.point_scores, + instance_score=self.instance_score, + tracking_score=self.track_score, + track=track, + ), + track_lookup, + ) + except Exception as e: + print( + f"Pose shape: {self.pose.shape}, Pose score shape {self.point_scores.shape}" + ) + raise RuntimeError(f"Failed to convert to sio.PredictedInstance: {e}") + + @property + def device(self) -> str: + """The device the instance is on. + + Returns: + The str representation of the device the gpu is on. + """ + return self._device + + @device.setter + def device(self, device) -> None: + """Set for the device property. + + Args: + device: The str representation of the device. + """ + self._device = device + + @property + def gt_track_id(self) -> torch.Tensor: + """The ground truth track id of the instance. + + Returns: + A tensor containing the ground truth track id + """ + return self._gt_track_id + + @gt_track_id.setter + def gt_track_id(self, track: int): + """Set the instance ground-truth track id. + + Args: + track: An int representing the ground-truth track id. + """ + if track is not None: + self._gt_track_id = torch.tensor([track]) + else: + self._gt_track_id = torch.tensor([]) + + def has_gt_track_id(self) -> bool: + """Determine if instance has a gt track assignment. + + Returns: + True if the gt track id is set, otherwise False. + """ + if self._gt_track_id.shape[0] == 0: + return False + else: + return True + + @property + def pred_track_id(self) -> torch.Tensor: + """The track id predicted by the tracker using asso_output from model. + + Returns: + A tensor containing the predicted track id. + """ + return self._pred_track_id + + @pred_track_id.setter + def pred_track_id(self, track: int) -> None: + """Set predicted track id. + + Args: + track: an int representing the predicted track id. + """ + if track is not None: + self._pred_track_id = torch.tensor([track]) + else: + self._pred_track_id = torch.tensor([]) + + def has_pred_track_id(self) -> bool: + """Determine whether instance has predicted track id. + + Returns: + True if instance has a pred track id, False otherwise. + """ + if self._pred_track_id.item() == -1 or self._pred_track_id.shape[0] == 0: + return False + else: + return True + + @property + def bbox(self) -> torch.Tensor: + """The bounding box coordinates of the instance in the original frame. + + Returns: + A (1,4) tensor containing the bounding box coordinates. + """ + return self._bbox + + @bbox.setter + def bbox(self, bbox: ArrayLike) -> None: + """Set the instance bounding box. + + Args: + bbox: an arraylike object containing the bounding box coordinates. + """ + if bbox is None or len(bbox) == 0: + self._bbox = torch.empty((0, 4)) + else: + if not isinstance(bbox, torch.Tensor): + self._bbox = torch.tensor(bbox) + else: + self._bbox = bbox + + if self._bbox.shape[0] and len(self._bbox.shape) == 1: + self._bbox = self._bbox.unsqueeze(0) + if self._bbox.shape[1] and len(self._bbox.shape) == 2: + self._bbox = self._bbox.unsqueeze(0) + + def has_bbox(self) -> bool: + """Determine if the instance has a bbox. + + Returns: + True if the instance has a bounding box, false otherwise. + """ + if self._bbox.shape[1] == 0: + return False + else: + return True + + @property + def centroid(self) -> dict[str, ArrayLike]: + """The centroid around which the crop was formed. + + Returns: + A dict containing the anchor name and the x, y bbox midpoint. + """ + return self._centroid + + @centroid.setter + def centroid(self, centroid: dict[str, ArrayLike]) -> None: + """Set the centroid of the instance. + + Args: + centroid: A dict containing the anchor name and points. + """ + self._centroid = centroid + + @property + def anchor(self) -> list[str]: + """The anchor node name around which the crop was formed. + + Returns: + the list of anchors around which each crop was formed + the list of anchors around which each crop was formed + """ + if self.centroid: + return list(self.centroid.keys()) + return "" + + @property + def crop(self) -> torch.Tensor: + """The crop of the instance. + + Returns: + A (1, c, h , w) tensor containing the cropped image centered around the instance. + """ + return self._crop + + @crop.setter + def crop(self, crop: ArrayLike) -> None: + """Set the crop of the instance. + + Args: + crop: an arraylike object containing the cropped image of the centered instance. + """ + if crop is None or len(crop) == 0: + self._crop = torch.tensor([]) + else: + if not isinstance(crop, torch.Tensor): + self._crop = torch.tensor(crop) + else: + self._crop = crop + + if len(self._crop.shape) == 2: + self._crop = self._crop.unsqueeze(0) + if len(self._crop.shape) == 3: + self._crop = self._crop.unsqueeze(0) + + def has_crop(self) -> bool: + """Determine if the instance has a crop. + + Returns: + True if the instance has an image otherwise False. + """ + if self._crop.shape[-1] == 0: + return False + else: + return True + + @property + def features(self) -> torch.Tensor: + """Re-ID feature vector from backbone model to be used as input to transformer. + + Returns: + a (1, d) tensor containing the reid feature vector. + """ + return self._features + + @features.setter + def features(self, features: ArrayLike) -> None: + """Set the reid feature vector of the instance. + + Args: + features: a (1,d) array like object containing the reid features for the instance. + """ + if features is None or len(features) == 0: + self._features = torch.tensor([]) + + elif not isinstance(features, torch.Tensor): + self._features = torch.tensor(features) + else: + self._features = features + + if self._features.shape[0] and len(self._features.shape) == 1: + self._features = self._features.unsqueeze(0) + + def has_features(self) -> bool: + """Determine if the instance has computed reid features. + + Returns: + True if the instance has reid features, False otherwise. + """ + if self._features.shape[-1] == 0: + return False + else: + return True + + def has_embedding(self, emb_type: str = None) -> bool: + """Determine if the instance has embedding type requested. + + Args: + emb_type: The key to check in the embedding dictionary. + + Returns: + True if `emb_type` in embedding_dict else false + """ + return emb_type in self._embeddings + + def get_embedding( + self, emb_type: str = "all" + ) -> Union[dict[str, torch.Tensor], torch.Tensor, None]: + """Retrieve instance's spatial/temporal embedding. + + Args: + emb_type: The string key of the embedding to retrieve. Should be "pos", "temp" + + Returns: + * A torch tensor representing the spatial/temporal location of the instance. + * None if the embedding is not stored + """ + if emb_type.lower() == "all": + return self._embeddings + else: + try: + return self._embeddings[emb_type] + except KeyError: + print( + f"{emb_type} not saved! Only {list(self._embeddings.keys())} are available" + ) + return None + + def add_embedding(self, emb_type: str, embedding: torch.Tensor) -> None: + """Save embedding to instance embedding dictionary. + + Args: + emb_type: Key/embedding type to be saved to dictionary + embedding: The actual torch tensor embedding. + """ + embedding = _expand_to_rank(embedding, 2) + self._embeddings[emb_type] = embedding + + @property + def frame(self) -> "Frame": + """Get the frame the instance belongs to. + + Returns: + The back reference to the `Frame` that this `Instance` belongs to. + """ + return self._frame + + @frame.setter + def frame(self, frame: "Frame") -> None: + """Set the back reference to the `Frame` that this `Instance` belongs to. + + This field is set when instances are added to `Frame` object. + + Args: + frame: A `Frame` object containing the metadata for the frame that the instance belongs to + """ + self._frame = frame + + @property + def pose(self) -> dict[str, ArrayLike]: + """Get the pose of the instance. + + Returns: + A dictionary containing the node and corresponding x,y points + """ + return self._pose + + @pose.setter + def pose(self, pose: dict[str, ArrayLike]) -> None: + """Set the pose of the instance. + + Args: + pose: A nodes x 2 array containing the pose coordinates. + """ + if pose is not None: + self._pose = pose + + elif self.bbox.shape[0]: + y1, x1, y2, x2 = self.bbox.squeeze() + self._pose = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])} + + else: + self._pose = {} + + def has_pose(self) -> bool: + """Check if the instance has a pose. + + Returns True if the instance has a pose. + """ + if len(self.pose): + return True + return False + + @property + def shown_pose(self) -> dict[str, ArrayLike]: + """Get the pose with shown nodes only. + + Returns: A dictionary filtered by nodes that are shown (points are not nan). + """ + pose = self.pose + return {node: point for node, point in pose.items() if not np.isna(point).any()} + + @property + def skeleton(self) -> sio.Skeleton: + """Get the skeleton associated with the instance. + + Returns: The sio.Skeleton associated with the instance. + """ + return self._skeleton + + @skeleton.setter + def skeleton(self, skeleton: sio.Skeleton) -> None: + """Set the skeleton associated with the instance. + + Args: + skeleton: The sio.Skeleton associated with the instance. + """ + self._skeleton = skeleton + + @property + def point_scores(self) -> ArrayLike: + """Get the point scores associated with the pose prediction. + + Returns: a vector of shape n containing the point scores outputed from sleap associated with pose predictions. + """ + return self._point_scores + + @point_scores.setter + def point_scores(self, point_scores: ArrayLike) -> None: + """Set the point scores associated with the pose prediction. + + Args: + point_scores: a vector of shape n containing the point scores + outputted from sleap associated with pose predictions. + """ + self._point_scores = point_scores + + @property + def instance_score(self) -> float: + """Get the pose prediction score associated with the instance. + + Returns: a float from 0-1 representing an instance_score. + """ + return self._instance_score + + @instance_score.setter + def instance_score(self, instance_score: float) -> None: + """Set the pose prediction score associated with the instance. + + Args: + instance_score: a float from 0-1 representing an instance_score. + """ + self._instance_score = instance_score + + @property + def track_score(self) -> float: + """Get the track_score of the instance. + + Returns: A float from 0-1 representing the output used in the tracker for assignment. + """ + return self._track_score + + @track_score.setter + def track_score(self, track_score: float) -> None: + """Set the track_score of the instance. + + Args: + track_score: A float from 0-1 representing the output used in the tracker for assignment. + """ + self._track_score = track_score diff --git a/biogtr/io/track.py b/biogtr/io/track.py new file mode 100644 index 00000000..8ee27f3f --- /dev/null +++ b/biogtr/io/track.py @@ -0,0 +1,94 @@ +"""Module containing data structures for storing instances of the same Track.""" + +import attrs +from typing import Union + + +@attrs.define(eq=False) +class Track: + """Object for storing instances of the same track. + + Attributes: + id: the track label. + instances: A list of instances belonging to the track. + """ + + _id: int = attrs.field(alias="id") + _instances: list["Instance"] = attrs.field(alias="instances", factory=list) + + def __repr__(self) -> str: + """Get the string representation of the track. + + Returns: + the string representation of the Track. + """ + return f"Track(id={self.id}, len={len(self)})" + + @property + def track_id(self) -> int: + """Get the id of the track. + + Returns: + The integer id of the track. + """ + return self._id + + @track_id.setter + def track_id(self, track_id: int) -> None: + """Set the id of the track. + + Args: + track_id: the int id of the track. + """ + self._id = track_id + + @property + def instances(self) -> list["Instances"]: + """Get the instances belonging to this track. + + Returns: + A list of instances with this track id. + """ + return self._instances + + @instances.setter + def instances(self, instances) -> None: + """Set the instances belonging to this track. + + Args: + instances: A list of instances that belong to the same track. + """ + self._instances = instances + + @property + def frames(self) -> set["Frame"]: + """Get the frames where this track appears. + + Returns: + A set of `Frame` objects where this track appears. + """ + return set([instance.frame for instance in self.instances]) + + def __len__(self) -> int: + """Get the length of the track. + + Returns: + The number of instances/frames in the track. + """ + return len(self.instances) + + def __getitem__(self, ind) -> Union["Instance", list["Instance"]]: + """Get an instance from the track. + + Args: + ind: Either a single int or list of int indices. + + Returns: + the instance at that index of the track.instances. + """ + if isinstance(ind, int): + return self.instances[ind] + elif isinstance(ind, list[int]): + return [self.instances[i] for i in ind] + else: + raise ValueError(f"Ind must be an int or list of ints, found {type(ind)}") diff --git a/biogtr/visualize.py b/biogtr/io/visualize.py similarity index 100% rename from biogtr/visualize.py rename to biogtr/io/visualize.py diff --git a/biogtr/models/global_tracking_transformer.py b/biogtr/models/global_tracking_transformer.py index 59206c17..b556ee1e 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/biogtr/models/global_tracking_transformer.py @@ -2,7 +2,7 @@ from biogtr.models.transformer import Transformer from biogtr.models.visual_encoder import VisualEncoder -from biogtr.data_structures import Instance +from biogtr.io.instance import Instance import torch # todo: do we want to handle params with configs already here? @@ -81,7 +81,7 @@ def __init__( def forward( self, ref_instances: list[Instance], query_instances: list[Instance] = None - ): + ) -> list["AssociationMatrix"]: """Execute forward pass of GTR Model to get asso matrix. Args: @@ -97,9 +97,9 @@ def forward( if query_instances: self.extract_features(query_instances) - asso_preds, emb = self.transformer(ref_instances, query_instances) + asso_preds = self.transformer(ref_instances, query_instances) - return asso_preds, emb + return asso_preds def extract_features( self, instances: list["Instance"], force_recompute: bool = False diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index 98955d23..b37ed718 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -8,7 +8,8 @@ from biogtr.training.losses import AssoLoss from biogtr.models.model_utils import init_optimizer, init_scheduler from pytorch_lightning import LightningModule -from biogtr.data_structures import Frame, Instance +from biogtr.io.frame import Frame +from biogtr.io.instance import Instance class GTRRunner(LightningModule): @@ -72,7 +73,7 @@ def forward( Returns: An association matrix between objects """ - asso_preds, _ = self.model(ref_instances, query_instances) + asso_preds = self.model(ref_instances, query_instances) return asso_preds def training_step( @@ -165,6 +166,7 @@ def _shared_eval_step(self, frames: list[Frame], mode: str) -> dict[str, float]: persistent_tracking = self.persistent_tracking[mode] logits = self(instances) + logits = [asso.matrix for asso in logits] loss = self.loss(logits, frames) return_metrics = {"loss": loss} diff --git a/biogtr/models/model_utils.py b/biogtr/models/model_utils.py index d92a822b..11b3f168 100644 --- a/biogtr/models/model_utils.py +++ b/biogtr/models/model_utils.py @@ -2,7 +2,7 @@ from typing import List, Tuple, Iterable from pytorch_lightning import loggers -from biogtr.data_structures import Instance +from biogtr.io.instance import Instance import torch diff --git a/biogtr/models/transformer.py b/biogtr/models/transformer.py index 4951c3e5..68c499a7 100644 --- a/biogtr/models/transformer.py +++ b/biogtr/models/transformer.py @@ -11,7 +11,8 @@ * added fixed embeddings over boxes """ -from biogtr.data_structures import Instance +from biogtr.io.instance import Instance +from biogtr.io.association_matrix import AssociationMatrix from biogtr.models.attention_head import ATTWeightHead from biogtr.models.embedding import Embedding from biogtr.models.model_utils import get_boxes, get_times @@ -141,7 +142,7 @@ def _reset_parameters(self): def forward( self, ref_instances: list[Instance], query_instances: list[Instance] = None - ) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]: + ) -> list[AssociationMatrix]: """Execute a forward pass through the transformer and attention head. Args: @@ -153,8 +154,6 @@ def forward( L: number of decoder blocks n_query: number of instances in current query/frame total_instances: number of instances in window - embedding_dict: A dictionary containing the "pos" and "temp" embeddings - if `self.return_embeddings` is False then they are None. """ ref_features = torch.cat( [instance.features for instance in ref_instances], dim=0 @@ -164,10 +163,6 @@ def forward( # instances_per_frame = [frame.num_detected for frame in frames] total_instances = len(ref_instances) embed_dim = ref_features.shape[-1] - embeddings_dict = { - "ref": {"pos": None, "temp": None}, - "query": {"pos": None, "temp": None}, - } # print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}') ref_boxes = get_boxes(ref_instances) # total_instances, 4 ref_boxes = torch.nan_to_num(ref_boxes, -1.0) @@ -176,12 +171,13 @@ def forward( window_length = len(ref_times.unique()) ref_temp_emb = self.temp_emb(ref_times / window_length) - if self.return_embedding: - embeddings_dict["ref"]["temp"] = ref_temp_emb ref_pos_emb = self.pos_emb(ref_boxes) + if self.return_embedding: - embeddings_dict["ref"]["pos"] = ref_pos_emb + for i, instance in enumerate(ref_instances): + instance.add_embedding("pos", ref_pos_emb[i]) + instance.add_embedding("temp", ref_temp_emb[i]) ref_emb = (ref_pos_emb + ref_temp_emb) / 2.0 @@ -222,18 +218,21 @@ def forward( query_boxes = get_boxes(query_instances) query_temp_emb = self.temp_emb(query_times / window_length) - if self.return_embedding: - embeddings_dict["query"]["temp"] = query_temp_emb query_pos_emb = self.pos_emb(query_boxes) - if self.return_embedding: - embeddings_dict["query"]["pos"] = query_pos_emb query_emb = (query_pos_emb + query_temp_emb) / 2.0 query_emb = query_emb.view(1, n_query, embed_dim) query_emb = query_emb.permute(1, 0, 2) # (n_query, batch_size, embed_dim) + else: + query_instances = ref_instances + + if self.return_embedding: + for i, instance in enumerate(query_instances): + instance.add_embedding("pos", query_pos_emb[i]) + instance.add_embedding("temp", query_temp_emb[i]) decoder_features = self.decoder( query_features, @@ -251,16 +250,15 @@ def forward( asso_output = [] for frame_features in decoder_features: - # x: (batch_size=1, n_query, embed_dim=512) - - asso_output.append( - self.attn_head(frame_features, encoder_features).view( - n_query, total_instances - ) + asso_matrix = self.attn_head(frame_features, encoder_features).view( + n_query, total_instances ) + asso_matrix = AssociationMatrix(asso_matrix, ref_instances, query_instances) + + asso_output.append(asso_matrix) # (L=1, n_query, total_instances) - return (asso_output, embeddings_dict) + return asso_output class TransformerEncoderLayer(nn.Module): diff --git a/biogtr/training/losses.py b/biogtr/training/losses.py index b6f1d5e8..a9e721cd 100644 --- a/biogtr/training/losses.py +++ b/biogtr/training/losses.py @@ -1,6 +1,6 @@ """Module containing different loss functions to be optimized.""" -from biogtr.data_structures import Frame +from biogtr.io.frame import Frame from biogtr.models.model_utils import get_boxes, get_times from torch import nn from typing import List, Tuple diff --git a/biogtr/training/train.py b/biogtr/training/train.py index de4617f1..ac98b305 100644 --- a/biogtr/training/train.py +++ b/biogtr/training/train.py @@ -3,7 +3,7 @@ Used for training a single model or deploying a batch train job on RUNAI CLI """ -from biogtr.config import Config +from biogtr.io.config import Config from biogtr.datasets.tracking_dataset import TrackingDataset from biogtr.datasets.data_utils import view_training_batch from multiprocessing import cpu_count diff --git a/tests/test_config.py b/tests/test_config.py index aa10e040..1d8c64cd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,7 @@ """Tests for `config.py`""" from omegaconf import OmegaConf -from biogtr.config import Config +from biogtr.io.config import Config from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer from biogtr.models.gtr_runner import GTRRunner diff --git a/tests/test_data_structures.py b/tests/test_data_model.py similarity index 56% rename from tests/test_data_structures.py rename to tests/test_data_model.py index 4b91b889..cbe5aa8c 100644 --- a/tests/test_data_structures.py +++ b/tests/test_data_model.py @@ -1,8 +1,13 @@ -"""Tests for Instance, Frame, and TrackQueue Object""" +"""Tests for Instance, Frame, and AssociationMatrix Objects""" -from biogtr.data_structures import Instance, Frame -from biogtr.inference.track_queue import TrackQueue +from biogtr.io.frame import Frame +from biogtr.io.instance import Instance +from biogtr.io.association_matrix import AssociationMatrix +from biogtr.io.track import Track import torch +import pytest +import numpy as np +import pandas as pd def test_instance(): @@ -68,7 +73,6 @@ def test_frame(): video_id = 0 frame_id = 0 img_shape = torch.tensor([3, 1024, 1024]) - asso_output = torch.randn(n_detected, 16) traj_score = torch.randn(n_detected, n_traj) matches = ([0, 1], [0, 1]) @@ -101,14 +105,15 @@ def test_frame(): assert not frame.has_asso_output() assert not frame.has_traj_score() - frame.asso_output = asso_output + asso_output = torch.randn(len(instances), len(instances)) + frame.asso_output = AssociationMatrix(asso_output, instances, instances) frame.add_traj_score("initial", traj_score) frame.matches = matches assert frame.has_matches() assert frame.matches == matches assert frame.has_asso_output() - assert torch.equal(frame.asso_output, asso_output) + assert torch.equal(frame.asso_output.matrix, asso_output) assert frame.has_traj_score() assert torch.equal(frame.get_traj_score("initial"), traj_score) @@ -127,79 +132,79 @@ def test_frame(): assert frame.has_traj_score() -def test_track_queue(): - window_size = 8 - max_gap = 10 - img_shape = (3, 1024, 1024) - n_instances_per_frame = [2] * window_size - - frames = [] - instances_per_frame = [] - - tq = TrackQueue(window_size, max_gap) - for i in range(window_size): - instances = [] - for j in range(n_instances_per_frame[i]): - instances.append(Instance(gt_track_id=j, pred_track_id=j)) - instances_per_frame.append(instances) - frame = Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances) - frames.append(frame) - - tq.add_frame(frame) - - assert len(tq) == sum(n_instances_per_frame[1:]) - assert tq.n_tracks == max(n_instances_per_frame) - assert tq.tracks == [i for i in range(max(n_instances_per_frame))] - assert len(tq.collate_tracks()) == window_size - 1 - assert all([gap == 0 for gap in tq._curr_gap.values()]) - assert tq.curr_track == max(n_instances_per_frame) - 1 - - tq.add_frame( - Frame( - video_id=0, - frame_id=window_size + 1, - img_shape=img_shape, - instances=[Instance(gt_track_id=0, pred_track_id=0)], - ) - ) +def test_association_matrix(): - assert len(tq._queues[0]) == window_size - 1 - assert tq._curr_gap[0] == 0 - assert tq._curr_gap[max(n_instances_per_frame) - 1] == 1 - - tq.add_frame( - Frame( - video_id=0, - frame_id=window_size + 1, - img_shape=img_shape, - instances=[ - Instance(gt_track_id=1, pred_track_id=1), - Instance( - gt_track_id=max(n_instances_per_frame), - pred_track_id=max(n_instances_per_frame), - ), - ], - ) + n_traj = 2 + total_instances = 32 + n_query = 2 + + instances = [ + Instance(gt_track_id=i % n_traj, pred_track_id=i % n_traj) + for i in range(total_instances) + ] + + query_instances = instances[-n_query:] + asso_tensor = np.random.rand(total_instances, total_instances) + query_tensor = np.random.rand(n_query, total_instances) + + with pytest.raises(ValueError): + _ = AssociationMatrix(asso_tensor, instances, query_instances) + _ = AssociationMatrix(asso_tensor, query_instances, query_instances) + + asso_matrix = AssociationMatrix(asso_tensor, instances, instances) + + assert isinstance(asso_matrix.numpy(), np.ndarray) + + asso_lookup = asso_matrix[instances[0], instances[0]] + assert asso_lookup.item() == asso_tensor[0, 0].item() + + inds = (-1, -1) + asso_lookup = asso_matrix[inds] + assert asso_lookup.item() == asso_tensor[-1, -1].item() + + inds = (instances[:2], instances[:-2]) + asso_lookup = asso_matrix[inds] + assert np.equal(asso_lookup, asso_tensor[:2, :-2]).all() + + inds = ([2, 3], [2, 3]) + asso_lookup = asso_matrix[inds] + assert np.equal(asso_lookup, asso_tensor[np.array(inds[0])[:, None], inds[1]]).all() + + asso_lookup = asso_matrix[instances[:2], None] + assert np.equal(asso_lookup, asso_tensor[:2, :]).all() + + with pytest.raises(ValueError): + _ = AssociationMatrix(query_tensor, instances, instances) + _ = AssociationMatrix(query_tensor, query_instances, instances) + _ = AssociationMatrix(query_tensor, query_instances, query_instances) + + query_matrix = AssociationMatrix(query_tensor, instances, query_instances) + + with pytest.raises(ValueError): + _ = query_matrix[instances[0], instances[0]] + + query_lookup = query_matrix[query_instances[0], instances[0]] + assert query_lookup.item() == query_tensor[0, 0].item() + + traj_score = pd.concat( + [ + query_matrix.to_dataframe(row_labels="inst").drop(1, axis=1).sum(1), + query_matrix.to_dataframe(row_labels="inst").drop(0, axis=1).sum(1), + ], + axis=1, ) + assert (query_matrix.reduce() == traj_score).all().all() - assert len(tq._queues[max(n_instances_per_frame)]) == 1 - assert tq._curr_gap[1] == 0 - assert tq._curr_gap[0] == 1 - - for i in range(max_gap): - tq.add_frame( - Frame( - video_id=0, - frame_id=window_size + i + 1, - img_shape=img_shape, - instances=[Instance(gt_track_id=0, pred_track_id=0)], - ) - ) - assert tq.n_tracks == 1 - assert tq.curr_track == max(n_instances_per_frame) - assert 0 in tq._queues.keys() +def test_track(): + + instances = [Instance(gt_track_id=0, pred_track_id=0) for i in range(32)] + + track = Track(0, instances=instances) + + assert track.track_id == 0 + assert len(track) == len(instances) - tq.end_tracks() + instance = track[1] - assert len(tq) == 0 + assert instance is instances[1] diff --git a/tests/test_inference.py b/tests/test_inference.py index c37470be..66c5ac61 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -3,13 +3,93 @@ import torch import pytest import numpy as np -from biogtr.data_structures import Frame, Instance +from biogtr.io.frame import Frame +from biogtr.io.instance import Instance from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer from biogtr.inference.tracker import Tracker +from biogtr.inference.track_queue import TrackQueue from biogtr.inference import post_processing from biogtr.inference import metrics +def test_track_queue(): + window_size = 8 + max_gap = 10 + img_shape = (3, 1024, 1024) + n_instances_per_frame = [2] * window_size + + frames = [] + instances_per_frame = [] + + tq = TrackQueue(window_size, max_gap) + for i in range(window_size): + instances = [] + for j in range(n_instances_per_frame[i]): + instances.append(Instance(gt_track_id=j, pred_track_id=j)) + instances_per_frame.append(instances) + frame = Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances) + frames.append(frame) + + tq.add_frame(frame) + + assert len(tq) == sum(n_instances_per_frame[1:]) + assert tq.n_tracks == max(n_instances_per_frame) + assert tq.tracks == [i for i in range(max(n_instances_per_frame))] + assert len(tq.collate_tracks()) == window_size - 1 + assert all([gap == 0 for gap in tq._curr_gap.values()]) + assert tq.curr_track == max(n_instances_per_frame) - 1 + + tq.add_frame( + Frame( + video_id=0, + frame_id=window_size + 1, + img_shape=img_shape, + instances=[Instance(gt_track_id=0, pred_track_id=0)], + ) + ) + + assert len(tq._queues[0]) == window_size - 1 + assert tq._curr_gap[0] == 0 + assert tq._curr_gap[max(n_instances_per_frame) - 1] == 1 + + tq.add_frame( + Frame( + video_id=0, + frame_id=window_size + 1, + img_shape=img_shape, + instances=[ + Instance(gt_track_id=1, pred_track_id=1), + Instance( + gt_track_id=max(n_instances_per_frame), + pred_track_id=max(n_instances_per_frame), + ), + ], + ) + ) + + assert len(tq._queues[max(n_instances_per_frame)]) == 1 + assert tq._curr_gap[1] == 0 + assert tq._curr_gap[0] == 1 + + for i in range(max_gap): + tq.add_frame( + Frame( + video_id=0, + frame_id=window_size + i + 1, + img_shape=img_shape, + instances=[Instance(gt_track_id=0, pred_track_id=0)], + ) + ) + + assert tq.n_tracks == 1 + assert tq.curr_track == max(n_instances_per_frame) + assert 0 in tq._queues.keys() + + tq.end_tracks() + + assert len(tq) == 0 + + def test_tracker(): """Test tracker module. diff --git a/tests/test_models.py b/tests/test_models.py index 08412e49..0c125d24 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,7 +2,8 @@ import pytest import torch -from biogtr.data_structures import Frame, Instance +from biogtr.io.frame import Frame +from biogtr.io.instance import Instance from biogtr.models.mlp import MLP from biogtr.models.attention_head import ATTWeightHead from biogtr.models.embedding import Embedding @@ -415,9 +416,9 @@ def test_transformer_basic(): ) instances = [instance for frame in frames for instance in frame.instances] - asso_preds, _ = transformer(instances) + asso_preds = transformer(instances) - assert asso_preds[0].size() == (num_detected * num_frames,) * 2 + assert asso_preds[0].matrix.size() == (num_detected * num_frames,) * 2 def test_transformer_embedding(): @@ -457,15 +458,26 @@ def test_transformer_embedding(): assert transformer.pos_emb.mode == "learned" assert transformer.temp_emb.mode == "learned" - asso_preds, embeddings = transformer(instances) + asso_preds = transformer(instances) - assert asso_preds[0].size() == (num_detected * num_frames,) * 2 + assert asso_preds[0].matrix.size() == (num_detected * num_frames,) * 2 - for emb_type, embedding in embeddings["ref"].items(): - assert embedding.size() == ( - num_detected * num_frames, - feats, - ), f"{emb_type}, {embedding.size()}" + pos_emb = torch.concat( + [instance.get_embedding("pos") for instance in instances], axis=0 + ) + temp_emb = torch.concat( + [instance.get_embedding("pos") for instance in instances], axis=0 + ) + + assert pos_emb.size() == ( + len(instances), + feats, + ), pos_emb.shape + + assert temp_emb.size() == ( + len(instances), + feats, + ), temp_emb.shape def test_tracking_transformer(): @@ -510,12 +522,23 @@ def test_tracking_transformer(): return_embedding=True, ) instances = [instance for frame in frames for instance in frame.instances] - asso_preds, embeddings = tracking_transformer(instances) + asso_preds = tracking_transformer(instances) + + assert asso_preds[0].matrix.size() == (num_detected * num_frames,) * 2 + + pos_emb = torch.concat( + [instance.get_embedding("pos") for instance in instances], axis=0 + ) + temp_emb = torch.concat( + [instance.get_embedding("pos") for instance in instances], axis=0 + ) - assert asso_preds[0].size() == (num_detected * num_frames,) * 2 + assert pos_emb.size() == ( + len(instances), + feats, + ), pos_emb.shape - for emb_type, embedding in embeddings["ref"].items(): - assert embedding.size() == ( - num_detected * num_frames, - feats, - ), embeddings + assert temp_emb.size() == ( + len(instances), + feats, + ), temp_emb.shape diff --git a/tests/test_training.py b/tests/test_training.py index b15852a1..f5331a58 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -3,12 +3,13 @@ import os import pytest import torch -from biogtr.data_structures import Frame, Instance +from biogtr.io.frame import Frame +from biogtr.io.instance import Instance from biogtr.training.losses import AssoLoss from biogtr.models.gtr_runner import GTRRunner from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer from omegaconf import OmegaConf, DictConfig -from biogtr.config import Config +from biogtr.io.config import Config from biogtr.training.train import main # TODO: add named tensor tests From f8a33df75f836e72e50576a9ddf16b999cb48974 Mon Sep 17 00:00:00 2001 From: Aaditya Prasad <78439225+aaprasad@users.noreply.github.com> Date: Mon, 3 Jun 2024 14:38:12 -0700 Subject: [PATCH 3/7] Remove kwargs and mutable defaults (#48) --- biogtr/inference/tracker.py | 3 +- biogtr/io/association_matrix.py | 5 +-- biogtr/io/config.py | 15 +++++-- biogtr/models/global_tracking_transformer.py | 1 - biogtr/models/gtr_runner.py | 43 +++++++++++++------- 5 files changed, 42 insertions(+), 25 deletions(-) diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 213914c2..5f6fce84 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -128,8 +128,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame Args: model: the pretrained GlobalTrackingTransformer to be used for inference - frames: A list of Frames (See `biogtr.io.data_structures.Frame` for more info). - + frames: A list of Frames (See `biogtr.io.Frame` for more info). Returns: Frames: A list of Frames populated with pred_track_ids and asso_matrices diff --git a/biogtr/io/association_matrix.py b/biogtr/io/association_matrix.py index a5f4a3a7..1447249c 100644 --- a/biogtr/io/association_matrix.py +++ b/biogtr/io/association_matrix.py @@ -173,7 +173,7 @@ def reduce( Either "instance" (remains unchanged), or "track" (n_cols=n_traj) row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt". col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt". - method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing. + reduce_method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing. Returns: The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe. @@ -199,7 +199,6 @@ def reduce( reduced_matrix = [] for row_track, row_instances in row_tracks.items(): - for col_track, col_instances in col_tracks.items(): asso_matrix = self[row_instances, col_instances] @@ -208,7 +207,6 @@ def reduce( if row_dims == "track": asso_matrix = reduce_method(asso_matrix, axis=0) - reduced_matrix.append(asso_matrix) reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T @@ -234,7 +232,6 @@ def __getitem__(self, inds) -> np.ndarray: try: return self.numpy()[query_ind[:, None], ref_ind].squeeze() - except IndexError as e: print(f"Query_insts: {type(query_inst)}") print(f"Query_inds: {query_ind}") diff --git a/biogtr/io/config.py b/biogtr/io/config.py index f30d3128..2d83be38 100644 --- a/biogtr/io/config.py +++ b/biogtr/io/config.py @@ -78,6 +78,11 @@ def get_model(self) -> GlobalTrackingTransformer: A global tracking transformer with parameters indicated by cfg """ model_params = self.cfg.model + ckpt_path = model_params.pop("ckpt_path", None) + + if ckpt_path is not None and len(ckpt_path) > 0: + return GTRRunner.load_from_checkpoint(ckpt_path).model + return GlobalTrackingTransformer(**model_params) def get_tracker_cfg(self) -> dict: @@ -100,9 +105,14 @@ def get_gtr_runner(self): loss_params = self.cfg.loss gtr_runner_params = self.cfg.runner - if self.cfg.model.ckpt_path is not None and self.cfg.model.ckpt_path != "": + model_params = self.cfg.model + + ckpt_path = model_params.pop("ckpt_path", None) + + if ckpt_path is not None and ckpt_path != "": + model = GTRRunner.load_from_checkpoint( - self.cfg.model.ckpt_path, + ckpt_path, tracker_cfg=tracker_params, train_metrics=self.cfg.runner.metrics.train, val_metrics=self.cfg.runner.metrics.val, @@ -110,7 +120,6 @@ def get_gtr_runner(self): ) else: - model_params = self.cfg.model model = GTRRunner( model_params, tracker_params, diff --git a/biogtr/models/global_tracking_transformer.py b/biogtr/models/global_tracking_transformer.py index b556ee1e..4a74d6b5 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/biogtr/models/global_tracking_transformer.py @@ -27,7 +27,6 @@ def __init__( embedding_meta: dict = None, return_embedding: bool = False, decoder_self_attn: bool = False, - **kwargs, ): """Initialize GTR. diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index b37ed718..fc9afeb4 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -18,23 +18,26 @@ class GTRRunner(LightningModule): Used for training, validation and inference. """ + DEFAULT_METRICS = { + "train": [], + "val": ["num_switches"], + "test": ["num_switches"], + } + DEFAULT_TRACKING = { + "train": False, + "val": True, + "test": True, + } + def __init__( self, - model_cfg: dict = {}, - tracker_cfg: dict = {}, - loss_cfg: dict = {}, + model_cfg: dict = None, + tracker_cfg: dict = None, + loss_cfg: dict = None, optimizer_cfg: dict = None, scheduler_cfg: dict = None, - metrics: dict[str, list[str]] = { - "train": [], - "val": ["num_switches"], - "test": ["num_switches"], - }, - persistent_tracking: dict[str, bool] = { - "train": False, - "val": True, - "test": True, - }, + metrics: dict[str, list[str]] = None, + persistent_tracking: dict[str, bool] = None, ): """Initialize a lightning module for GTR. @@ -51,6 +54,11 @@ def __init__( super().__init__() self.save_hyperparameters() + model_cfg = model_cfg if model_cfg else {} + loss_cfg = loss_cfg if loss_cfg else {} + tracker_cfg = tracker_cfg if tracker_cfg else {} + + _ = model_cfg.pop("ckpt_path", None) self.model = GlobalTrackingTransformer(**model_cfg) self.loss = AssoLoss(**loss_cfg) self.tracker = Tracker(**tracker_cfg) @@ -58,8 +66,12 @@ def __init__( self.optimizer_cfg = optimizer_cfg self.scheduler_cfg = scheduler_cfg - self.metrics = metrics - self.persistent_tracking = persistent_tracking + self.metrics = metrics if metrics is not None else self.DEFAULT_METRICS + self.persistent_tracking = ( + persistent_tracking + if persistent_tracking is not None + else self.DEFAULT_TRACKING + ) def forward( self, ref_instances: list[Instance], query_instances: list[Instance] = None @@ -159,6 +171,7 @@ def _shared_eval_step(self, frames: list[Frame], mode: str) -> dict[str, float]: """ try: instances = [instance for frame in frames for instance in frame.instances] + if len(instances) == 0: return None From 1ad8a478f195b60a831b0a9beb86f9b1a2420b50 Mon Sep 17 00:00:00 2001 From: Aaditya Prasad <78439225+aaprasad@users.noreply.github.com> Date: Mon, 3 Jun 2024 15:51:25 -0700 Subject: [PATCH 4/7] Expose modules (#49) --- biogtr/__init__.py | 17 +++- biogtr/datasets/__init__.py | 7 ++ biogtr/datasets/base_dataset.py | 2 +- biogtr/datasets/cell_tracking_dataset.py | 6 +- biogtr/datasets/eval_dataset.py | 3 +- biogtr/datasets/microscopy_dataset.py | 6 +- biogtr/datasets/sleap_dataset.py | 6 +- biogtr/inference/__init__.py | 2 + biogtr/inference/metrics.py | 5 +- biogtr/inference/track.py | 7 +- biogtr/inference/track_queue.py | 2 +- biogtr/inference/tracker.py | 5 +- biogtr/io/__init__.py | 3 +- biogtr/io/association_matrix.py | 6 ++ biogtr/io/config.py | 35 +++++--- biogtr/io/frame.py | 3 +- biogtr/models/__init__.py | 9 +- biogtr/models/global_tracking_transformer.py | 7 +- biogtr/models/gtr_runner.py | 24 ++++-- biogtr/models/model_utils.py | 6 +- biogtr/models/transformer.py | 13 ++- biogtr/training/__init__.py | 2 + biogtr/training/losses.py | 3 +- biogtr/training/train.py | 8 +- tests/test_config.py | 5 +- tests/test_data_model.py | 5 +- tests/test_datasets.py | 12 +-- tests/test_inference.py | 87 ++++++++++++++++++-- tests/test_models.py | 15 ++-- tests/test_training.py | 11 +-- 30 files changed, 213 insertions(+), 109 deletions(-) diff --git a/biogtr/__init__.py b/biogtr/__init__.py index 6e9395e1..e4c823ef 100644 --- a/biogtr/__init__.py +++ b/biogtr/__init__.py @@ -1,7 +1,18 @@ """Top-level package for BioGTR.""" from biogtr.version import __version__ -from biogtr.models.attention_head import MLP, ATTWeightHead -from biogtr.models.visual_encoder import VisualEncoder -from biogtr.models.embedding import Embedding + +from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer +from biogtr.models.gtr_runner import GTRRunner from biogtr.models.transformer import Transformer +from biogtr.models.visual_encoder import VisualEncoder + +from biogtr.io.frame import Frame +from biogtr.io.instance import Instance +from biogtr.io.association_matrix import AssociationMatrix +from biogtr.io.config import Config +from biogtr.io.visualize import annotate_video + +# from .training import run + +from biogtr.inference.tracker import Tracker diff --git a/biogtr/datasets/__init__.py b/biogtr/datasets/__init__.py index 06f2e8bf..0a1fcc2f 100644 --- a/biogtr/datasets/__init__.py +++ b/biogtr/datasets/__init__.py @@ -1 +1,8 @@ """Data loading and preprocessing.""" + +from .base_dataset import BaseDataset +from .cell_tracking_dataset import CellTrackingDataset +from .eval_dataset import EvalDataset +from .microscopy_dataset import MicroscopyDataset +from .sleap_dataset import SleapDataset +from .tracking_dataset import TrackingDataset diff --git a/biogtr/datasets/base_dataset.py b/biogtr/datasets/base_dataset.py index 74ccf448..6f54719d 100644 --- a/biogtr/datasets/base_dataset.py +++ b/biogtr/datasets/base_dataset.py @@ -1,7 +1,7 @@ """Module containing logic for loading datasets.""" from biogtr.datasets import data_utils -from biogtr.io.frame import Frame +from biogtr.io import Frame from torch.utils.data import Dataset from typing import List, Union import numpy as np diff --git a/biogtr/datasets/cell_tracking_dataset.py b/biogtr/datasets/cell_tracking_dataset.py index 65889526..33aa12fa 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/biogtr/datasets/cell_tracking_dataset.py @@ -1,10 +1,8 @@ """Module containing cell tracking challenge dataset.""" from PIL import Image -from biogtr.datasets import data_utils -from biogtr.datasets.base_dataset import BaseDataset -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance +from biogtr.datasets import data_utils, BaseDataset +from biogtr.io import Frame, Instance from scipy.ndimage import measurements from typing import List, Optional, Union import albumentations as A diff --git a/biogtr/datasets/eval_dataset.py b/biogtr/datasets/eval_dataset.py index 6f52a8c9..95836a51 100644 --- a/biogtr/datasets/eval_dataset.py +++ b/biogtr/datasets/eval_dataset.py @@ -1,8 +1,7 @@ """Module containing wrapper for merging gt and pred datasets for evaluation.""" from torch.utils.data import Dataset -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance +from biogtr.io import Instance, Frame from typing import List diff --git a/biogtr/datasets/microscopy_dataset.py b/biogtr/datasets/microscopy_dataset.py index acb50c3f..2c15cc1e 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/biogtr/datasets/microscopy_dataset.py @@ -1,10 +1,8 @@ """Module containing microscopy dataset.""" from PIL import Image -from biogtr.datasets import data_utils -from biogtr.datasets.base_dataset import BaseDataset -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance +from biogtr.datasets import data_utils, BaseDataset +from biogtr.io import Instance, Frame from typing import Union import albumentations as A import numpy as np diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index 24d0e1ec..6a934b4e 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -7,10 +7,8 @@ import sleap_io as sio import random import warnings -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance -from biogtr.datasets import data_utils -from biogtr.datasets.base_dataset import BaseDataset +from biogtr.io import Instance, Frame +from biogtr.datasets import data_utils, BaseDataset from torchvision.transforms import functional as tvf from typing import List, Union diff --git a/biogtr/inference/__init__.py b/biogtr/inference/__init__.py index c1c53dce..6615f0a9 100644 --- a/biogtr/inference/__init__.py +++ b/biogtr/inference/__init__.py @@ -1 +1,3 @@ """Tracking Inference using GTR Model.""" + +from .tracker import Tracker diff --git a/biogtr/inference/metrics.py b/biogtr/inference/metrics.py index ed961a96..c80c15c3 100644 --- a/biogtr/inference/metrics.py +++ b/biogtr/inference/metrics.py @@ -3,14 +3,13 @@ import numpy as np import motmetrics as mm import torch -from biogtr.io.frame import Frame from typing import Union, Iterable # from biogtr.inference.post_processing import _pairwise_iou # from biogtr.inference.boxes import Boxes -def get_matches(frames: list[Frame]) -> tuple[dict, list, int]: +def get_matches(frames: list["biogtr.io.Frame"]) -> tuple[dict, list, int]: """Get comparison between predicted and gt trajectory labels. Args: @@ -101,7 +100,7 @@ def get_switch_count(switches: dict) -> int: return sw_cnt -def to_track_eval(frames: list[Frame]) -> dict: +def to_track_eval(frames: list["biogtr.io.Frame"]) -> dict: """Reformats frames the output from `sliding_inference` to be used by `TrackEval`. Args: diff --git a/biogtr/inference/track.py b/biogtr/inference/track.py index 24b6fefa..c930cf62 100644 --- a/biogtr/inference/track.py +++ b/biogtr/inference/track.py @@ -1,8 +1,7 @@ """Script to run inference and get out tracks.""" -from biogtr.io.config import Config -from biogtr.models.gtr_runner import GTRRunner -from biogtr.io.frame import Frame +from biogtr.io import Config +from biogtr.models import GTRRunner from omegaconf import DictConfig from pathlib import Path from pprint import pprint @@ -14,7 +13,7 @@ import torch -def export_trajectories(frames_pred: list[Frame], save_path: str = None): +def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = None): """Convert trajectories to data frame and save as .csv. Args: diff --git a/biogtr/inference/track_queue.py b/biogtr/inference/track_queue.py index 89d8991f..739869fb 100644 --- a/biogtr/inference/track_queue.py +++ b/biogtr/inference/track_queue.py @@ -1,7 +1,7 @@ """Module handling sliding window tracking.""" import warnings -from biogtr.io.frame import Frame +from biogtr.io import Frame from collections import deque import numpy as np diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 5f6fce84..da1320be 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -3,9 +3,8 @@ import torch import pandas as pd import warnings -from biogtr.io.frame import Frame -from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer -from biogtr.models import model_utils +from biogtr.io import Frame +from biogtr.models import model_utils, GlobalTrackingTransformer from biogtr.inference.track_queue import TrackQueue from biogtr.inference import post_processing from biogtr.inference.boxes import Boxes diff --git a/biogtr/io/__init__.py b/biogtr/io/__init__.py index eec945d8..0cda02ae 100644 --- a/biogtr/io/__init__.py +++ b/biogtr/io/__init__.py @@ -4,5 +4,4 @@ from biogtr.io.instance import Instance from biogtr.io.association_matrix import AssociationMatrix from biogtr.io.track import Track - -# TODO: expose config without circular import error from biogtr.io.config import Config +from biogtr.io.config import Config diff --git a/biogtr/io/association_matrix.py b/biogtr/io/association_matrix.py index 1447249c..9d6d366e 100644 --- a/biogtr/io/association_matrix.py +++ b/biogtr/io/association_matrix.py @@ -107,6 +107,7 @@ def to_dataframe( if not isinstance(row_labels, str): if len(row_labels) == len(self.query_instances): row_inds = row_labels + else: raise ValueError( ( @@ -114,6 +115,7 @@ def to_dataframe( f"Found {len(row_labels)} with {len(self.query_instances)} rows", ) ) + else: if row_labels == "gt": row_inds = [ @@ -131,6 +133,7 @@ def to_dataframe( if not isinstance(col_labels, str): if len(col_labels) == len(self.ref_instances): col_inds = col_labels + else: raise ValueError( ( @@ -138,6 +141,7 @@ def to_dataframe( f"Found {len(col_labels)} with {len(self.ref_instances)} columns", ) ) + else: if col_labels == "gt": col_inds = [ @@ -200,6 +204,7 @@ def reduce( reduced_matrix = [] for row_track, row_instances in row_tracks.items(): for col_track, col_instances in col_tracks.items(): + asso_matrix = self[row_instances, col_instances] if col_dims == "track": @@ -207,6 +212,7 @@ def reduce( if row_dims == "track": asso_matrix = reduce_method(asso_matrix, axis=0) + reduced_matrix.append(asso_matrix) reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T diff --git a/biogtr/io/config.py b/biogtr/io/config.py index 2d83be38..280a1b04 100644 --- a/biogtr/io/config.py +++ b/biogtr/io/config.py @@ -1,13 +1,6 @@ # to implement - config class that handles getters/setters """Data structures for handling config parsing.""" -from biogtr.datasets.microscopy_dataset import MicroscopyDataset -from biogtr.datasets.sleap_dataset import SleapDataset -from biogtr.models.model_utils import init_optimizer, init_scheduler, init_logger -from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset -from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer -from biogtr.models.gtr_runner import GTRRunner -from biogtr.training.losses import AssoLoss from omegaconf import DictConfig, OmegaConf from pprint import pprint from typing import Union, Iterable @@ -71,12 +64,14 @@ def set_hparams(self, hparams: dict) -> bool: return False return True - def get_model(self) -> GlobalTrackingTransformer: + def get_model(self) -> "GlobalTrackingTransformer": """Getter for gtr model. Returns: A global tracking transformer with parameters indicated by cfg """ + from biogtr.models import GlobalTrackingTransformer + model_params = self.cfg.model ckpt_path = model_params.pop("ckpt_path", None) @@ -99,18 +94,18 @@ def get_tracker_cfg(self) -> dict: def get_gtr_runner(self): """Get lightning module for training, validation, and inference.""" + from biogtr.models import GTRRunner + tracker_params = self.cfg.tracker optimizer_params = self.cfg.optimizer scheduler_params = self.cfg.scheduler loss_params = self.cfg.loss gtr_runner_params = self.cfg.runner - model_params = self.cfg.model ckpt_path = model_params.pop("ckpt_path", None) if ckpt_path is not None and ckpt_path != "": - model = GTRRunner.load_from_checkpoint( ckpt_path, tracker_cfg=tracker_params, @@ -133,7 +128,7 @@ def get_gtr_runner(self): def get_dataset( self, mode: str - ) -> Union[SleapDataset, MicroscopyDataset, CellTrackingDataset]: + ) -> Union["SleapDataset", "MicroscopyDataset", "CellTrackingDataset"]: """Getter for datasets. Args: @@ -143,6 +138,8 @@ def get_dataset( Returns: Either a `SleapDataset` or `MicroscopyDataset` with params indicated by cfg """ + from biogtr.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset + if mode.lower() == "train": dataset_params = self.cfg.dataset.train_dataset elif mode.lower() == "val": @@ -169,7 +166,7 @@ def get_dataset( def get_dataloader( self, - dataset: Union[SleapDataset, MicroscopyDataset, CellTrackingDataset], + dataset: Union["SleapDataset", "MicroscopyDataset", "CellTrackingDataset"], mode: str, ) -> torch.utils.data.DataLoader: """Getter for dataloader. @@ -217,7 +214,10 @@ def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer: Returns: A torch Optimizer with specified params """ + from biogtr.models.model_utils import init_optimizer + optimizer_params = self.cfg.optimizer + return init_optimizer(params, optimizer_params) def get_scheduler( @@ -231,16 +231,22 @@ def get_scheduler( Returns: A torch learning rate scheduler with specified params """ + from biogtr.models.model_utils import init_scheduler + lr_scheduler_params = self.cfg.scheduler + return init_scheduler(optimizer, lr_scheduler_params) - def get_loss(self) -> AssoLoss: + def get_loss(self) -> "biogtr.training.losses.AssoLoss": """Getter for loss functions. Returns: An AssoLoss with specified params """ + from biogtr.training.losses import AssoLoss + loss_params = self.cfg.loss + return AssoLoss(**loss_params) def get_logger(self): @@ -249,7 +255,10 @@ def get_logger(self): Returns: A Logger with specified params """ + from biogtr.models.model_utils import init_logger + logger_params = OmegaConf.to_container(self.cfg.logging, resolve=True) + return init_logger( logger_params, OmegaConf.to_container(self.cfg, resolve=True) ) diff --git a/biogtr/io/frame.py b/biogtr/io/frame.py index 248c0159..803970bc 100644 --- a/biogtr/io/frame.py +++ b/biogtr/io/frame.py @@ -6,7 +6,6 @@ import attrs from numpy.typing import ArrayLike from typing import Union, List -from biogtr.io.instance import Instance def _to_tensor(data: Union[float, ArrayLike]) -> torch.Tensor: @@ -138,6 +137,8 @@ def from_slp( Returns: A biogtr.io.Frame object """ + from biogtr.io import Instance + img_shape = lf.image.shape if len(img_shape) == 2: img_shape = (1, *img_shape) diff --git a/biogtr/models/__init__.py b/biogtr/models/__init__.py index 8f6968c8..bc6be168 100644 --- a/biogtr/models/__init__.py +++ b/biogtr/models/__init__.py @@ -1,7 +1,12 @@ """Model architectures and layers.""" -from .attention_head import ATTWeightHead - from .embedding import Embedding + +# from .mlp import MLP +# from .attention_head import ATTWeightHead + from .transformer import Transformer from .visual_encoder import VisualEncoder + +from .global_tracking_transformer import GlobalTrackingTransformer +from .gtr_runner import GTRRunner diff --git a/biogtr/models/global_tracking_transformer.py b/biogtr/models/global_tracking_transformer.py index 4a74d6b5..c746d1aa 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/biogtr/models/global_tracking_transformer.py @@ -1,8 +1,7 @@ """Module containing GTR model used for training.""" -from biogtr.models.transformer import Transformer -from biogtr.models.visual_encoder import VisualEncoder -from biogtr.io.instance import Instance +from biogtr.models import Transformer +from biogtr.models import VisualEncoder import torch # todo: do we want to handle params with configs already here? @@ -79,7 +78,7 @@ def __init__( ) def forward( - self, ref_instances: list[Instance], query_instances: list[Instance] = None + self, ref_instances: list["Instance"], query_instances: list["Instance"] = None ) -> list["AssociationMatrix"]: """Execute forward pass of GTR Model to get asso matrix. diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index fc9afeb4..8ddfe787 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -2,9 +2,9 @@ import torch import gc -from biogtr.inference.tracker import Tracker +from biogtr.inference import Tracker from biogtr.inference import metrics -from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer +from biogtr.models import GlobalTrackingTransformer from biogtr.training.losses import AssoLoss from biogtr.models.model_utils import init_optimizer, init_scheduler from pytorch_lightning import LightningModule @@ -74,7 +74,9 @@ def __init__( ) def forward( - self, ref_instances: list[Instance], query_instances: list[Instance] = None + self, + ref_instances: list["biogtr.io.Instance"], + query_instances: list["biogtr.io.Instance"] = None, ) -> torch.Tensor: """Execute forward pass of the lightning module. @@ -89,7 +91,7 @@ def forward( return asso_preds def training_step( - self, train_batch: list[list[Frame]], batch_idx: int + self, train_batch: list[list["biogtr.io.Frame"]], batch_idx: int ) -> dict[str, float]: """Execute single training step for model. @@ -107,7 +109,7 @@ def training_step( return result def validation_step( - self, val_batch: list[list[Frame]], batch_idx: int + self, val_batch: list[list["biogtr.io.Frame"]], batch_idx: int ) -> dict[str, float]: """Execute single val step for model. @@ -125,7 +127,7 @@ def validation_step( return result def test_step( - self, test_batch: list[list[Frame]], batch_idx: int + self, test_batch: list[list["biogtr.io.Frame"]], batch_idx: int ) -> dict[str, float]: """Execute single test step for model. @@ -142,7 +144,9 @@ def test_step( return result - def predict_step(self, batch: list[list[Frame]], batch_idx: int) -> list[Frame]: + def predict_step( + self, batch: list[list["biogtr.io.Frame"]], batch_idx: int + ) -> list["biogtr.io.Frame"]: """Run inference for model. Computes association + assignment. @@ -159,11 +163,13 @@ def predict_step(self, batch: list[list[Frame]], batch_idx: int) -> list[Frame]: frames_pred = self.tracker(self.model, batch[0]) return frames_pred - def _shared_eval_step(self, frames: list[Frame], mode: str) -> dict[str, float]: + def _shared_eval_step( + self, frames: list["biogtr.io.Frame"], mode: str + ) -> dict[str, float]: """Run evaluation used by train, test, and val steps. Args: - frames: A list of `Frame` objects with length `clip_length` containing Instances and other metadata. + frames: A list of dicts where each dict is a frame containing gt data mode: which metrics to compute and whether to use persistent tracking or not Returns: diff --git a/biogtr/models/model_utils.py b/biogtr/models/model_utils.py index 11b3f168..a2885a0f 100644 --- a/biogtr/models/model_utils.py +++ b/biogtr/models/model_utils.py @@ -2,11 +2,10 @@ from typing import List, Tuple, Iterable from pytorch_lightning import loggers -from biogtr.io.instance import Instance import torch -def get_boxes(instances: List[Instance]) -> torch.Tensor: +def get_boxes(instances: List["biogtr.io.Instance"]) -> torch.tensor: """Extract the bounding boxes from the input list of instances. Args: @@ -30,7 +29,8 @@ def get_boxes(instances: List[Instance]) -> torch.Tensor: def get_times( - ref_instances: list[Instance], query_instances: list[Instance] = None + ref_instances: list["biogtr.io.Instance"], + query_instances: list["biogtr.io.Instance"] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Extract the time indices of each instance relative to the window length. diff --git a/biogtr/models/transformer.py b/biogtr/models/transformer.py index 68c499a7..58864e47 100644 --- a/biogtr/models/transformer.py +++ b/biogtr/models/transformer.py @@ -11,10 +11,9 @@ * added fixed embeddings over boxes """ -from biogtr.io.instance import Instance -from biogtr.io.association_matrix import AssociationMatrix +from biogtr.io import AssociationMatrix from biogtr.models.attention_head import ATTWeightHead -from biogtr.models.embedding import Embedding +from biogtr.models import Embedding from biogtr.models.model_utils import get_boxes, get_times from torch import nn import copy @@ -141,12 +140,14 @@ def _reset_parameters(self): raise (e) def forward( - self, ref_instances: list[Instance], query_instances: list[Instance] = None + self, + ref_instances: list["biogtr.io.Instance"], + query_instances: list["biogtr.io.Instance"] = None, ) -> list[AssociationMatrix]: """Execute a forward pass through the transformer and attention head. Args: - ref instances: A list of instance objects (See `biogtr.data_structures.Instance` for more info.) + ref instances: A list of instance objects (See `biogtr.io.Instance` for more info.) query_instances: An set of instances to be used as decoder queries. Returns: @@ -222,9 +223,7 @@ def forward( query_pos_emb = self.pos_emb(query_boxes) query_emb = (query_pos_emb + query_temp_emb) / 2.0 - query_emb = query_emb.view(1, n_query, embed_dim) - query_emb = query_emb.permute(1, 0, 2) # (n_query, batch_size, embed_dim) else: query_instances = ref_instances diff --git a/biogtr/training/__init__.py b/biogtr/training/__init__.py index 932d36ac..c4e96012 100644 --- a/biogtr/training/__init__.py +++ b/biogtr/training/__init__.py @@ -1 +1,3 @@ """Initialize training module.""" + +# from .train import train diff --git a/biogtr/training/losses.py b/biogtr/training/losses.py index a9e721cd..c487092c 100644 --- a/biogtr/training/losses.py +++ b/biogtr/training/losses.py @@ -1,6 +1,5 @@ """Module containing different loss functions to be optimized.""" -from biogtr.io.frame import Frame from biogtr.models.model_utils import get_boxes, get_times from torch import nn from typing import List, Tuple @@ -35,7 +34,7 @@ def __init__( self.asso_weight = asso_weight def forward( - self, asso_preds: List[torch.Tensor], frames: List[Frame] + self, asso_preds: List[torch.Tensor], frames: List["Frame"] ) -> torch.Tensor: """Calculate association loss. diff --git a/biogtr/training/train.py b/biogtr/training/train.py index ac98b305..e252d3f5 100644 --- a/biogtr/training/train.py +++ b/biogtr/training/train.py @@ -3,8 +3,8 @@ Used for training a single model or deploying a batch train job on RUNAI CLI """ -from biogtr.io.config import Config -from biogtr.datasets.tracking_dataset import TrackingDataset +from biogtr.io import Config +from biogtr.datasets import TrackingDataset from biogtr.datasets.data_utils import view_training_batch from multiprocessing import cpu_count from omegaconf import DictConfig @@ -18,7 +18,7 @@ @hydra.main(config_path="configs", config_name=None, version_base=None) -def main(cfg: DictConfig): +def run(cfg: DictConfig): """Train model based on config. Handles all config parsing and initialization then calls `trainer.train()`. @@ -107,4 +107,4 @@ def main(cfg: DictConfig): # deploy batch train job: # python train.py --config-dir=./configs --config-name=base +batch_config=test_batch_train.csv - main() + train() diff --git a/tests/test_config.py b/tests/test_config.py index 1d8c64cd..4f7ebbc7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,9 +1,8 @@ """Tests for `config.py`""" from omegaconf import OmegaConf -from biogtr.io.config import Config -from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer -from biogtr.models.gtr_runner import GTRRunner +from biogtr.io import Config +from biogtr.models import GlobalTrackingTransformer, GTRRunner import torch diff --git a/tests/test_data_model.py b/tests/test_data_model.py index cbe5aa8c..ef5d0320 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -1,9 +1,6 @@ """Tests for Instance, Frame, and AssociationMatrix Objects""" -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance -from biogtr.io.association_matrix import AssociationMatrix -from biogtr.io.track import Track +from biogtr.io import Frame, Instance, AssociationMatrix, Track import torch import pytest import numpy as np diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8862ffad..775c1929 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,12 +1,14 @@ """Test dataset logic.""" -from biogtr.datasets.base_dataset import BaseDataset +from biogtr.datasets import ( + BaseDataset, + MicroscopyDataset, + SleapDataset, + CellTrackingDataset, + TrackingDataset, +) from biogtr.datasets.data_utils import get_max_padding, NodeDropout -from biogtr.datasets.microscopy_dataset import MicroscopyDataset -from biogtr.datasets.sleap_dataset import SleapDataset -from biogtr.datasets.tracking_dataset import TrackingDataset from biogtr.models.model_utils import get_device -from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset from torch.utils.data import DataLoader import pytest import torch diff --git a/tests/test_inference.py b/tests/test_inference.py index 66c5ac61..ddda7cd2 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -3,13 +3,88 @@ import torch import pytest import numpy as np -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance -from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer -from biogtr.inference.tracker import Tracker +from biogtr.io import Frame, Instance +from biogtr.models import GlobalTrackingTransformer +from biogtr.inference import Tracker, post_processing, metrics from biogtr.inference.track_queue import TrackQueue -from biogtr.inference import post_processing -from biogtr.inference import metrics + + +def test_track_queue(): + window_size = 8 + max_gap = 10 + img_shape = (3, 1024, 1024) + n_instances_per_frame = [2] * window_size + + frames = [] + instances_per_frame = [] + + tq = TrackQueue(window_size, max_gap) + for i in range(window_size): + instances = [] + for j in range(n_instances_per_frame[i]): + instances.append(Instance(gt_track_id=j, pred_track_id=j)) + instances_per_frame.append(instances) + frame = Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances) + frames.append(frame) + + tq.add_frame(frame) + + assert len(tq) == sum(n_instances_per_frame[1:]) + assert tq.n_tracks == max(n_instances_per_frame) + assert tq.tracks == [i for i in range(max(n_instances_per_frame))] + assert len(tq.collate_tracks()) == window_size - 1 + assert all([gap == 0 for gap in tq._curr_gap.values()]) + assert tq.curr_track == max(n_instances_per_frame) - 1 + + tq.add_frame( + Frame( + video_id=0, + frame_id=window_size + 1, + img_shape=img_shape, + instances=[Instance(gt_track_id=0, pred_track_id=0)], + ) + ) + + assert len(tq._queues[0]) == window_size - 1 + assert tq._curr_gap[0] == 0 + assert tq._curr_gap[max(n_instances_per_frame) - 1] == 1 + + tq.add_frame( + Frame( + video_id=0, + frame_id=window_size + 1, + img_shape=img_shape, + instances=[ + Instance(gt_track_id=1, pred_track_id=1), + Instance( + gt_track_id=max(n_instances_per_frame), + pred_track_id=max(n_instances_per_frame), + ), + ], + ) + ) + + assert len(tq._queues[max(n_instances_per_frame)]) == 1 + assert tq._curr_gap[1] == 0 + assert tq._curr_gap[0] == 1 + + for i in range(max_gap): + tq.add_frame( + Frame( + video_id=0, + frame_id=window_size + i + 1, + img_shape=img_shape, + instances=[Instance(gt_track_id=0, pred_track_id=0)], + ) + ) + + assert tq.n_tracks == 1 + assert tq.curr_track == max(n_instances_per_frame) + assert 0 in tq._queues.keys() + + tq.end_tracks() + + assert len(tq) == 0 def test_track_queue(): diff --git a/tests/test_models.py b/tests/test_models.py index 0c125d24..bf4e8e47 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,23 +2,22 @@ import pytest import torch -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance +from biogtr.io import Frame, Instance from biogtr.models.mlp import MLP from biogtr.models.attention_head import ATTWeightHead -from biogtr.models.embedding import Embedding -from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer -from biogtr.models.transformer import ( +from biogtr.models import ( + Embedding, + VisualEncoder, Transformer, + GlobalTrackingTransformer, +) +from biogtr.models.transformer import ( TransformerEncoderLayer, TransformerDecoderLayer, ) -from biogtr.models.visual_encoder import VisualEncoder # todo: add named tensor tests - - def test_mlp(): """Test MLP logic.""" b, n, f = 1, 10, 1024 # batch size, num instances, features diff --git a/tests/test_training.py b/tests/test_training.py index f5331a58..0baded43 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -3,14 +3,11 @@ import os import pytest import torch -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance +from biogtr.io import Frame, Instance, Config from biogtr.training.losses import AssoLoss -from biogtr.models.gtr_runner import GTRRunner -from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer +from biogtr.models import GTRRunner from omegaconf import OmegaConf, DictConfig -from biogtr.io.config import Config -from biogtr.training.train import main +from biogtr.training.train import run # TODO: add named tensor tests # TODO: use temp dir and cleanup after tests (https://docs.pytest.org/en/7.1.x/how-to/tmp_path.html) @@ -139,4 +136,4 @@ def test_config_gtr_runner(base_config, params_config, two_flies): cfg.set_hparams(hparams) with torch.autograd.set_detect_anomaly(True): - main(cfg.cfg) + run(cfg.cfg) From 63506c83d256ad1a29a0eec45d23ebc85e68caa7 Mon Sep 17 00:00:00 2001 From: Aaditya Prasad <78439225+aaprasad@users.noreply.github.com> Date: Mon, 3 Jun 2024 16:49:39 -0700 Subject: [PATCH 5/7] Refactor inference script (#51) Co-authored-by: Talmo Pereira --- biogtr/datasets/base_dataset.py | 10 +- biogtr/datasets/cell_tracking_dataset.py | 3 +- biogtr/datasets/microscopy_dataset.py | 3 +- biogtr/datasets/sleap_dataset.py | 3 +- .../configs/{base.yaml => inference.yaml} | 2 + biogtr/inference/track.py | 103 ++++++++------- biogtr/inference/tracker.py | 11 +- biogtr/io/config.py | 104 ++++++++++++--- biogtr/io/frame.py | 13 +- biogtr/io/instance.py | 4 +- biogtr/models/gtr_runner.py | 1 - tests/configs/inference.yaml | 22 ++++ tests/fixtures/configs.py | 7 + tests/test_datasets.py | 8 +- tests/test_inference.py | 122 ++++++------------ tests/test_training.py | 9 +- 16 files changed, 260 insertions(+), 165 deletions(-) rename biogtr/inference/configs/{base.yaml => inference.yaml} (92%) create mode 100644 tests/configs/inference.yaml diff --git a/biogtr/datasets/base_dataset.py b/biogtr/datasets/base_dataset.py index 6f54719d..15b87d45 100644 --- a/biogtr/datasets/base_dataset.py +++ b/biogtr/datasets/base_dataset.py @@ -13,7 +13,8 @@ class BaseDataset(Dataset): def __init__( self, - files: list[str], + label_files: list[str], + vid_files: list[str], padding: int, crop_size: int, chunk: bool, @@ -27,7 +28,9 @@ def __init__( """Initialize Dataset. Args: - files: a list of files, file types are combined in subclasses + label_files: a list of paths to label files. + should at least contain detections for inference, detections + tracks for training. + vid_files: list of paths to video files. padding: amount of padding around object crops crop_size: the size of the object crops chunk: whether or not to chunk the dataset into batches @@ -42,7 +45,8 @@ def __init__( gt_list: An optional path to .txt file containing ground truth for cell tracking challenge datasets. """ - self.files = files + self.vid_files = vid_files + self.label_files = label_files self.padding = padding self.crop_size = crop_size self.chunk = chunk diff --git a/biogtr/datasets/cell_tracking_dataset.py b/biogtr/datasets/cell_tracking_dataset.py index 33aa12fa..9567de46 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/biogtr/datasets/cell_tracking_dataset.py @@ -55,7 +55,8 @@ def __init__( "end_frame", "parent_id" """ super().__init__( - raw_images + gt_images, + gt_images, + raw_images, padding, crop_size, chunk, diff --git a/biogtr/datasets/microscopy_dataset.py b/biogtr/datasets/microscopy_dataset.py index 2c15cc1e..9656d19d 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/biogtr/datasets/microscopy_dataset.py @@ -51,7 +51,8 @@ def __init__( seed: set a seed for reproducibility """ super().__init__( - videos + tracks, + tracks, + videos, padding, crop_size, chunk, diff --git a/biogtr/datasets/sleap_dataset.py b/biogtr/datasets/sleap_dataset.py index 6a934b4e..b23b4e70 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/biogtr/datasets/sleap_dataset.py @@ -65,7 +65,8 @@ def __init__( verbose: boolean representing whether to print """ super().__init__( - slp_files + video_files, + slp_files, + video_files, padding, crop_size, chunk, diff --git a/biogtr/inference/configs/base.yaml b/biogtr/inference/configs/inference.yaml similarity index 92% rename from biogtr/inference/configs/base.yaml rename to biogtr/inference/configs/inference.yaml index faf036a4..e48416d6 100644 --- a/biogtr/inference/configs/base.yaml +++ b/biogtr/inference/configs/inference.yaml @@ -4,6 +4,7 @@ tracker: decay_time: 0.9 iou: "mult" max_center_dist: 1.0 + persistent_tracking: True dataset: test_dataset: @@ -11,6 +12,7 @@ dataset: video_files: ["../training/190612_110405_wt_18159111_rig2.2@11730.mp4", "../training/190612_110405_wt_18159111_rig2.2@11730.mp4"] chunk: True clip_length: 32 + anchor: "centroid" dataloader: test_dataloader: diff --git a/biogtr/inference/track.py b/biogtr/inference/track.py index c930cf62..aa766d5d 100644 --- a/biogtr/inference/track.py +++ b/biogtr/inference/track.py @@ -11,6 +11,7 @@ import pandas as pd import pytorch_lightning as pl import torch +import sleap_io as sio def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = None): @@ -50,60 +51,45 @@ def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = N return save_df -def inference( - model: GTRRunner, dataloader: torch.utils.data.DataLoader +def track( + model: GTRRunner, trainer: pl.Trainer, dataloader: torch.utils.data.DataLoader ) -> list[pd.DataFrame]: """Run Inference. Args: - model: model loaded from checkpoint used for inference + model: GTRRunner model loaded from checkpoint used for inference + trainer: lighting Trainer object used for handling inference log. dataloader: dataloader containing inference data Return: List of DataFrames containing prediction results for each video """ - num_videos = len(dataloader.dataset.slp_files) - trainer = pl.Trainer(devices=1, limit_predict_batches=3) + num_videos = len(dataloader.dataset.vid_files) preds = trainer.predict(model, dataloader) - vid_trajectories = [[] for i in range(num_videos)] + vid_trajectories = {i: [] for i in range(num_videos)} + tracks = {} for batch in preds: for frame in batch: - vid_trajectories[frame.video_id].append(frame) + lf, tracks = frame.to_slp(tracks) + if frame.frame_id.item() == 0: + print(f"Video: {lf.video}") + vid_trajectories[frame.video_id.item()].append(lf) - saved = [] - - for video in vid_trajectories: + for vid_id, video in vid_trajectories.items(): if len(video) > 0: - save_dict = {} - video_ids = [] - frame_ids = [] - X, Y = [], [] - pred_track_ids = [] - for frame in video: - for i, instance in frame.instances: - video_ids.append(frame.video_id.item()) - frame_ids.append(frame.frame_id.item()) - bbox = instance.bbox - y = (bbox[2] + bbox[0]) / 2 - x = (bbox[3] + bbox[1]) / 2 - X.append(x.item()) - Y.append(y.item()) - pred_track_ids.append(instance.pred_track_id.item()) - save_dict["Video"] = video_ids - save_dict["Frame"] = frame_ids - save_dict["X"] = X - save_dict["Y"] = Y - save_dict["Pred_track_id"] = pred_track_ids - save_df = pd.DataFrame(save_dict) - saved.append(save_df) - - return saved + try: + vid_trajectories[vid_id] = sio.Labels(video) + except AttributeError as e: + print(video[0].video) + raise (e) + + return vid_trajectories @hydra.main(config_path="configs", config_name=None, version_base=None) -def main(cfg: DictConfig): +def run(cfg: DictConfig) -> dict[int, sio.Labels]: """Run inference based on config file. Args: @@ -116,14 +102,14 @@ def main(cfg: DictConfig): index = int(os.environ["POD_INDEX"]) # For testing without deploying a job on runai except KeyError: - print("Pod Index Not found! Setting index to 0") - index = 0 + index = input("Pod Index Not found! Please choose a pod index: ") + print(f"Pod Index: {index}") checkpoints = pd.read_csv(cfg.checkpoints) checkpoint = checkpoints.iloc[index] else: - checkpoint = pred_cfg.get_ckpt_path() + checkpoint = pred_cfg.cfg.ckpt_path model = GTRRunner.load_from_checkpoint(checkpoint) tracker_cfg = pred_cfg.get_tracker_cfg() @@ -131,22 +117,41 @@ def main(cfg: DictConfig): model.tracker_cfg = tracker_cfg print(f"Using the following params for tracker:") pprint(model.tracker_cfg) - dataset = pred_cfg.get_dataset(mode="test") + dataset = pred_cfg.get_dataset(mode="test") dataloader = pred_cfg.get_dataloader(dataset, mode="test") - preds = inference(model, dataloader) - for i, pred in enumerate(preds): - print(pred) - outdir = pred_cfg.cfg.outdir if "outdir" in pred_cfg.cfg else "./results" - os.makedirs(outdir, exist_ok=True) + + trainer = pred_cfg.get_trainer() + + preds = track(model, trainer, dataloader) + + outdir = pred_cfg.cfg.outdir if "outdir" in pred_cfg.cfg else "./results" + os.makedirs(outdir, exist_ok=True) + + run_num = 0 + for i, pred in preds.items(): outpath = os.path.join( outdir, - f"{Path(pred_cfg.cfg.dataset.test_dataset.slp_files[i]).stem}_tracking_results", + f"{Path(dataloader.dataset.label_files[i]).stem}.biogtr_inference.v{run_num}.slp", ) - print(f"Saving to {outpath}") - # TODO: Figure out how to overwrite sleap labels instance labels w pred instance labels then save as a new slp file - pred.to_csv(outpath, index=False) + if os.path.exists(outpath): + run_num += 1 + outpath = outpath.replace(f".v{run_num-1}", f".v{run_num}") + print(f"Saving {preds} to {outpath}") + pred.save(outpath) + + return preds if __name__ == "__main__": - main() + # example calls: + + # train with base config: + # python train.py --config-dir=./configs --config-name=inference + + # override with params config: + # python train.py --config-dir=./configs --config-name=inference +params_config=configs/params.yaml + + # override with params config, and specific params: + # python train.py --config-dir=./configs --config-name=inference +params_config=configs/params.yaml dataset.train_dataset.padding=10 + run() diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index da1320be..e5c3083a 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -155,7 +155,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame self.track_queue.end_tracks() """ - Initialize tracks on first frame of video or first instance of detections. + Initialize tracks on first frame where detections appear. """ if len(self.track_queue) == 0: if frame_to_track.has_instances(): @@ -167,12 +167,12 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame curr_track_id = 0 for i, instance in enumerate(frames[batch_idx].instances): instance.pred_track_id = instance.gt_track_id - curr_track_id = instance.pred_track_id + curr_track_id = max(curr_track_id, instance.pred_track_id) for i, instance in enumerate(frames[batch_idx].instances): if instance.pred_track_id == -1: - instance.pred_track_id = curr_track_id curr_track += 1 + instance.pred_track_id = curr_track_id else: if ( @@ -250,6 +250,7 @@ def _run_global_tracker( overlap_thresh = self.overlap_thresh mult_thresh = self.mult_thresh n_traj = self.track_queue.n_tracks + curr_track = self.track_queue.curr_track reid_features = torch.cat([frame.get_features() for frame in frames], dim=0)[ None @@ -470,8 +471,8 @@ def _run_global_tracker( if track_ids[i] < 0: if self.verbose: print(f"Creating new track {n_traj}") - track_ids[i] = n_traj - n_traj += 1 + curr_track += 1 + track_ids[i] = curr_track query_frame.matches = (match_i, match_j) diff --git a/biogtr/io/config.py b/biogtr/io/config.py index 280a1b04..7ea8a0ac 100644 --- a/biogtr/io/config.py +++ b/biogtr/io/config.py @@ -5,6 +5,7 @@ from pprint import pprint from typing import Union, Iterable from pathlib import Path +import glob import pytorch_lightning as pl import torch @@ -12,26 +13,28 @@ class Config: """Class handling loading components based on config params.""" - def __init__(self, cfg: DictConfig): + def __init__(self, cfg: DictConfig, params_cfg: DictConfig = None): """Initialize the class with config from hydra/omega conf. First uses `base_param` file then overwrites with specific `params_config`. Args: cfg: The `DictConfig` containing all the hyperparameters needed for + training/evaluation. + params_cfg: The `DictConfig` containing subset of hyperparameters to override. training/evaluation """ base_cfg = cfg print(f"Base Config: {cfg}") if "params_config" in cfg: - # merge configs - params_config = OmegaConf.load(cfg.params_config) - pprint(f"Overwriting base config with {params_config}") - self.cfg = OmegaConf.merge(base_cfg, params_config) + params_cfg = OmegaConf.load(cfg.params_config) + + if params_cfg: + pprint(f"Overwriting base config with {params_cfg}") + self.cfg = OmegaConf.merge(base_cfg, params_cfg) # merge configs else: - # just use base config - self.cfg = base_cfg + self.cfg = cfg def __repr__(self): """Object representation of config class.""" @@ -41,6 +44,18 @@ def __str__(self): """Return a string representation of config class.""" return f"Config({self.cfg})" + @classmethod + def from_yaml(cls, base_cfg_path: str, params_cfg_path: str = None) -> None: + """Load config directly from yaml. + + Args: + base_cfg_path: path to base config file. + params_cfg_path: path to override params. + """ + base_cfg = OmegaConf.load(base_cfg_path) + params_cfg = OmegaConf.load(params_cfg_path) if params_cfg else None + return cls(base_cfg, params_cfg) + def set_hparams(self, hparams: dict) -> bool: """Setter function for overwriting specific hparams. @@ -92,7 +107,7 @@ def get_tracker_cfg(self) -> dict: tracker_cfg[key] = val return tracker_cfg - def get_gtr_runner(self): + def get_gtr_runner(self) -> "GTRRunner": """Get lightning module for training, validation, and inference.""" from biogtr.models import GTRRunner @@ -126,6 +141,27 @@ def get_gtr_runner(self): return model + def get_data_paths(self, data_cfg: dict) -> tuple[list[str], list[str]]: + """Get file paths from directory. + + Args: + data_cfg: Config for the dataset containing "dir" key. + + Returns: + lists of labels file paths and video file paths respectively + """ + dir_cfg = data_cfg.pop("dir", None) + + if dir_cfg: + labels_suff = dir_cfg.labels_suffix + vid_suff = dir_cfg.vid_suffix + + label_files = glob.glob(f"{dir_cfg.path}/*.{labels_suff}") + vid_files = glob.glob(f"{dir_cfg.path}/*.{vid_suff}") + return label_files, vid_files + + return None, None + def get_dataset( self, mode: str ) -> Union["SleapDataset", "MicroscopyDataset", "CellTrackingDataset"]: @@ -151,13 +187,39 @@ def get_dataset( "`mode` must be one of ['train', 'val','test'], not '{mode}'" ) + label_files, vid_files = self.get_data_paths(dataset_params) # todo: handle this better if "slp_files" in dataset_params: + if label_files is not None: + dataset_params.slp_files = label_files + if vid_files is not None: + dataset_params.video_files = vid_files return SleapDataset(**dataset_params) + elif "tracks" in dataset_params or "source" in dataset_params: + if label_files is not None: + dataset_params.tracks = label_files + if vid_files is not None: + dataset_params.video_files = vid_files return MicroscopyDataset(**dataset_params) + elif "raw_images" in dataset_params: + if label_files is not None: + dataset_params.gt_images = label_files + if vid_files is not None: + dataset_params.raw_images = vid_files return CellTrackingDataset(**dataset_params) + + # todo: handle this better + if "slp_files" in dataset_params: + return SleapDataset(**dataset_params) + + elif "tracks" in dataset_params or "source" in dataset_params: + return MicroscopyDataset(**dataset_params) + + elif "raw_images" in dataset_params: + return CellTrackingDataset(**dataset_params) + else: raise ValueError( "Could not resolve dataset type from Config! Please include \ @@ -315,10 +377,10 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint: def get_trainer( self, - callbacks: list[pl.callbacks.Callback], - logger: pl.loggers.WandbLogger, + callbacks: list[pl.callbacks.Callback] = None, + logger: pl.loggers.WandbLogger = None, devices: int = 1, - accelerator: str = None, + accelerator: str = "auto", ) -> pl.Trainer: """Getter for the lightning trainer. @@ -332,17 +394,23 @@ def get_trainer( Returns: A lightning Trainer with specified params """ - if "accelerator" not in self.cfg.trainer: - self.set_hparams({"trainer.accelerator": accelerator}) - if "devices" not in self.cfg.trainer: - self.set_hparams({"trainer.devices": devices}) + if "trainer" in self.cfg: + trainer_params = self.cfg.trainer - trainer_params = self.cfg.trainer - if "profiler" in trainer_params: + else: + trainer_params = {} + + profiler = trainer_params.pop("profiler", None) + if "profiler": profiler = pl.profilers.AdvancedProfiler(filename="profile.txt") - trainer_params.pop("profiler") else: profiler = None + + if "accelerator" not in trainer_params: + trainer_params["accelerator"] = accelerator + if "devices" not in trainer_params: + trainer_params["devices"] = devices + return pl.Trainer( callbacks=callbacks, logger=logger, diff --git a/biogtr/io/frame.py b/biogtr/io/frame.py index 803970bc..5607e832 100644 --- a/biogtr/io/frame.py +++ b/biogtr/io/frame.py @@ -158,7 +158,7 @@ def from_slp( ) def to_slp( - self, track_lookup: dict[int, sio.Track] = {} + self, track_lookup: dict[int, sio.Track] = None ) -> tuple[sio.LabeledFrame, dict[int, sio.Track]]: """Convert Frame to sleap_io.LabeledFrame object. @@ -168,13 +168,22 @@ def to_slp( Returns: A tuple containing a LabeledFrame object with necessary metadata and a lookup dictionary containing the track_id and sio.Track for persistence """ + if track_lookup is None: + track_lookup = {} + slp_instances = [] for instance in self.instances: slp_instance, track_lookup = instance.to_slp(track_lookup=track_lookup) slp_instances.append(slp_instance) + + video = ( + self.video + if isinstance(self.video, sio.Video) + else sio.load_video(self.video) + ) return ( sio.LabeledFrame( - video=self.video, + video=video, frame_idx=self.frame_id.item(), instances=slp_instances, ), diff --git a/biogtr/io/instance.py b/biogtr/io/instance.py index 44dc2386..5ffef867 100644 --- a/biogtr/io/instance.py +++ b/biogtr/io/instance.py @@ -227,7 +227,7 @@ def to_slp( return ( sio.PredictedInstance.from_numpy( - points=self.pose, + points=np.array(list(self.pose.values())), skeleton=self.skeleton, point_scores=self.point_scores, instance_score=self.instance_score, @@ -238,7 +238,7 @@ def to_slp( ) except Exception as e: print( - f"Pose shape: {self.pose.shape}, Pose score shape {self.point_scores.shape}" + f"Pose: {np.array(list(self.pose.values())).shape}, Pose score shape {self.point_scores.shape}" ) raise RuntimeError(f"Failed to convert to sio.PredictedInstance: {e}") diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index 8ddfe787..7dbf4b2b 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -159,7 +159,6 @@ def predict_step( Returns: A list of dicts where each dict is a frame containing the predicted track ids """ - self.tracker.persistent_tracking = True frames_pred = self.tracker(self.model, batch[0]) return frames_pred diff --git a/tests/configs/inference.yaml b/tests/configs/inference.yaml new file mode 100644 index 00000000..53680a7c --- /dev/null +++ b/tests/configs/inference.yaml @@ -0,0 +1,22 @@ +ckpt_path: null +tracker: + overlap_thresh: 0.01 + decay_time: 0.9 + iou: "mult" + max_center_dist: 1.0 + persistent_tracking: True + +dataset: + test_dataset: + slp_files: ['tests/data/sleap/two_flies.slp', 'tests/data/sleap/two_flies.slp'] + video_files: ['tests/data/sleap/two_flies.mp4', 'tests/data/sleap/two_flies.mp4'] + clip_length: 32 + anchors: "centroid" + mode: "test" + +dataloader: + test_dataloader: + shuffle: False + num_workers: 0 + + \ No newline at end of file diff --git a/tests/fixtures/configs.py b/tests/fixtures/configs.py index 3cf06840..19f2797c 100644 --- a/tests/fixtures/configs.py +++ b/tests/fixtures/configs.py @@ -20,3 +20,10 @@ def base_config(config_dir): def params_config(config_dir): """Get the full path to the supplementary params config.""" return os.path.join(config_dir, "params.yaml") + + +@pytest.fixture +def inference_config(config_dir): + """Get the full path to the inference params config.""" + + return os.path.join(config_dir, "inference.yaml") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 775c1929..ab9d7640 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -21,7 +21,13 @@ class DummyDataset(BaseDataset): pass ds = DummyDataset( - files=[], padding=0, crop_size=0, chunk=False, clip_length=0, mode="" + label_files=[], + vid_files=[], + padding=0, + crop_size=0, + chunk=False, + clip_length=0, + mode="", ) with pytest.raises(NotImplementedError): diff --git a/tests/test_inference.py b/tests/test_inference.py index ddda7cd2..a5ef05f9 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -3,88 +3,13 @@ import torch import pytest import numpy as np -from biogtr.io import Frame, Instance -from biogtr.models import GlobalTrackingTransformer +from pytorch_lightning import Trainer +from omegaconf import OmegaConf, DictConfig +from biogtr.io import Frame, Instance, Config +from biogtr.models import GTRRunner, GlobalTrackingTransformer from biogtr.inference import Tracker, post_processing, metrics from biogtr.inference.track_queue import TrackQueue - - -def test_track_queue(): - window_size = 8 - max_gap = 10 - img_shape = (3, 1024, 1024) - n_instances_per_frame = [2] * window_size - - frames = [] - instances_per_frame = [] - - tq = TrackQueue(window_size, max_gap) - for i in range(window_size): - instances = [] - for j in range(n_instances_per_frame[i]): - instances.append(Instance(gt_track_id=j, pred_track_id=j)) - instances_per_frame.append(instances) - frame = Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances) - frames.append(frame) - - tq.add_frame(frame) - - assert len(tq) == sum(n_instances_per_frame[1:]) - assert tq.n_tracks == max(n_instances_per_frame) - assert tq.tracks == [i for i in range(max(n_instances_per_frame))] - assert len(tq.collate_tracks()) == window_size - 1 - assert all([gap == 0 for gap in tq._curr_gap.values()]) - assert tq.curr_track == max(n_instances_per_frame) - 1 - - tq.add_frame( - Frame( - video_id=0, - frame_id=window_size + 1, - img_shape=img_shape, - instances=[Instance(gt_track_id=0, pred_track_id=0)], - ) - ) - - assert len(tq._queues[0]) == window_size - 1 - assert tq._curr_gap[0] == 0 - assert tq._curr_gap[max(n_instances_per_frame) - 1] == 1 - - tq.add_frame( - Frame( - video_id=0, - frame_id=window_size + 1, - img_shape=img_shape, - instances=[ - Instance(gt_track_id=1, pred_track_id=1), - Instance( - gt_track_id=max(n_instances_per_frame), - pred_track_id=max(n_instances_per_frame), - ), - ], - ) - ) - - assert len(tq._queues[max(n_instances_per_frame)]) == 1 - assert tq._curr_gap[1] == 0 - assert tq._curr_gap[0] == 1 - - for i in range(max_gap): - tq.add_frame( - Frame( - video_id=0, - frame_id=window_size + i + 1, - img_shape=img_shape, - instances=[Instance(gt_track_id=0, pred_track_id=0)], - ) - ) - - assert tq.n_tracks == 1 - assert tq.curr_track == max(n_instances_per_frame) - assert 0 in tq._queues.keys() - - tq.end_tracks() - - assert len(tq) == 0 +from biogtr.inference.track import run def test_track_queue(): @@ -339,3 +264,40 @@ def test_metrics(): sw_cnt, clear_mot["num_switches"], ) + + +def get_ckpt(ckpt_path: str): + """Save GTR Runner to checkpoint file.""" + + class DummyDataset(torch.utils.data.Dataset): + + def __len__(self): + return 0 + + def __getitem__(self, idx): + return None + + dl = torch.utils.data.DataLoader(DummyDataset()) + model = GTRRunner() + trainer = Trainer(max_steps=1, min_steps=1) + trainer.fit(model, dl) + trainer.save_checkpoint(ckpt_path) + + return ckpt_path + + +def test_track(tmp_path, inference_config): + ckpt_path = tmp_path / "model.ckpt" + get_ckpt(ckpt_path) + + out_dir = tmp_path / "preds" + out_dir.mkdir() + + inference_cfg = OmegaConf.load(inference_config) + + cfg = Config(inference_cfg) + + cfg.set_hparams({"ckpt_path": ckpt_path, "outdir": out_dir}) + + run(cfg.cfg) + assert len(list(out_dir.iterdir())) == len(cfg.cfg.dataset.test_dataset.video_files) diff --git a/tests/test_training.py b/tests/test_training.py index 0baded43..357ce267 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -123,15 +123,22 @@ def test_basic_gtr_runner(): os.environ.get("GITHUB_ACTIONS") == "true", reason="Silent fail on GitHub Actions", ) -def test_config_gtr_runner(base_config, params_config, two_flies): +def test_config_gtr_runner(tmp_path, base_config, params_config, two_flies): """Test config GTR Runner.""" base_cfg = OmegaConf.load(base_config) base_cfg["params_config"] = params_config cfg = Config(base_cfg) + model_dir = tmp_path / "models" + model_dir.mkdir() + + logs_dir = tmp_path / "logs" + logs_dir.mkdir() hparams = { "dataset.clip_length": 8, "trainer.min_epochs": 1, + "checkpointing.dirpath": model_dir, + "logging.save_dir": logs_dir, } cfg.set_hparams(hparams) From 152fca66d6920ac51832336af522ca617fff87ff Mon Sep 17 00:00:00 2001 From: Aaditya Prasad <78439225+aaprasad@users.noreply.github.com> Date: Wed, 5 Jun 2024 14:42:37 -0700 Subject: [PATCH 6/7] Patch transformer forward by replacing nan with -1 placeholder (#59) --- biogtr/inference/tracker.py | 8 +++++++- biogtr/models/transformer.py | 3 ++- biogtr/models/visual_encoder.py | 3 --- biogtr/training/losses.py | 8 +++++++- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index e5c3083a..4aa36f39 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -445,7 +445,13 @@ def _run_global_tracker( query_frame.add_traj_score("scaled", scaled_traj_score_df) ################################################################################ - match_i, match_j = linear_sum_assignment((-traj_score)) + try: + match_i, match_j = linear_sum_assignment((-traj_score)) + except ValueError as e: + print(reid_features.isnan().any()) + print(asso_output) + print(traj_score) + raise (e) track_ids = instance_ids.new_full((n_query,), -1) for i, j in zip(match_i, match_j): diff --git a/biogtr/models/transformer.py b/biogtr/models/transformer.py index 58864e47..75579da5 100644 --- a/biogtr/models/transformer.py +++ b/biogtr/models/transformer.py @@ -217,7 +217,7 @@ def forward( ) # (n_query, batch_size, embed_dim) query_boxes = get_boxes(query_instances) - + query_boxes = torch.nan_to_num(query_boxes, -1.0) query_temp_emb = self.temp_emb(query_times / window_length) query_pos_emb = self.pos_emb(query_boxes) @@ -225,6 +225,7 @@ def forward( query_emb = (query_pos_emb + query_temp_emb) / 2.0 query_emb = query_emb.view(1, n_query, embed_dim) query_emb = query_emb.permute(1, 0, 2) # (n_query, batch_size, embed_dim) + else: query_instances = ref_instances diff --git a/biogtr/models/visual_encoder.py b/biogtr/models/visual_encoder.py index 25468032..ce5bc563 100644 --- a/biogtr/models/visual_encoder.py +++ b/biogtr/models/visual_encoder.py @@ -144,11 +144,8 @@ def forward(self, img: torch.Tensor) -> torch.Tensor: # Reshape feature vectors feats = feats.reshape([img.shape[0], -1]) # (B, out_dim) - # Map feature vectors to output dimension using linear layer. feats = self.out_layer(feats) # (B, d_model) - # Normalize output feature vectors. feats = F.normalize(feats) # (B, d_model) - return feats diff --git a/biogtr/training/losses.py b/biogtr/training/losses.py index c487092c..ff3e6eca 100644 --- a/biogtr/training/losses.py +++ b/biogtr/training/losses.py @@ -47,7 +47,13 @@ def forward( """ # get number of detected objects and ground truth ids n_t = [frame.num_detected for frame in frames] - target_inst_id = torch.cat([frame.get_gt_track_ids() for frame in frames]) + try: + target_inst_id = torch.cat( + [frame.get_gt_track_ids().to(asso_preds[-1].device) for frame in frames] + ) + except RuntimeError as e: + print([frame.get_gt_track_ids().device for frame in frames]) + raise (e) instances = [instance for frame in frames for instance in frame.instances] # for now set equal since detections are fixed From f04eb9fa31ffe323ab63f3efb2e39eb990e319e0 Mon Sep 17 00:00:00 2001 From: Aaditya Prasad <78439225+aaprasad@users.noreply.github.com> Date: Wed, 5 Jun 2024 15:59:06 -0700 Subject: [PATCH 7/7] Rename `biogtr` to `dreem` (#60) --- .github/workflows/ci.yml | 10 +++---- .gitignore | 4 +-- README.md | 10 +++---- biogtr/__init__.py | 18 ----------- biogtr/cli.py | 1 - biogtr/io/__init__.py | 7 ----- dreem/__init__.py | 18 +++++++++++ dreem/cli.py | 1 + {biogtr => dreem}/datasets/__init__.py | 0 {biogtr => dreem}/datasets/base_dataset.py | 4 +-- .../datasets/cell_tracking_dataset.py | 6 ++-- {biogtr => dreem}/datasets/data_utils.py | 0 {biogtr => dreem}/datasets/eval_dataset.py | 2 +- .../datasets/microscopy_dataset.py | 6 ++-- {biogtr => dreem}/datasets/sleap_dataset.py | 4 +-- .../datasets/tracking_dataset.py | 6 ++-- {biogtr => dreem}/inference/__init__.py | 0 {biogtr => dreem}/inference/boxes.py | 0 .../inference/configs/inference.yaml | 0 {biogtr => dreem}/inference/metrics.py | 10 +++---- .../inference/post_processing.py | 2 +- {biogtr => dreem}/inference/track.py | 8 ++--- {biogtr => dreem}/inference/track_queue.py | 2 +- {biogtr => dreem}/inference/tracker.py | 14 ++++----- dreem/io/__init__.py | 7 +++++ {biogtr => dreem}/io/association_matrix.py | 2 +- {biogtr => dreem}/io/config.py | 16 +++++----- {biogtr => dreem}/io/frame.py | 6 ++-- {biogtr => dreem}/io/instance.py | 4 +-- {biogtr => dreem}/io/track.py | 0 {biogtr => dreem}/io/visualize.py | 0 {biogtr => dreem}/models/__init__.py | 0 {biogtr => dreem}/models/attention_head.py | 2 +- {biogtr => dreem}/models/embedding.py | 2 +- .../models/global_tracking_transformer.py | 8 ++--- {biogtr => dreem}/models/gtr_runner.py | 30 +++++++++---------- {biogtr => dreem}/models/mlp.py | 0 {biogtr => dreem}/models/model_utils.py | 6 ++-- {biogtr => dreem}/models/transformer.py | 18 +++++------ {biogtr => dreem}/models/visual_encoder.py | 0 {biogtr => dreem}/training/__init__.py | 0 {biogtr => dreem}/training/configs/base.yaml | 0 .../training/configs/params.yaml | 0 .../training/configs/test_batch_train.csv | 0 {biogtr => dreem}/training/losses.py | 2 +- {biogtr => dreem}/training/train.py | 6 ++-- {biogtr => dreem}/version.py | 0 environment.yml | 2 +- environment_cpu.yml | 2 +- environment_osx-arm64.yml | 2 +- pyproject.toml | 10 +++---- tests/fixtures/datasets.py | 2 +- tests/test_config.py | 4 +-- tests/test_data_model.py | 2 +- tests/test_datasets.py | 6 ++-- tests/test_inference.py | 10 +++---- tests/test_models.py | 10 +++---- tests/test_training.py | 8 ++--- tests/test_version.py | 4 +-- 59 files changed, 152 insertions(+), 152 deletions(-) delete mode 100644 biogtr/__init__.py delete mode 100644 biogtr/cli.py delete mode 100644 biogtr/io/__init__.py create mode 100644 dreem/__init__.py create mode 100644 dreem/cli.py rename {biogtr => dreem}/datasets/__init__.py (100%) rename {biogtr => dreem}/datasets/base_dataset.py (98%) rename {biogtr => dreem}/datasets/cell_tracking_dataset.py (97%) rename {biogtr => dreem}/datasets/data_utils.py (100%) rename {biogtr => dreem}/datasets/eval_dataset.py (98%) rename {biogtr => dreem}/datasets/microscopy_dataset.py (98%) rename {biogtr => dreem}/datasets/sleap_dataset.py (99%) rename {biogtr => dreem}/datasets/tracking_dataset.py (95%) rename {biogtr => dreem}/inference/__init__.py (100%) rename {biogtr => dreem}/inference/boxes.py (100%) rename {biogtr => dreem}/inference/configs/inference.yaml (100%) rename {biogtr => dreem}/inference/metrics.py (96%) rename {biogtr => dreem}/inference/post_processing.py (99%) rename {biogtr => dreem}/inference/track.py (95%) rename {biogtr => dreem}/inference/track_queue.py (99%) rename {biogtr => dreem}/inference/tracker.py (98%) create mode 100644 dreem/io/__init__.py rename {biogtr => dreem}/io/association_matrix.py (99%) rename {biogtr => dreem}/io/config.py (96%) rename {biogtr => dreem}/io/frame.py (99%) rename {biogtr => dreem}/io/instance.py (99%) rename {biogtr => dreem}/io/track.py (100%) rename {biogtr => dreem}/io/visualize.py (100%) rename {biogtr => dreem}/models/__init__.py (100%) rename {biogtr => dreem}/models/attention_head.py (97%) rename {biogtr => dreem}/models/embedding.py (99%) rename {biogtr => dreem}/models/global_tracking_transformer.py (95%) rename {biogtr => dreem}/models/gtr_runner.py (91%) rename {biogtr => dreem}/models/mlp.py (100%) rename {biogtr => dreem}/models/model_utils.py (97%) rename {biogtr => dreem}/models/transformer.py (97%) rename {biogtr => dreem}/models/visual_encoder.py (100%) rename {biogtr => dreem}/training/__init__.py (100%) rename {biogtr => dreem}/training/configs/base.yaml (100%) rename {biogtr => dreem}/training/configs/params.yaml (100%) rename {biogtr => dreem}/training/configs/test_batch_train.csv (100%) rename {biogtr => dreem}/training/losses.py (99%) rename {biogtr => dreem}/training/train.py (96%) rename {biogtr => dreem}/version.py (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 30a60b49..169c9e23 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,7 +5,7 @@ on: pull_request: types: [opened, reopened, synchronize] paths: - - "biogtr/**" + - "dreem/**" - "tests/**" - ".github/workflows/ci.yml" - "environment_cpu.yml" @@ -14,7 +14,7 @@ on: branches: - main paths: - - "biogtr/**" + - "dreem/**" - "tests/**" - ".github/workflows/ci.yml" - "environment_cpu.yml" @@ -53,11 +53,11 @@ jobs: - name: Run Black run: | - black --check biogtr tests + black --check dreem tests - name: Run pydocstyle run: | - pydocstyle --convention=google biogtr/ + pydocstyle --convention=google dreem/ # Tests with pytest tests: @@ -105,7 +105,7 @@ jobs: if: ${{ startsWith(matrix.os, 'ubuntu') && matrix.python == 3.9 }} shell: bash -l {0} run: | - pytest --cov=biogtr --cov-report=xml tests/ + pytest --cov=dreem --cov-report=xml tests/ - name: Upload coverage uses: codecov/codecov-action@v3 diff --git a/.gitignore b/.gitignore index e9fdce8a..b69e7d35 100644 --- a/.gitignore +++ b/.gitignore @@ -137,5 +137,5 @@ logs # vscode .vscode -biogtr/training/.hydra/* -biogtr/training/models/* +dreem/training/.hydra/* +dreem/training/models/* diff --git a/README.md b/README.md index e8aa69c5..495eedc3 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,22 @@ -# BioGTR +# DREEM Reconstructs Every Entities' Motion Global Tracking Transformers for biological multi-object tracking. ## Installation ### Development 1. Clone the repository: ``` -git clone https://github.com/talmolab/biogtr && cd biogtr +git clone https://github.com/talmolab/dreem && cd dreem ``` 2. Set up in a new conda environment: ``` -conda env create -y -f environment.yml && conda activate biogtr +conda env create -y -f environment.yml && conda activate dreem ``` ### Uninstalling ``` -conda env remove -n biogtr +conda env remove -n dreem ``` \ No newline at end of file diff --git a/biogtr/__init__.py b/biogtr/__init__.py deleted file mode 100644 index e4c823ef..00000000 --- a/biogtr/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Top-level package for BioGTR.""" - -from biogtr.version import __version__ - -from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer -from biogtr.models.gtr_runner import GTRRunner -from biogtr.models.transformer import Transformer -from biogtr.models.visual_encoder import VisualEncoder - -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance -from biogtr.io.association_matrix import AssociationMatrix -from biogtr.io.config import Config -from biogtr.io.visualize import annotate_video - -# from .training import run - -from biogtr.inference.tracker import Tracker diff --git a/biogtr/cli.py b/biogtr/cli.py deleted file mode 100644 index 31db5230..00000000 --- a/biogtr/cli.py +++ /dev/null @@ -1 +0,0 @@ -"""This module contains the command line interfaces for the biogtr package.""" diff --git a/biogtr/io/__init__.py b/biogtr/io/__init__.py deleted file mode 100644 index 0cda02ae..00000000 --- a/biogtr/io/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Module containing input/output data structures for easy storage and manipulation.""" - -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance -from biogtr.io.association_matrix import AssociationMatrix -from biogtr.io.track import Track -from biogtr.io.config import Config diff --git a/dreem/__init__.py b/dreem/__init__.py new file mode 100644 index 00000000..5299cbe3 --- /dev/null +++ b/dreem/__init__.py @@ -0,0 +1,18 @@ +"""Top-level package for dreem.""" + +from dreem.version import __version__ + +from dreem.models.global_tracking_transformer import GlobalTrackingTransformer +from dreem.models.gtr_runner import GTRRunner +from dreem.models.transformer import Transformer +from dreem.models.visual_encoder import VisualEncoder + +from dreem.io.frame import Frame +from dreem.io.instance import Instance +from dreem.io.association_matrix import AssociationMatrix +from dreem.io.config import Config +from dreem.io.visualize import annotate_video + +# from .training import run + +from dreem.inference.tracker import Tracker diff --git a/dreem/cli.py b/dreem/cli.py new file mode 100644 index 00000000..1755475c --- /dev/null +++ b/dreem/cli.py @@ -0,0 +1 @@ +"""This module contains the command line interfaces for the dreem package.""" diff --git a/biogtr/datasets/__init__.py b/dreem/datasets/__init__.py similarity index 100% rename from biogtr/datasets/__init__.py rename to dreem/datasets/__init__.py diff --git a/biogtr/datasets/base_dataset.py b/dreem/datasets/base_dataset.py similarity index 98% rename from biogtr/datasets/base_dataset.py rename to dreem/datasets/base_dataset.py index 15b87d45..8dd6af4b 100644 --- a/biogtr/datasets/base_dataset.py +++ b/dreem/datasets/base_dataset.py @@ -1,7 +1,7 @@ """Module containing logic for loading datasets.""" -from biogtr.datasets import data_utils -from biogtr.io import Frame +from dreem.datasets import data_utils +from dreem.io import Frame from torch.utils.data import Dataset from typing import List, Union import numpy as np diff --git a/biogtr/datasets/cell_tracking_dataset.py b/dreem/datasets/cell_tracking_dataset.py similarity index 97% rename from biogtr/datasets/cell_tracking_dataset.py rename to dreem/datasets/cell_tracking_dataset.py index 9567de46..74e7182a 100644 --- a/biogtr/datasets/cell_tracking_dataset.py +++ b/dreem/datasets/cell_tracking_dataset.py @@ -1,8 +1,8 @@ """Module containing cell tracking challenge dataset.""" from PIL import Image -from biogtr.datasets import data_utils, BaseDataset -from biogtr.io import Frame, Instance +from dreem.datasets import data_utils, BaseDataset +from dreem.io import Frame, Instance from scipy.ndimage import measurements from typing import List, Optional, Union import albumentations as A @@ -122,7 +122,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Fram Returns: a list of Frame objects containing frame metadata and Instance Objects. - See `biogtr.io.data_structures` for more info. + See `dreem.io.data_structures` for more info. """ image = self.videos[label_idx] gt = self.labels[label_idx] diff --git a/biogtr/datasets/data_utils.py b/dreem/datasets/data_utils.py similarity index 100% rename from biogtr/datasets/data_utils.py rename to dreem/datasets/data_utils.py diff --git a/biogtr/datasets/eval_dataset.py b/dreem/datasets/eval_dataset.py similarity index 98% rename from biogtr/datasets/eval_dataset.py rename to dreem/datasets/eval_dataset.py index 95836a51..c8c5c63a 100644 --- a/biogtr/datasets/eval_dataset.py +++ b/dreem/datasets/eval_dataset.py @@ -1,7 +1,7 @@ """Module containing wrapper for merging gt and pred datasets for evaluation.""" from torch.utils.data import Dataset -from biogtr.io import Instance, Frame +from dreem.io import Instance, Frame from typing import List diff --git a/biogtr/datasets/microscopy_dataset.py b/dreem/datasets/microscopy_dataset.py similarity index 98% rename from biogtr/datasets/microscopy_dataset.py rename to dreem/datasets/microscopy_dataset.py index 9656d19d..484c453d 100644 --- a/biogtr/datasets/microscopy_dataset.py +++ b/dreem/datasets/microscopy_dataset.py @@ -1,8 +1,8 @@ """Module containing microscopy dataset.""" from PIL import Image -from biogtr.datasets import data_utils, BaseDataset -from biogtr.io import Instance, Frame +from dreem.datasets import data_utils, BaseDataset +from dreem.io import Instance, Frame from typing import Union import albumentations as A import numpy as np @@ -122,7 +122,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram frame_idx: index of the frames Returns: - A list of Frames containing Instances to be tracked (See `biogtr.io.data_structures for more info`) + A list of Frames containing Instances to be tracked (See `dreem.io.data_structures for more info`) """ labels = self.labels[label_idx] labels = labels.dropna(how="all") diff --git a/biogtr/datasets/sleap_dataset.py b/dreem/datasets/sleap_dataset.py similarity index 99% rename from biogtr/datasets/sleap_dataset.py rename to dreem/datasets/sleap_dataset.py index b23b4e70..3538dc60 100644 --- a/biogtr/datasets/sleap_dataset.py +++ b/dreem/datasets/sleap_dataset.py @@ -7,8 +7,8 @@ import sleap_io as sio import random import warnings -from biogtr.io import Instance, Frame -from biogtr.datasets import data_utils, BaseDataset +from dreem.io import Instance, Frame +from dreem.datasets import data_utils, BaseDataset from torchvision.transforms import functional as tvf from typing import List, Union diff --git a/biogtr/datasets/tracking_dataset.py b/dreem/datasets/tracking_dataset.py similarity index 95% rename from biogtr/datasets/tracking_dataset.py rename to dreem/datasets/tracking_dataset.py index fdc54cac..960bf2d1 100644 --- a/biogtr/datasets/tracking_dataset.py +++ b/dreem/datasets/tracking_dataset.py @@ -1,8 +1,8 @@ """Module containing Lightning module wrapper around all other datasets.""" -from biogtr.datasets.cell_tracking_dataset import CellTrackingDataset -from biogtr.datasets.microscopy_dataset import MicroscopyDataset -from biogtr.datasets.sleap_dataset import SleapDataset +from dreem.datasets.cell_tracking_dataset import CellTrackingDataset +from dreem.datasets.microscopy_dataset import MicroscopyDataset +from dreem.datasets.sleap_dataset import SleapDataset from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from typing import Union diff --git a/biogtr/inference/__init__.py b/dreem/inference/__init__.py similarity index 100% rename from biogtr/inference/__init__.py rename to dreem/inference/__init__.py diff --git a/biogtr/inference/boxes.py b/dreem/inference/boxes.py similarity index 100% rename from biogtr/inference/boxes.py rename to dreem/inference/boxes.py diff --git a/biogtr/inference/configs/inference.yaml b/dreem/inference/configs/inference.yaml similarity index 100% rename from biogtr/inference/configs/inference.yaml rename to dreem/inference/configs/inference.yaml diff --git a/biogtr/inference/metrics.py b/dreem/inference/metrics.py similarity index 96% rename from biogtr/inference/metrics.py rename to dreem/inference/metrics.py index c80c15c3..935df49b 100644 --- a/biogtr/inference/metrics.py +++ b/dreem/inference/metrics.py @@ -5,11 +5,11 @@ import torch from typing import Union, Iterable -# from biogtr.inference.post_processing import _pairwise_iou -# from biogtr.inference.boxes import Boxes +# from dreem.inference.post_processing import _pairwise_iou +# from dreem.inference.boxes import Boxes -def get_matches(frames: list["biogtr.io.Frame"]) -> tuple[dict, list, int]: +def get_matches(frames: list["dreem.io.Frame"]) -> tuple[dict, list, int]: """Get comparison between predicted and gt trajectory labels. Args: @@ -100,11 +100,11 @@ def get_switch_count(switches: dict) -> int: return sw_cnt -def to_track_eval(frames: list["biogtr.io.Frame"]) -> dict: +def to_track_eval(frames: list["dreem.io.Frame"]) -> dict: """Reformats frames the output from `sliding_inference` to be used by `TrackEval`. Args: - frames: A list of Frames. `See biogtr.io.data_structures for more info`. + frames: A list of Frames. `See dreem.io.data_structures for more info`. Returns: data: A dictionary. Example provided below. diff --git a/biogtr/inference/post_processing.py b/dreem/inference/post_processing.py similarity index 99% rename from biogtr/inference/post_processing.py rename to dreem/inference/post_processing.py index 1aaf21cc..e4db20bc 100644 --- a/biogtr/inference/post_processing.py +++ b/dreem/inference/post_processing.py @@ -1,7 +1,7 @@ """Helper functions for post-processing association matrix pre-tracking.""" import torch -from biogtr.inference.boxes import Boxes +from dreem.inference.boxes import Boxes def weight_decay_time( diff --git a/biogtr/inference/track.py b/dreem/inference/track.py similarity index 95% rename from biogtr/inference/track.py rename to dreem/inference/track.py index aa766d5d..07b8464b 100644 --- a/biogtr/inference/track.py +++ b/dreem/inference/track.py @@ -1,7 +1,7 @@ """Script to run inference and get out tracks.""" -from biogtr.io import Config -from biogtr.models import GTRRunner +from dreem.io import Config +from dreem.models import GTRRunner from omegaconf import DictConfig from pathlib import Path from pprint import pprint @@ -14,7 +14,7 @@ import sleap_io as sio -def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = None): +def export_trajectories(frames_pred: list["dreem.io.Frame"], save_path: str = None): """Convert trajectories to data frame and save as .csv. Args: @@ -132,7 +132,7 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: for i, pred in preds.items(): outpath = os.path.join( outdir, - f"{Path(dataloader.dataset.label_files[i]).stem}.biogtr_inference.v{run_num}.slp", + f"{Path(dataloader.dataset.label_files[i]).stem}.dreem_inference.v{run_num}.slp", ) if os.path.exists(outpath): run_num += 1 diff --git a/biogtr/inference/track_queue.py b/dreem/inference/track_queue.py similarity index 99% rename from biogtr/inference/track_queue.py rename to dreem/inference/track_queue.py index 739869fb..54927ac0 100644 --- a/biogtr/inference/track_queue.py +++ b/dreem/inference/track_queue.py @@ -1,7 +1,7 @@ """Module handling sliding window tracking.""" import warnings -from biogtr.io import Frame +from dreem.io import Frame from collections import deque import numpy as np diff --git a/biogtr/inference/tracker.py b/dreem/inference/tracker.py similarity index 98% rename from biogtr/inference/tracker.py rename to dreem/inference/tracker.py index 4aa36f39..93bcad8e 100644 --- a/biogtr/inference/tracker.py +++ b/dreem/inference/tracker.py @@ -3,11 +3,11 @@ import torch import pandas as pd import warnings -from biogtr.io import Frame -from biogtr.models import model_utils, GlobalTrackingTransformer -from biogtr.inference.track_queue import TrackQueue -from biogtr.inference import post_processing -from biogtr.inference.boxes import Boxes +from dreem.io import Frame +from dreem.models import model_utils, GlobalTrackingTransformer +from dreem.inference.track_queue import TrackQueue +from dreem.inference import post_processing +from dreem.inference.boxes import Boxes from scipy.optimize import linear_sum_assignment from math import inf @@ -127,7 +127,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame Args: model: the pretrained GlobalTrackingTransformer to be used for inference - frames: A list of Frames (See `biogtr.io.Frame` for more info). + frames: A list of Frames (See `dreem.io.Frame` for more info). Returns: Frames: A list of Frames populated with pred_track_ids and asso_matrices @@ -207,7 +207,7 @@ def _run_global_tracker( Args: model: the pretrained GlobalTrackingTransformer to be used for inference - frames: A list of Frames containing reid features. See `biogtr.io.data_structures` for more info. + frames: A list of Frames containing reid features. See `dreem.io.data_structures` for more info. query_ind: An integer for the query frame within the window of instances. Returns: diff --git a/dreem/io/__init__.py b/dreem/io/__init__.py new file mode 100644 index 00000000..5bd23340 --- /dev/null +++ b/dreem/io/__init__.py @@ -0,0 +1,7 @@ +"""Module containing input/output data structures for easy storage and manipulation.""" + +from dreem.io.frame import Frame +from dreem.io.instance import Instance +from dreem.io.association_matrix import AssociationMatrix +from dreem.io.track import Track +from dreem.io.config import Config diff --git a/biogtr/io/association_matrix.py b/dreem/io/association_matrix.py similarity index 99% rename from biogtr/io/association_matrix.py rename to dreem/io/association_matrix.py index 9d6d366e..84aee035 100644 --- a/biogtr/io/association_matrix.py +++ b/dreem/io/association_matrix.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd import attrs -from biogtr.io import Instance +from dreem.io import Instance from typing import Union diff --git a/biogtr/io/config.py b/dreem/io/config.py similarity index 96% rename from biogtr/io/config.py rename to dreem/io/config.py index 7ea8a0ac..7cd1e6de 100644 --- a/biogtr/io/config.py +++ b/dreem/io/config.py @@ -85,7 +85,7 @@ def get_model(self) -> "GlobalTrackingTransformer": Returns: A global tracking transformer with parameters indicated by cfg """ - from biogtr.models import GlobalTrackingTransformer + from dreem.models import GlobalTrackingTransformer model_params = self.cfg.model ckpt_path = model_params.pop("ckpt_path", None) @@ -109,7 +109,7 @@ def get_tracker_cfg(self) -> dict: def get_gtr_runner(self) -> "GTRRunner": """Get lightning module for training, validation, and inference.""" - from biogtr.models import GTRRunner + from dreem.models import GTRRunner tracker_params = self.cfg.tracker optimizer_params = self.cfg.optimizer @@ -174,7 +174,7 @@ def get_dataset( Returns: Either a `SleapDataset` or `MicroscopyDataset` with params indicated by cfg """ - from biogtr.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset + from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset if mode.lower() == "train": dataset_params = self.cfg.dataset.train_dataset @@ -276,7 +276,7 @@ def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer: Returns: A torch Optimizer with specified params """ - from biogtr.models.model_utils import init_optimizer + from dreem.models.model_utils import init_optimizer optimizer_params = self.cfg.optimizer @@ -293,19 +293,19 @@ def get_scheduler( Returns: A torch learning rate scheduler with specified params """ - from biogtr.models.model_utils import init_scheduler + from dreem.models.model_utils import init_scheduler lr_scheduler_params = self.cfg.scheduler return init_scheduler(optimizer, lr_scheduler_params) - def get_loss(self) -> "biogtr.training.losses.AssoLoss": + def get_loss(self) -> "dreem.training.losses.AssoLoss": """Getter for loss functions. Returns: An AssoLoss with specified params """ - from biogtr.training.losses import AssoLoss + from dreem.training.losses import AssoLoss loss_params = self.cfg.loss @@ -317,7 +317,7 @@ def get_logger(self): Returns: A Logger with specified params """ - from biogtr.models.model_utils import init_logger + from dreem.models.model_utils import init_logger logger_params = OmegaConf.to_container(self.cfg.logging, resolve=True) diff --git a/biogtr/io/frame.py b/dreem/io/frame.py similarity index 99% rename from biogtr/io/frame.py rename to dreem/io/frame.py index 5607e832..67f6dc95 100644 --- a/biogtr/io/frame.py +++ b/dreem/io/frame.py @@ -129,15 +129,15 @@ def from_slp( device: str = None, **kwargs, ) -> "Frame": - """Convert `sio.LabeledFrame` to `biogtr.io.Frame`. + """Convert `sio.LabeledFrame` to `dreem.io.Frame`. Args: lf: A sio.LabeledFrame object Returns: - A biogtr.io.Frame object + A dreem.io.Frame object """ - from biogtr.io import Instance + from dreem.io import Instance img_shape = lf.image.shape if len(img_shape) == 2: diff --git a/biogtr/io/instance.py b/dreem/io/instance.py similarity index 99% rename from biogtr/io/instance.py rename to dreem/io/instance.py index 5ffef867..64e1af7d 100644 --- a/biogtr/io/instance.py +++ b/dreem/io/instance.py @@ -159,7 +159,7 @@ def from_slp( crop: ArrayLike = None, device: str = None, ) -> None: - """Convert a slp instance to a biogtr instance. + """Convert a slp instance to a dreem instance. Args: slp_instance: A `sleap_io.Instance` object representing a detection @@ -167,7 +167,7 @@ def from_slp( crop: The corresponding crop of the bbox device: which device to keep the instance on Returns: - A biogtr.Instance object with a pose-centered bbox and no crop. + A dreem.Instance object with a pose-centered bbox and no crop. """ try: track_id = int(slp_instance.track.name) diff --git a/biogtr/io/track.py b/dreem/io/track.py similarity index 100% rename from biogtr/io/track.py rename to dreem/io/track.py diff --git a/biogtr/io/visualize.py b/dreem/io/visualize.py similarity index 100% rename from biogtr/io/visualize.py rename to dreem/io/visualize.py diff --git a/biogtr/models/__init__.py b/dreem/models/__init__.py similarity index 100% rename from biogtr/models/__init__.py rename to dreem/models/__init__.py diff --git a/biogtr/models/attention_head.py b/dreem/models/attention_head.py similarity index 97% rename from biogtr/models/attention_head.py rename to dreem/models/attention_head.py index ed8c6f50..2b160552 100644 --- a/biogtr/models/attention_head.py +++ b/dreem/models/attention_head.py @@ -1,7 +1,7 @@ """Module containing different components of multi-head attention heads.""" import torch -from biogtr.models.mlp import MLP +from dreem.models.mlp import MLP # todo: add named tensors diff --git a/biogtr/models/embedding.py b/dreem/models/embedding.py similarity index 99% rename from biogtr/models/embedding.py rename to dreem/models/embedding.py index 222d8585..cf9a4f6f 100644 --- a/biogtr/models/embedding.py +++ b/dreem/models/embedding.py @@ -3,7 +3,7 @@ from typing import Tuple, Optional import math import torch -from biogtr.models.mlp import MLP +from dreem.models.mlp import MLP # todo: add named tensors, clean variable names diff --git a/biogtr/models/global_tracking_transformer.py b/dreem/models/global_tracking_transformer.py similarity index 95% rename from biogtr/models/global_tracking_transformer.py rename to dreem/models/global_tracking_transformer.py index c746d1aa..6114080e 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/dreem/models/global_tracking_transformer.py @@ -1,7 +1,7 @@ """Module containing GTR model used for training.""" -from biogtr.models import Transformer -from biogtr.models import VisualEncoder +from dreem.models import Transformer +from dreem.models import VisualEncoder import torch # todo: do we want to handle params with configs already here? @@ -51,8 +51,8 @@ def __init__( that no positional embeddings should be used. To use the positional embeddings pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie: {"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'}, - "temp": {'mode': 'learned', 'emb_num': 16}}. (see `biogtr.models.embeddings.Embedding.EMB_TYPES` - and `biogtr.models.embeddings.Embedding.EMB_MODES` for embedding parameters). + "temp": {'mode': 'learned', 'emb_num': 16}}. (see `dreem.models.embeddings.Embedding.EMB_TYPES` + and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters). """ super().__init__() diff --git a/biogtr/models/gtr_runner.py b/dreem/models/gtr_runner.py similarity index 91% rename from biogtr/models/gtr_runner.py rename to dreem/models/gtr_runner.py index 7dbf4b2b..35965674 100644 --- a/biogtr/models/gtr_runner.py +++ b/dreem/models/gtr_runner.py @@ -2,14 +2,14 @@ import torch import gc -from biogtr.inference import Tracker -from biogtr.inference import metrics -from biogtr.models import GlobalTrackingTransformer -from biogtr.training.losses import AssoLoss -from biogtr.models.model_utils import init_optimizer, init_scheduler +from dreem.inference import Tracker +from dreem.inference import metrics +from dreem.models import GlobalTrackingTransformer +from dreem.training.losses import AssoLoss +from dreem.models.model_utils import init_optimizer, init_scheduler from pytorch_lightning import LightningModule -from biogtr.io.frame import Frame -from biogtr.io.instance import Instance +from dreem.io.frame import Frame +from dreem.io.instance import Instance class GTRRunner(LightningModule): @@ -75,8 +75,8 @@ def __init__( def forward( self, - ref_instances: list["biogtr.io.Instance"], - query_instances: list["biogtr.io.Instance"] = None, + ref_instances: list["dreem.io.Instance"], + query_instances: list["dreem.io.Instance"] = None, ) -> torch.Tensor: """Execute forward pass of the lightning module. @@ -91,7 +91,7 @@ def forward( return asso_preds def training_step( - self, train_batch: list[list["biogtr.io.Frame"]], batch_idx: int + self, train_batch: list[list["dreem.io.Frame"]], batch_idx: int ) -> dict[str, float]: """Execute single training step for model. @@ -109,7 +109,7 @@ def training_step( return result def validation_step( - self, val_batch: list[list["biogtr.io.Frame"]], batch_idx: int + self, val_batch: list[list["dreem.io.Frame"]], batch_idx: int ) -> dict[str, float]: """Execute single val step for model. @@ -127,7 +127,7 @@ def validation_step( return result def test_step( - self, test_batch: list[list["biogtr.io.Frame"]], batch_idx: int + self, test_batch: list[list["dreem.io.Frame"]], batch_idx: int ) -> dict[str, float]: """Execute single test step for model. @@ -145,8 +145,8 @@ def test_step( return result def predict_step( - self, batch: list[list["biogtr.io.Frame"]], batch_idx: int - ) -> list["biogtr.io.Frame"]: + self, batch: list[list["dreem.io.Frame"]], batch_idx: int + ) -> list["dreem.io.Frame"]: """Run inference for model. Computes association + assignment. @@ -163,7 +163,7 @@ def predict_step( return frames_pred def _shared_eval_step( - self, frames: list["biogtr.io.Frame"], mode: str + self, frames: list["dreem.io.Frame"], mode: str ) -> dict[str, float]: """Run evaluation used by train, test, and val steps. diff --git a/biogtr/models/mlp.py b/dreem/models/mlp.py similarity index 100% rename from biogtr/models/mlp.py rename to dreem/models/mlp.py diff --git a/biogtr/models/model_utils.py b/dreem/models/model_utils.py similarity index 97% rename from biogtr/models/model_utils.py rename to dreem/models/model_utils.py index a2885a0f..249170f4 100644 --- a/biogtr/models/model_utils.py +++ b/dreem/models/model_utils.py @@ -5,7 +5,7 @@ import torch -def get_boxes(instances: List["biogtr.io.Instance"]) -> torch.tensor: +def get_boxes(instances: List["dreem.io.Instance"]) -> torch.tensor: """Extract the bounding boxes from the input list of instances. Args: @@ -29,8 +29,8 @@ def get_boxes(instances: List["biogtr.io.Instance"]) -> torch.tensor: def get_times( - ref_instances: list["biogtr.io.Instance"], - query_instances: list["biogtr.io.Instance"] = None, + ref_instances: list["dreem.io.Instance"], + query_instances: list["dreem.io.Instance"] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Extract the time indices of each instance relative to the window length. diff --git a/biogtr/models/transformer.py b/dreem/models/transformer.py similarity index 97% rename from biogtr/models/transformer.py rename to dreem/models/transformer.py index 75579da5..1c4ff019 100644 --- a/biogtr/models/transformer.py +++ b/dreem/models/transformer.py @@ -11,10 +11,10 @@ * added fixed embeddings over boxes """ -from biogtr.io import AssociationMatrix -from biogtr.models.attention_head import ATTWeightHead -from biogtr.models import Embedding -from biogtr.models.model_utils import get_boxes, get_times +from dreem.io import AssociationMatrix +from dreem.models.attention_head import ATTWeightHead +from dreem.models import Embedding +from dreem.models.model_utils import get_boxes, get_times from torch import nn import copy import torch @@ -65,8 +65,8 @@ def __init__( that no positional embeddings should be used. To use the positional embeddings pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie: {"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'}, - "temp": {'mode': 'learned', 'emb_num': 16}}. (see `biogtr.models.embeddings.Embedding.EMB_TYPES` - and `biogtr.models.embeddings.Embedding.EMB_MODES` for embedding parameters). + "temp": {'mode': 'learned', 'emb_num': 16}}. (see `dreem.models.embeddings.Embedding.EMB_TYPES` + and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters). """ super().__init__() @@ -141,13 +141,13 @@ def _reset_parameters(self): def forward( self, - ref_instances: list["biogtr.io.Instance"], - query_instances: list["biogtr.io.Instance"] = None, + ref_instances: list["dreem.io.Instance"], + query_instances: list["dreem.io.Instance"] = None, ) -> list[AssociationMatrix]: """Execute a forward pass through the transformer and attention head. Args: - ref instances: A list of instance objects (See `biogtr.io.Instance` for more info.) + ref instances: A list of instance objects (See `dreem.io.Instance` for more info.) query_instances: An set of instances to be used as decoder queries. Returns: diff --git a/biogtr/models/visual_encoder.py b/dreem/models/visual_encoder.py similarity index 100% rename from biogtr/models/visual_encoder.py rename to dreem/models/visual_encoder.py diff --git a/biogtr/training/__init__.py b/dreem/training/__init__.py similarity index 100% rename from biogtr/training/__init__.py rename to dreem/training/__init__.py diff --git a/biogtr/training/configs/base.yaml b/dreem/training/configs/base.yaml similarity index 100% rename from biogtr/training/configs/base.yaml rename to dreem/training/configs/base.yaml diff --git a/biogtr/training/configs/params.yaml b/dreem/training/configs/params.yaml similarity index 100% rename from biogtr/training/configs/params.yaml rename to dreem/training/configs/params.yaml diff --git a/biogtr/training/configs/test_batch_train.csv b/dreem/training/configs/test_batch_train.csv similarity index 100% rename from biogtr/training/configs/test_batch_train.csv rename to dreem/training/configs/test_batch_train.csv diff --git a/biogtr/training/losses.py b/dreem/training/losses.py similarity index 99% rename from biogtr/training/losses.py rename to dreem/training/losses.py index ff3e6eca..53b4289c 100644 --- a/biogtr/training/losses.py +++ b/dreem/training/losses.py @@ -1,6 +1,6 @@ """Module containing different loss functions to be optimized.""" -from biogtr.models.model_utils import get_boxes, get_times +from dreem.models.model_utils import get_boxes, get_times from torch import nn from typing import List, Tuple import torch diff --git a/biogtr/training/train.py b/dreem/training/train.py similarity index 96% rename from biogtr/training/train.py rename to dreem/training/train.py index e252d3f5..a2522817 100644 --- a/biogtr/training/train.py +++ b/dreem/training/train.py @@ -3,9 +3,9 @@ Used for training a single model or deploying a batch train job on RUNAI CLI """ -from biogtr.io import Config -from biogtr.datasets import TrackingDataset -from biogtr.datasets.data_utils import view_training_batch +from dreem.io import Config +from dreem.datasets import TrackingDataset +from dreem.datasets.data_utils import view_training_batch from multiprocessing import cpu_count from omegaconf import DictConfig from pprint import pprint diff --git a/biogtr/version.py b/dreem/version.py similarity index 100% rename from biogtr/version.py rename to dreem/version.py diff --git a/environment.yml b/environment.yml index da26b24b..0b129889 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: biogtr +name: dreem channels: - pytorch diff --git a/environment_cpu.yml b/environment_cpu.yml index 6714b9c1..bc63a14a 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -1,4 +1,4 @@ -name: biogtr +name: dreem channels: - pytorch diff --git a/environment_osx-arm64.yml b/environment_osx-arm64.yml index 124c4576..0393cb66 100644 --- a/environment_osx-arm64.yml +++ b/environment_osx-arm64.yml @@ -1,4 +1,4 @@ -name: biogtr +name: dreem channels: - pytorch diff --git a/pyproject.toml b/pyproject.toml index eeb8b59a..9b13894f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools", "setuptools-scm"] build-backend = "setuptools.build_meta" [project] -name = "biogtr" +name = "dreem" authors = [ {name = "Arlo Sheridan", email = "asheridan@salk.edu"}, {name = "Aaditya Prasad", email = "aprasad@salk.edu"}, @@ -33,7 +33,7 @@ dependencies = [ dynamic = ["version", "readme"] [tool.setuptools.dynamic] -version = {attr = "biogtr.version.__version__"} +version = {attr = "dreem.version.__version__"} readme = {file = ["README.md"]} [project.optional-dependencies] @@ -48,11 +48,11 @@ dev = [ ] [project.scripts] -biogtr = "biogtr.cli:cli" +dreem = "dreem.cli:cli" [project.urls] -Homepage = "https://github.com/talmolab/biogtr" -Repository = "https://github.com/talmolab/biogtr" +Homepage = "https://github.com/talmolab/dreem" +Repository = "https://github.com/talmolab/dreem" [tool.black] line-length = 88 diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index db574099..f28e0950 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -1,4 +1,4 @@ -"""Fixtures for testing biogtr.""" +"""Fixtures for testing dreem.""" import pytest from pathlib import Path diff --git a/tests/test_config.py b/tests/test_config.py index 4f7ebbc7..0b1c8267 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,8 +1,8 @@ """Tests for `config.py`""" from omegaconf import OmegaConf -from biogtr.io import Config -from biogtr.models import GlobalTrackingTransformer, GTRRunner +from dreem.io import Config +from dreem.models import GlobalTrackingTransformer, GTRRunner import torch diff --git a/tests/test_data_model.py b/tests/test_data_model.py index ef5d0320..aeeb09ca 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -1,6 +1,6 @@ """Tests for Instance, Frame, and AssociationMatrix Objects""" -from biogtr.io import Frame, Instance, AssociationMatrix, Track +from dreem.io import Frame, Instance, AssociationMatrix, Track import torch import pytest import numpy as np diff --git a/tests/test_datasets.py b/tests/test_datasets.py index ab9d7640..2287c4f9 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,14 +1,14 @@ """Test dataset logic.""" -from biogtr.datasets import ( +from dreem.datasets import ( BaseDataset, MicroscopyDataset, SleapDataset, CellTrackingDataset, TrackingDataset, ) -from biogtr.datasets.data_utils import get_max_padding, NodeDropout -from biogtr.models.model_utils import get_device +from dreem.datasets.data_utils import get_max_padding, NodeDropout +from dreem.models.model_utils import get_device from torch.utils.data import DataLoader import pytest import torch diff --git a/tests/test_inference.py b/tests/test_inference.py index a5ef05f9..06b7154b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -5,11 +5,11 @@ import numpy as np from pytorch_lightning import Trainer from omegaconf import OmegaConf, DictConfig -from biogtr.io import Frame, Instance, Config -from biogtr.models import GTRRunner, GlobalTrackingTransformer -from biogtr.inference import Tracker, post_processing, metrics -from biogtr.inference.track_queue import TrackQueue -from biogtr.inference.track import run +from dreem.io import Frame, Instance, Config +from dreem.models import GTRRunner, GlobalTrackingTransformer +from dreem.inference import Tracker, post_processing, metrics +from dreem.inference.track_queue import TrackQueue +from dreem.inference.track import run def test_track_queue(): diff --git a/tests/test_models.py b/tests/test_models.py index bf4e8e47..3eaf9c22 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,16 +2,16 @@ import pytest import torch -from biogtr.io import Frame, Instance -from biogtr.models.mlp import MLP -from biogtr.models.attention_head import ATTWeightHead -from biogtr.models import ( +from dreem.io import Frame, Instance +from dreem.models.mlp import MLP +from dreem.models.attention_head import ATTWeightHead +from dreem.models import ( Embedding, VisualEncoder, Transformer, GlobalTrackingTransformer, ) -from biogtr.models.transformer import ( +from dreem.models.transformer import ( TransformerEncoderLayer, TransformerDecoderLayer, ) diff --git a/tests/test_training.py b/tests/test_training.py index 357ce267..bd8bbe75 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -3,11 +3,11 @@ import os import pytest import torch -from biogtr.io import Frame, Instance, Config -from biogtr.training.losses import AssoLoss -from biogtr.models import GTRRunner +from dreem.io import Frame, Instance, Config +from dreem.training.losses import AssoLoss +from dreem.models import GTRRunner from omegaconf import OmegaConf, DictConfig -from biogtr.training.train import run +from dreem.training.train import run # TODO: add named tensor tests # TODO: use temp dir and cleanup after tests (https://docs.pytest.org/en/7.1.x/how-to/tmp_path.html) diff --git a/tests/test_version.py b/tests/test_version.py index 6bde7e48..2d43f00b 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,8 +1,8 @@ """Test version.""" -import biogtr +import dreem def test_version(): """Test version.""" - assert biogtr.__version__ == biogtr.version.__version__ + assert dreem.__version__ == dreem.version.__version__