diff --git a/biogtr/data_structures.py b/biogtr/data_structures.py index 7aa706b7..81b2f551 100644 --- a/biogtr/data_structures.py +++ b/biogtr/data_structures.py @@ -124,6 +124,8 @@ def __init__( self._device = device self.to(self._device) + self._frame = None + def __repr__(self) -> str: """Return string representation of the Instance.""" return ( @@ -421,6 +423,26 @@ def has_features(self) -> bool: else: return True + @property + def frame(self) -> "Frame": + """Get the frame the instance belongs to. + + Returns: + The back reference to the `Frame` that this `Instance` belongs to. + """ + return self._frame + + @frame.setter + def frame(self, frame: "Frame") -> None: + """Set the back reference to the `Frame` that this `Instance` belongs to. + + This field is set when instances are added to `Frame` object. + + Args: + frame: A `Frame` object containing the metadata for the frame that the instance belongs to + """ + self._frame = frame + @property def pose(self) -> dict[str, ArrayLike]: """Get the pose of the instance. @@ -580,9 +602,12 @@ def __init__( self._img_shape = img_shape else: self._img_shape = torch.tensor([img_shape]) + if instances is None: self.instances = [] else: + for instance in instances: + instance.frame = self self._instances = instances self._asso_output = asso_output @@ -612,7 +637,7 @@ def __repr__(self) -> str: f"img_shape={self._img_shape}, " f"num_detected={self.num_detected}, " f"asso_output={self._asso_output}, " - f"traj_score={self._traj_score}, " + f"traj_score={list(self._traj_score.keys())}, " f"matches={self._matches}, " f"instances={self._instances}, " f"device={self._device}" @@ -796,6 +821,9 @@ def instances(self, instances: List[Instance]) -> None: Args: instances: A list of Instances that appear in the frame. """ + for instance in instances: + instance.frame = self + self._instances = instances def has_instances(self) -> bool: diff --git a/biogtr/inference/tracker.py b/biogtr/inference/tracker.py index 84fffb73..3e3405ba 100644 --- a/biogtr/inference/tracker.py +++ b/biogtr/inference/tracker.py @@ -141,7 +141,9 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame # W: width. for batch_idx, frame_to_track in enumerate(frames): - tracked_frames = self.track_queue.collate_tracks() + tracked_frames = self.track_queue.collate_tracks( + device=frame_to_track.frame_id.device + ) if self.verbose: warnings.warn( f"Current number of tracks is {self.track_queue.n_tracks}" @@ -229,8 +231,12 @@ def _run_global_tracker( # E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window. _ = model.eval() + query_frame = frames[query_ind] + query_instances = query_frame.instances + all_instances = [instance for frame in frames for instance in frame.instances] + if self.verbose: print(f"Frame {query_frame.frame_id.item()}") @@ -253,7 +259,7 @@ def _run_global_tracker( # (L=1, n_query, total_instances) with torch.no_grad(): - asso_output, embed = model(frames, query_frame=query_ind) + asso_output, embed = model(all_instances, query_instances) # if model.transformer.return_embedding: # query_frame.embeddings = embed TODO add embedding to Instance Object # if query_frame == 1: @@ -321,6 +327,7 @@ def _run_global_tracker( ] nonquery_inds = [i for i in range(total_instances) if i not in query_inds] + # instead should we do model(nonquery_instances, query_instances)? asso_nonquery = asso_output[:, nonquery_inds] # (n_query, n_nonquery) asso_nonquery_df = pd.DataFrame( @@ -332,10 +339,9 @@ def _run_global_tracker( query_frame.add_traj_score("asso_nonquery", asso_nonquery_df) - pred_boxes, _ = model_utils.get_boxes_times(frames) + pred_boxes = model_utils.get_boxes(all_instances) query_boxes = pred_boxes[query_inds] # n_k x 4 nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4 - # TODO: Insert postprocessing. unique_ids = torch.unique(instance_ids) # (n_nonquery,) diff --git a/biogtr/models/global_tracking_transformer.py b/biogtr/models/global_tracking_transformer.py index c02aec5b..59206c17 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/biogtr/models/global_tracking_transformer.py @@ -2,13 +2,13 @@ from biogtr.models.transformer import Transformer from biogtr.models.visual_encoder import VisualEncoder -from biogtr.data_structures import Frame -from torch import nn +from biogtr.data_structures import Instance +import torch # todo: do we want to handle params with configs already here? -class GlobalTrackingTransformer(nn.Module): +class GlobalTrackingTransformer(torch.nn.Module): """Modular GTR model composed of visual encoder + transformer used for tracking.""" def __init__( @@ -79,26 +79,54 @@ def __init__( decoder_self_attn=decoder_self_attn, ) - def forward(self, frames: list[Frame], query_frame: int = None): + def forward( + self, ref_instances: list[Instance], query_instances: list[Instance] = None + ): """Execute forward pass of GTR Model to get asso matrix. Args: - frames: List of Frames from chunk containing crops of objects + gt label info - query_frame: Frame index used as query for self attention. Only used in sliding inference where query frame is the last frame in the window. + ref_instances: List of instances from chunk containing crops of objects + gt label info + query_instances: list of instances used as query in decoder. Returns: An N_T x N association matrix """ # Extract feature representations with pre-trained encoder. - for frame in filter( - lambda f: f.has_instances() and not f.has_features(), frames - ): - crops = frame.get_crops() - z = self.visual_encoder(crops) + self.extract_features(ref_instances) - for i, z_i in enumerate(z): - frame.instances[i].features = z_i + if query_instances: + self.extract_features(query_instances) - asso_preds, emb = self.transformer(frames, query_frame=query_frame) + asso_preds, emb = self.transformer(ref_instances, query_instances) return asso_preds, emb + + def extract_features( + self, instances: list["Instance"], force_recompute: bool = False + ) -> None: + """Extract features from instances using visual encoder backbone. + + Args: + instances: A list of instances to compute features for + force_recompute: indicate whether to compute features for all instances regardless of if they have instances + """ + if not force_recompute: + instances_to_compute = [ + instance + for instance in instances + if instance.has_crop() and not instance.has_features() + ] + else: + instances_to_compute = instances + + if len(instances_to_compute) == 0: + return + elif len(instances_to_compute) == 1: # handle batch norm error when B=1 + instances_to_compute = instances + + crops = torch.concatenate([instance.crop for instance in instances_to_compute]) + + features = self.visual_encoder(crops) + + for i, z_i in enumerate(features): + instances_to_compute[i].features = z_i diff --git a/biogtr/models/gtr_runner.py b/biogtr/models/gtr_runner.py index e7e6a577..98955d23 100644 --- a/biogtr/models/gtr_runner.py +++ b/biogtr/models/gtr_runner.py @@ -8,6 +8,7 @@ from biogtr.training.losses import AssoLoss from biogtr.models.model_utils import init_optimizer, init_scheduler from pytorch_lightning import LightningModule +from biogtr.data_structures import Frame, Instance class GTRRunner(LightningModule): @@ -59,28 +60,29 @@ def __init__( self.metrics = metrics self.persistent_tracking = persistent_tracking - def forward(self, instances) -> torch.Tensor: + def forward( + self, ref_instances: list[Instance], query_instances: list[Instance] = None + ) -> torch.Tensor: """Execute forward pass of the lightning module. Args: - instances: a list of dicts where each dict is a frame with gt data + ref_instances: a list of `Instance` objects containing crops and other data needed for transformer model + query_instances: a list of `Instance` objects used as queries in the decoder. Mostly used for inference. Returns: An association matrix between objects """ - if sum([frame.num_detected for frame in instances]) > 0: - asso_preds, _ = self.model(instances) - return asso_preds - return None + asso_preds, _ = self.model(ref_instances, query_instances) + return asso_preds def training_step( - self, train_batch: list[dict], batch_idx: int + self, train_batch: list[list[Frame]], batch_idx: int ) -> dict[str, float]: """Execute single training step for model. Args: - train_batch: A single batch from the dataset which is a list of dicts - with length `clip_length` where each dict is a frame + train_batch: A single batch from the dataset which is a list of `Frame` objects + with length `clip_length` containing Instances and other metadata. batch_idx: the batch number used by lightning Returns: @@ -92,13 +94,13 @@ def training_step( return result def validation_step( - self, val_batch: list[dict], batch_idx: int + self, val_batch: list[list[Frame]], batch_idx: int ) -> dict[str, float]: """Execute single val step for model. Args: - val_batch: A single batch from the dataset which is a list of dicts - with length `clip_length` where each dict is a frame + val_batch: A single batch from the dataset which is a list of `Frame` objects + with length `clip_length` containing Instances and other metadata. batch_idx: the batch number used by lightning Returns: @@ -109,12 +111,14 @@ def validation_step( return result - def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]: + def test_step( + self, test_batch: list[list[Frame]], batch_idx: int + ) -> dict[str, float]: """Execute single test step for model. Args: - val_batch: A single batch from the dataset which is a list of dicts - with length `clip_length` where each dict is a frame + test_batch: A single batch from the dataset which is a list of `Frame` objects + with length `clip_length` containing Instances and other metadata. batch_idx: the batch number used by lightning Returns: @@ -125,57 +129,57 @@ def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]: return result - def predict_step(self, batch: list[dict], batch_idx: int) -> dict: + def predict_step(self, batch: list[list[Frame]], batch_idx: int) -> list[Frame]: """Run inference for model. Computes association + assignment. Args: - batch: A single batch from the dataset which is a list of dicts - with length `clip_length` where each dict is a frame + batch: A single batch from the dataset which is a list of `Frame` objects + with length `clip_length` containing Instances and other metadata. batch_idx: the batch number used by lightning Returns: A list of dicts where each dict is a frame containing the predicted track ids """ self.tracker.persistent_tracking = True - instances_pred = self.tracker(self.model, batch[0]) - return instances_pred + frames_pred = self.tracker(self.model, batch[0]) + return frames_pred - def _shared_eval_step(self, instances, mode): + def _shared_eval_step(self, frames: list[Frame], mode: str) -> dict[str, float]: """Run evaluation used by train, test, and val steps. Args: - instances: A list of dicts where each dict is a frame containing gt data + frames: A list of `Frame` objects with length `clip_length` containing Instances and other metadata. mode: which metrics to compute and whether to use persistent tracking or not Returns: a dict containing the loss and any other metrics specified by `eval_metrics` """ try: - instances = [frame for frame in instances if frame.has_instances()] + instances = [instance for frame in frames for instance in frame.instances] + if len(instances) == 0: + return None + eval_metrics = self.metrics[mode] persistent_tracking = self.persistent_tracking[mode] logits = self(instances) - - if not logits: - return None - - loss = self.loss(logits, instances) + loss = self.loss(logits, frames) return_metrics = {"loss": loss} if eval_metrics is not None and len(eval_metrics) > 0: self.tracker.persistent_tracking = persistent_tracking - instances_pred = self.tracker(self.model, instances) - instances_mm = metrics.to_track_eval(instances_pred) - clearmot = metrics.get_pymotmetrics(instances_mm, eval_metrics) + + frames_pred = self.tracker(self.model, frames) + + frames_mm = metrics.to_track_eval(frames_pred) + clearmot = metrics.get_pymotmetrics(frames_mm, eval_metrics) + return_metrics.update(clearmot.to_dict()) - return_metrics["batch_size"] = len(instances) + return_metrics["batch_size"] = len(frames) except Exception as e: - print( - f"Failed on frame {instances[0].frame_id} of video {instances[0].video_id}" - ) + print(f"Failed on frame {frames[0].frame_id} of video {frames[0].video_id}") raise (e) return return_metrics diff --git a/biogtr/models/model_utils.py b/biogtr/models/model_utils.py index 14e2948f..d92a822b 100644 --- a/biogtr/models/model_utils.py +++ b/biogtr/models/model_utils.py @@ -2,35 +2,70 @@ from typing import List, Tuple, Iterable from pytorch_lightning import loggers -from biogtr.data_structures import Frame +from biogtr.data_structures import Instance import torch -def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]: - """Extract the bounding boxes and frame indices from the input list of instances. +def get_boxes(instances: List[Instance]) -> torch.Tensor: + """Extract the bounding boxes from the input list of instances. Args: - frames (List[Frame]): List of frame objects containing metadata and instances. + instances: List of Instance objects. Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors containing the - bounding boxes normalized by the height and width of the image - and corresponding frame indices, respectively. + An (n_instances, n_points, 4) float tensor containing the bounding boxes + normalized by the height and width of the image """ - boxes, times = [], [] - _, h, w = frames[0].img_shape.flatten() - - for fidx, frame in enumerate(frames): - bbox = frame.get_bboxes().clone() + boxes = [] + for i, instance in enumerate(instances): + _, h, w = instance.frame.img_shape.flatten() + bbox = instance.bbox.clone() bbox[:, :, [0, 2]] /= w bbox[:, :, [1, 3]] /= h - boxes.append(bbox) - times.append(torch.full((bbox.shape[0],), fidx)) boxes = torch.cat(boxes, dim=0) # N, n_anchors, 4 - times = torch.cat(times, dim=0).to(boxes.device) # N - return boxes, times + + return boxes + + +def get_times( + ref_instances: list[Instance], query_instances: list[Instance] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """Extract the time indices of each instance relative to the window length. + + Args: + ref_instances: Set of instances to query against + query_instances: Set of query instances to look up using decoder. + + Returns: + Tuple of Corresponding frame indices eg [0, 0, 1, 1, ..., T, T] for ref and query instances. + """ + try: + ref_inds = torch.concat([instance.frame.frame_id for instance in ref_instances]) + except RuntimeError as e: + print([instance.frame.frame_id.device for instance in ref_instances]) + raise (e) + if query_instances is not None: + query_inds = torch.concat( + [instance.frame.frame_id for instance in query_instances] + ) + else: + query_inds = torch.tensor([], device=ref_inds.device) + + frame_inds = torch.concat([ref_inds, query_inds]) + window_length = len(frame_inds.unique()) + + frame_idx_mapping = {frame_inds.unique()[i].item(): i for i in range(window_length)} + ref_t = torch.tensor( + [frame_idx_mapping[ind.item()] for ind in ref_inds], device=ref_inds.device + ) + + query_t = torch.tensor( + [frame_idx_mapping[ind.item()] for ind in query_inds], device=ref_inds.device + ) + + return ref_t, query_t def softmax_asso(asso_output: list[torch.Tensor]) -> list[torch.Tensor]: diff --git a/biogtr/models/transformer.py b/biogtr/models/transformer.py index 7f84222e..4951c3e5 100644 --- a/biogtr/models/transformer.py +++ b/biogtr/models/transformer.py @@ -11,10 +11,10 @@ * added fixed embeddings over boxes """ -from biogtr.data_structures import Frame +from biogtr.data_structures import Instance from biogtr.models.attention_head import ATTWeightHead from biogtr.models.embedding import Embedding -from biogtr.models.model_utils import get_boxes_times +from biogtr.models.model_utils import get_boxes, get_times from torch import nn import copy import torch @@ -140,13 +140,13 @@ def _reset_parameters(self): raise (e) def forward( - self, frames: list[Frame], query_frame: int = None + self, ref_instances: list[Instance], query_instances: list[Instance] = None ) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]: """Execute a forward pass through the transformer and attention head. Args: - frames: A list of Frames (See `biogtr.data_structures.Frame for more info.) - query_frame: An integer (k) specifying the frame within the window to be queried. + ref instances: A list of instance objects (See `biogtr.data_structures.Instance` for more info.) + query_instances: An set of instances to be used as decoder queries. Returns: asso_output: A list of torch.Tensors of shape (L, n_query, total_instances) where: @@ -156,79 +156,90 @@ def forward( embedding_dict: A dictionary containing the "pos" and "temp" embeddings if `self.return_embeddings` is False then they are None. """ - try: - reid_features = torch.cat( - [frame.get_features() for frame in frames], dim=0 - ).unsqueeze(0) - except Exception as e: - print([[f.device for f in frame.get_features()] for frame in frames]) - raise (e) - - window_length = len(frames) - instances_per_frame = [frame.num_detected for frame in frames] - total_instances = sum(instances_per_frame) - embed_dim = reid_features.shape[-1] - embeddings_dict = {"pos": None, "temp": None} + ref_features = torch.cat( + [instance.features for instance in ref_instances], dim=0 + ).unsqueeze(0) + + # window_length = len(frames) + # instances_per_frame = [frame.num_detected for frame in frames] + total_instances = len(ref_instances) + embed_dim = ref_features.shape[-1] + embeddings_dict = { + "ref": {"pos": None, "temp": None}, + "query": {"pos": None, "temp": None}, + } # print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}') - pred_box, pred_time = get_boxes_times(frames) # total_instances, 4 - pred_box = torch.nan_to_num(pred_box, -1.0) + ref_boxes = get_boxes(ref_instances) # total_instances, 4 + ref_boxes = torch.nan_to_num(ref_boxes, -1.0) + ref_times, query_times = get_times(ref_instances, query_instances) + + window_length = len(ref_times.unique()) - temp_emb = self.temp_emb(pred_time / window_length) + ref_temp_emb = self.temp_emb(ref_times / window_length) if self.return_embedding: - embeddings_dict["temp"] = temp_emb + embeddings_dict["ref"]["temp"] = ref_temp_emb - pos_emb = self.pos_emb(pred_box) + ref_pos_emb = self.pos_emb(ref_boxes) if self.return_embedding: - embeddings_dict["pos"] = pos_emb + embeddings_dict["ref"]["pos"] = ref_pos_emb - try: - emb = (pos_emb + temp_emb) / 2.0 - except RuntimeError as e: - print(self.pos_emb.features, self.temp_emb.features) - print(pos_emb.shape, temp_emb.shape) - raise (e) + ref_emb = (ref_pos_emb + ref_temp_emb) / 2.0 - emb = emb.view(1, total_instances, embed_dim) + ref_emb = ref_emb.view(1, total_instances, embed_dim) - emb = emb.permute(1, 0, 2) # (total_instances, batch_size, embed_dim) + ref_emb = ref_emb.permute(1, 0, 2) # (total_instances, batch_size, embed_dim) - batch_size, total_instances, embed_dim = reid_features.shape + batch_size, total_instances, embed_dim = ref_features.shape - reid_features = reid_features.permute( + ref_features = ref_features.permute( 1, 0, 2 ) # (total_instances, batch_size, embed_dim) - encoder_queries = reid_features + encoder_queries = ref_features encoder_features = self.encoder( - encoder_queries, pos_emb=emb + encoder_queries, pos_emb=ref_emb ) # (total_instances, batch_size, embed_dim) n_query = total_instances - decoder_queries = reid_features - decoder_query_emb = emb + query_features = ref_features + query_pos_emb = ref_pos_emb + query_temp_emb = ref_temp_emb + query_emb = ref_emb - if query_frame is not None: - query_inds = [ - x - for x in range( - sum(instances_per_frame[:query_frame]), - sum(instances_per_frame[: query_frame + 1]), - ) - ] - n_query = len(query_inds) + if query_instances is not None: + n_query = len(query_instances) + + query_features = torch.cat( + [instance.features for instance in query_instances], dim=0 + ).unsqueeze(0) + + query_features = query_features.permute( + 1, 0, 2 + ) # (n_query, batch_size, embed_dim) - decoder_queries = decoder_queries[ - query_inds - ] # decoder_queries: (n_query, batch_size, embed_dim) - decoder_query_emb = decoder_query_emb[query_inds] + query_boxes = get_boxes(query_instances) + + query_temp_emb = self.temp_emb(query_times / window_length) + if self.return_embedding: + embeddings_dict["query"]["temp"] = query_temp_emb + + query_pos_emb = self.pos_emb(query_boxes) + if self.return_embedding: + embeddings_dict["query"]["pos"] = query_pos_emb + + query_emb = (query_pos_emb + query_temp_emb) / 2.0 + + query_emb = query_emb.view(1, n_query, embed_dim) + + query_emb = query_emb.permute(1, 0, 2) # (n_query, batch_size, embed_dim) decoder_features = self.decoder( - decoder_queries, + query_features, encoder_features, - pos_emb=emb, - query_pos_emb=decoder_query_emb, + ref_pos_emb=ref_emb, + query_pos_emb=query_emb, ) # (L, n_query, batch_size, embed_dim) decoder_features = decoder_features.transpose( @@ -372,7 +383,7 @@ def forward( self, decoder_queries: torch.Tensor, encoder_features: torch.Tensor, - pos_emb: torch.Tensor = None, + ref_pos_emb: torch.Tensor = None, query_pos_emb: torch.Tensor = None, ) -> torch.Tensor: """Execute forward pass of decoder layer. @@ -381,7 +392,7 @@ def forward( decoder_queries: Target sequence for decoder to generate (n_query, batch_size, embed_dim). encoder_features: Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim) - pos_emb: The input positional embedding tensor of shape (n_query, embed_dim). + ref_pos_emb: The input positional embedding tensor of shape (n_query, embed_dim). query_pos_emb: The target positional embedding of shape (n_query, embed_dim) Returns: @@ -389,11 +400,11 @@ def forward( """ if query_pos_emb is None: query_pos_emb = torch.zeros_like(decoder_queries) - if pos_emb is None: - pos_emb = torch.zeros_like(encoder_features) + if ref_pos_emb is None: + ref_pos_emb = torch.zeros_like(encoder_features) decoder_queries = decoder_queries + query_pos_emb - encoder_features = encoder_features + pos_emb + encoder_features = encoder_features + ref_pos_emb if self.decoder_self_attn: self_attn_features = self.self_attn( @@ -496,7 +507,7 @@ def forward( self, decoder_queries: torch.Tensor, encoder_features: torch.Tensor, - pos_emb: torch.Tensor = None, + ref_pos_emb: torch.Tensor = None, query_pos_emb: torch.Tensor = None, ) -> torch.Tensor: """Execute a forward pass of the decoder block. @@ -505,7 +516,7 @@ def forward( decoder_queries: Query sequence for decoder to generate (n_query, batch_size, embed_dim). encoder_features: Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim) - pos_emb: The input positional embedding tensor of shape (total_instances, batch_size, embed_dim). + ref_pos_emb: The input positional embedding tensor of shape (total_instances, batch_size, embed_dim). query_pos_emb: The query positional embedding of shape (n_query, batch_size, embed_dim) Returns: @@ -519,7 +530,7 @@ def forward( decoder_features = layer( decoder_features, encoder_features, - pos_emb=pos_emb, + ref_pos_emb=ref_pos_emb, query_pos_emb=query_pos_emb, ) if self.return_intermediate: diff --git a/biogtr/training/losses.py b/biogtr/training/losses.py index 7dfc15f2..b6f1d5e8 100644 --- a/biogtr/training/losses.py +++ b/biogtr/training/losses.py @@ -1,7 +1,7 @@ """Module containing different loss functions to be optimized.""" from biogtr.data_structures import Frame -from biogtr.models.model_utils import get_boxes_times +from biogtr.models.model_utils import get_boxes, get_times from torch import nn from typing import List, Tuple import torch @@ -49,9 +49,11 @@ def forward( # get number of detected objects and ground truth ids n_t = [frame.num_detected for frame in frames] target_inst_id = torch.cat([frame.get_gt_track_ids() for frame in frames]) + instances = [instance for frame in frames for instance in frame.instances] # for now set equal since detections are fixed - pred_box, pred_time = get_boxes_times(frames) + pred_box = get_boxes(instances) + pred_time, _ = get_times(instances) pred_box = torch.nanmean(pred_box, axis=1) target_box, target_time = pred_box, pred_time diff --git a/tests/test_models.py b/tests/test_models.py index c19f8187..08412e49 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -382,7 +382,10 @@ def test_transformer_decoder(): pos_emb = query_pos_emb = torch.ones_like(encoder_features) decoder_features = transformer_decoder( - decoder_queries, encoder_features, pos_emb=pos_emb, query_pos_emb=query_pos_emb + decoder_queries, + encoder_features, + ref_pos_emb=pos_emb, + query_pos_emb=query_pos_emb, ) assert decoder_features.size() == decoder_queries.size() @@ -411,7 +414,8 @@ def test_transformer_basic(): Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances) ) - asso_preds, _ = transformer(frames) + instances = [instance for frame in frames for instance in frame.instances] + asso_preds, _ = transformer(instances) assert asso_preds[0].size() == (num_detected * num_frames,) * 2 @@ -435,6 +439,8 @@ def test_transformer_embedding(): ) frames.append(Frame(video_id=0, frame_id=i, instances=instances)) + instances = [instance for frame in frames for instance in frame.instances] + embedding_meta = { "pos": {"mode": "learned", "emb_num": 16, "normalize": True}, "temp": {"mode": "learned", "emb_num": 16, "normalize": True}, @@ -451,11 +457,11 @@ def test_transformer_embedding(): assert transformer.pos_emb.mode == "learned" assert transformer.temp_emb.mode == "learned" - asso_preds, embeddings = transformer(frames) + asso_preds, embeddings = transformer(instances) assert asso_preds[0].size() == (num_detected * num_frames,) * 2 - for emb_type, embedding in embeddings.items(): + for emb_type, embedding in embeddings["ref"].items(): assert embedding.size() == ( num_detected * num_frames, feats, @@ -503,12 +509,12 @@ def test_tracking_transformer(): embedding_meta=embedding_meta, return_embedding=True, ) - - asso_preds, embeddings = tracking_transformer(frames) + instances = [instance for frame in frames for instance in frame.instances] + asso_preds, embeddings = tracking_transformer(instances) assert asso_preds[0].size() == (num_detected * num_frames,) * 2 - for emb_type, embedding in embeddings.items(): + for emb_type, embedding in embeddings["ref"].items(): assert embedding.size() == ( num_detected * num_frames, feats,