From 041d0a4f0abe76f48b11c697450d10c4fe5a88dc Mon Sep 17 00:00:00 2001 From: Aaditya Prasad <78439225+aaprasad@users.noreply.github.com> Date: Thu, 25 Apr 2024 18:18:18 -0700 Subject: [PATCH] Refactor embeddings (#29) Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- biogtr/models/embedding.py | 274 +++++++++++-------- biogtr/models/global_tracking_transformer.py | 37 +-- biogtr/models/transformer.py | 134 +++------ biogtr/training/configs/base.yaml | 14 +- biogtr/training/configs/params.yaml | 9 +- tests/configs/base.yaml | 13 +- tests/configs/params.yaml | 12 +- tests/test_models.py | 202 ++++++++------ tests/test_training.py | 3 +- 9 files changed, 352 insertions(+), 346 deletions(-) diff --git a/biogtr/models/embedding.py b/biogtr/models/embedding.py index 364d4c8f..95a555dd 100644 --- a/biogtr/models/embedding.py +++ b/biogtr/models/embedding.py @@ -1,6 +1,6 @@ """Module containing different position and temporal embeddings.""" -from typing import Tuple +from typing import Tuple, Optional import math import torch @@ -13,12 +13,116 @@ class Embedding(torch.nn.Module): Used for both learned and fixed embeddings. """ - def __init__(self): - """Initialize embeddings.""" + EMB_TYPES = { + "temp": {}, + "pos": {"over_boxes"}, + "off": {}, + None: {}, + } # dict of valid args:keyword params + EMB_MODES = { + "fixed": {"temperature", "scale", "normalize"}, + "learned": {"emb_num"}, + "off": {}, + } # dict of valid args:keyword params + + def __init__( + self, + emb_type: str, + mode: str, + features: int, + emb_num: Optional[int] = 16, + over_boxes: Optional[bool] = True, + temperature: Optional[int] = 10000, + normalize: Optional[bool] = False, + scale: Optional[float] = None, + ): + """Initialize embeddings. + + Args: + emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", "off"}` + mode: The mode or function used to map positions to vector embeddings. + Must be one of `{"fixed", "learned", "off"}` + features: The embedding dimensions. Must match the dimension of the + input vectors for the transformer model. + emb_num: the number of embeddings in the `self.lookup` table (Only used in learned embeddings). + over_boxes: Whether to compute the position embedding for each bbox coordinate (y1x1y2x2) or the centroid + bbox size (yxwh). + temperature: the temperature constant to be used when computing the sinusoidal position embedding + normalize: whether or not to normalize the positions (Only used in fixed embeddings). + scale: factor by which to scale the positions after normalizing (Only used in fixed embeddings). + """ + self._check_init_args(emb_type, mode) + super().__init__() - # empty init for flexibility - self.pos_lookup = None - self.temp_lookup = None + + self.emb_type = emb_type + self.mode = mode + self.features = features + self.emb_num = emb_num + self.over_boxes = over_boxes + self.temperature = temperature + self.normalize = normalize + self.scale = scale + if self.normalize and self.scale is None: + self.scale = 2 * math.pi + + self._emb_func = lambda tensor: torch.zeros( + (tensor.shape[0], self.features), dtype=tensor.dtype, device=tensor.device + ) # turn off embedding by returning zeros + + self.lookup = None + + if self.mode == "learned": + if self.emb_type == "pos": + self.lookup = torch.nn.Embedding(self.emb_num * 4, self.features // 4) + self._emb_func = self._learned_pos_embedding + elif self.emb_type == "temp": + self.lookup = torch.nn.Embedding(self.emb_num, self.features) + self._emb_func = self._learned_temp_embedding + + elif self.mode == "fixed": + if self.emb_type == "pos": + self._emb_func = self._sine_box_embedding + elif self.emb_type == "temp": + pass # TODO Implement fixed sine temporal embedding + + def _check_init_args(self, emb_type: str, mode: str): + """Check whether the correct arguments were passed to initialization. + + Args: + emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", ""}` + mode: The mode or function used to map positions to vector embeddings. + Must be one of `{"fixed", "learned"}` + + Raises: + ValueError: + * if the incorrect `emb_type` or `mode` string are passed + NotImplementedError: if `emb_type` is `temp` and `mode` is `fixed`. + """ + if emb_type.lower() not in self.EMB_TYPES: + raise ValueError( + f"Embedding `emb_type` must be one of {self.EMB_TYPES} not {emb_type}" + ) + + if mode.lower() not in self.EMB_MODES: + raise ValueError( + f"Embedding `mode` must be one of {self.EMB_MODES} not {mode}" + ) + + if mode == "fixed" and emb_type == "temp": + raise NotImplementedError("TODO: Implement Fixed Sinusoidal Temp Embedding") + + def forward(self, seq_positions: torch.Tensor) -> torch.Tensor: + """Get the sequence positional embeddings. + + Args: + seq_positions: + * An `N` x 1 tensor where seq_positions[i] represents the temporal position of instance_i in the sequence. + * An `N` x 4 tensor where seq_positions[i] represents the [y1, x1, y2, x2] spatial locations of instance_i in the sequence. + + Returns: + An `N` x `self.features` tensor representing the corresponding spatial or temporal embedding. + """ + return self._emb_func(seq_positions) def _torch_int_div( self, tensor1: torch.Tensor, tensor2: torch.Tensor @@ -34,46 +138,19 @@ def _torch_int_div( """ return torch.div(tensor1, tensor2, rounding_mode="floor") - def _sine_box_embedding( - self, - boxes, - features: int = 512, - temperature: int = 10000, - scale: float = None, - normalize: bool = False, - **kwargs, - ) -> torch.Tensor: + def _sine_box_embedding(self, boxes: torch.Tensor) -> torch.Tensor: """Compute sine positional embeddings for boxes using given parameters. Args: - boxes: the input boxes. - features: number of position features to use. - temperature: frequency factor to control spread of pos embed values. - A higher temp (e.g 10000) gives a larger spread of values - scale: A scale factor to use if normalizing - normalize: Whether to normalize the input before computing embedding + boxes: the input boxes of shape N x 4 or B x N x 4 + where the last dimension is the bbox coords in [y1, x1, y2, x2]. + (Note currently `B=batch_size=1`). Returns: torch.Tensor, the sine positional embeddings. """ - # update default parameters with kwargs if available - params = { - "features": features, - "temperature": temperature, - "scale": scale, - "normalize": normalize, - **kwargs, - } - - self.features = params["features"] - self.temperature = params["temperature"] - self.scale = params["scale"] - self.normalize = params["normalize"] - if self.scale is not None and self.normalize is False: raise ValueError("normalize should be True if scale is passed") - if self.scale is None: - self.scale = 2 * math.pi if len(boxes.size()) == 2: boxes = boxes.unsqueeze(0) @@ -81,9 +158,11 @@ def _sine_box_embedding( if self.normalize: boxes = boxes / (boxes[:, -1:] + 1e-6) * self.scale - dim_t = torch.arange(self.features, dtype=torch.float32) + dim_t = torch.arange(self.features // 4, dtype=torch.float32) - dim_t = self.temperature ** (2 * self._torch_int_div(dim_t, 2) / self.features) + dim_t = self.temperature ** ( + 2 * self._torch_int_div(dim_t, 2) / (self.features // 4) + ) # (b, n_t, 4, D//4) pos_emb = boxes[:, :, :, None] / dim_t.to(boxes.device) @@ -97,41 +176,18 @@ def _sine_box_embedding( return pos_emb - def _learned_pos_embedding( - self, - boxes: torch.Tensor, - features: int = 1024, - learn_pos_emb_num: int = 16, - over_boxes: bool = True, - **kwargs, - ) -> torch.Tensor: + def _learned_pos_embedding(self, boxes: torch.Tensor) -> torch.Tensor: """Compute learned positional embeddings for boxes using given parameters. Args: - boxes: the input boxes. - features: Number of features in attention head. - learn_pos_emb_num: Size of the dictionary of embeddings. - over_boxes: If True, use box dimensions, rather than box offset and shape. + boxes: the input boxes of shape N x 4 or B x N x 4 + where the last dimension is the bbox coords in [y1, x1, y2, x2]. + (Note currently `B=batch_size=1`). Returns: torch.Tensor, the learned positional embeddings. """ - params = { - "features": features, - "learn_pos_emb_num": learn_pos_emb_num, - "over_boxes": over_boxes, - **kwargs, - } - - self.features = params["features"] - self.learn_pos_emb_num = params["learn_pos_emb_num"] - self.over_boxes = params["over_boxes"] - - if self.pos_lookup is None: - self.pos_lookup = torch.nn.Embedding( - self.learn_pos_emb_num * 4, self.features // 4 - ) - pos_lookup = self.pos_lookup + pos_lookup = self.lookup N = boxes.shape[0] boxes = boxes.view(N, 4) @@ -144,92 +200,70 @@ def _learned_pos_embedding( dim=1, ) - l, r, lw, rw = self._compute_weights(xywh, self.learn_pos_emb_num) + left_ind, right_ind, left_weight, right_weight = self._compute_weights(xywh) - f = pos_lookup.weight.shape[1] + f = pos_lookup.weight.shape[1] # self.features // 4 - pos_emb_table = pos_lookup.weight.view( - self.learn_pos_emb_num, 4, f - ) # T x 4 x (D * 4) + pos_emb_table = pos_lookup.weight.view(self.emb_num, 4, f) # T x 4 x (D * 4) - pos_le = pos_emb_table.gather( - 0, l[:, :, None].to(pos_emb_table.device).expand(N, 4, f) + left_emb = pos_emb_table.gather( + 0, left_ind[:, :, None].to(pos_emb_table.device).expand(N, 4, f) ) # N x 4 x d - pos_re = pos_emb_table.gather( - 0, r[:, :, None].to(pos_emb_table.device).expand(N, 4, f) + right_emb = pos_emb_table.gather( + 0, right_ind[:, :, None].to(pos_emb_table.device).expand(N, 4, f) ) # N x 4 x d - pos_emb = lw[:, :, None] * pos_re.to(lw.device) + rw[:, :, None] * pos_le.to( - rw.device - ) + pos_emb = left_weight[:, :, None] * right_emb.to( + left_weight.device + ) + right_weight[:, :, None] * left_emb.to(right_weight.device) - pos_emb = pos_emb.view(N, 4 * f) + pos_emb = pos_emb.view(N, self.features) return pos_emb - def _learned_temp_embedding( - self, - times: torch.Tensor, - features: int = 1024, - learn_temp_emb_num: int = 16, - **kwargs, - ) -> torch.Tensor: + def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor: """Compute learned temporal embeddings for times using given parameters. Args: - times: the input times. - features: Number of features in attention head. - learn_temp_emb_num: Size of the dictionary of embeddings. + times: the input times of shape (N,) or (N,1) where N = (sum(instances_per_frame)) + which is the frame index of the instance relative + to the batch size + (e.g. `torch.tensor([0, 0, ..., 0, 1, 1, ..., 1, 2, 2, ..., 2,..., B, B, ...B])`). Returns: torch.Tensor, the learned temporal embeddings. """ - params = { - "features": features, - "learn_temp_emb_num": learn_temp_emb_num, - **kwargs, - } - - self.features = params["features"] - self.learn_temp_emb_num = params["learn_temp_emb_num"] - - if self.temp_lookup is None: - self.temp_lookup = torch.nn.Embedding( - self.learn_temp_emb_num, self.features - ) - - temp_lookup = self.temp_lookup + temp_lookup = self.lookup N = times.shape[0] - l, r, lw, rw = self._compute_weights(times, self.learn_temp_emb_num) + left_ind, right_ind, left_weight, right_weight = self._compute_weights(times) - le = temp_lookup.weight[l.to(temp_lookup.weight.device)] # T x D --> N x D - re = temp_lookup.weight[r.to(temp_lookup.weight.device)] + left_emb = temp_lookup.weight[ + left_ind.to(temp_lookup.weight.device) + ] # T x D --> N x D + right_emb = temp_lookup.weight[right_ind.to(temp_lookup.weight.device)] - temp_emb = lw[:, None] * re.to(lw.device) + rw[:, None] * le.to(rw.device) + temp_emb = left_weight[:, None] * right_emb.to( + left_weight.device + ) + right_weight[:, None] * left_emb.to(right_weight.device) return temp_emb.view(N, self.features) - def _compute_weights( - self, data: torch.Tensor, learn_emb_num: int = 16 - ) -> Tuple[torch.Tensor, ...]: + def _compute_weights(self, data: torch.Tensor) -> Tuple[torch.Tensor, ...]: """Compute left and right learned embedding weights. Args: data: the input data (e.g boxes or times). - learn_temp_emb_num: Size of the dictionary of embeddings. Returns: A torch.Tensor for each of the left/right indices and weights, respectively """ - data = data * learn_emb_num + data = data * self.emb_num - left_index = data.clamp(min=0, max=learn_emb_num - 1).long() # N x 4 - right_index = ( - (left_index + 1).clamp(min=0, max=learn_emb_num - 1).long() - ) # N x 4 + left_ind = data.clamp(min=0, max=self.emb_num - 1).long() # N x 4 + right_ind = (left_ind + 1).clamp(min=0, max=self.emb_num - 1).long() # N x 4 - left_weight = data - left_index.float() # N x 4 + left_weight = data - left_ind.float() # N x 4 right_weight = 1.0 - left_weight - return left_index, right_index, left_weight, right_weight + return left_ind, right_ind, left_weight, right_weight diff --git a/biogtr/models/global_tracking_transformer.py b/biogtr/models/global_tracking_transformer.py index 1766a851..4f9d99b5 100644 --- a/biogtr/models/global_tracking_transformer.py +++ b/biogtr/models/global_tracking_transformer.py @@ -19,11 +19,9 @@ def __init__( nhead: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, - dim_feedforward: int = 1024, dropout: int = 0.1, activation: str = "relu", return_intermediate_dec: bool = False, - feature_dim_attn_head: int = 1024, norm: bool = False, num_layers_attn_head: int = 2, dropout_attn_head: int = 0.1, @@ -42,40 +40,23 @@ def __init__( nhead: The number of heads in the transfomer encoder/decoder. num_encoder_layers: The number of encoder-layers in the encoder. num_decoder_layers: The number of decoder-layers in the decoder. - dim_feedforward: The dimension of the feedforward layers of the transformer. dropout: Dropout value applied to the output of transformer layers. activation: Activation function to use. return_intermediate_dec: Return intermediate layers from decoder. norm: If True, normalize output of encoder and decoder. - feature_dim_attn_head: The number of features in the attention head. num_layers_attn_head: The number of layers in the attention head. dropout_attn_head: Dropout value for the attention_head. embedding_meta: Metadata for positional embeddings. See below. return_embedding: Whether to return the positional embeddings decoder_self_attn: If True, use decoder self attention. - embedding_meta: By default this will be an empty dict and indicate - that no positional embeddings should be used. To use positional - embeddings, a dict should be passed with the type of embedding to - use. Valid options are: - * learned_pos: only learned position embeddings - * learned_temp: only learned temporal embeddings - * learned_pos_temp: learned position and temporal embeddings - * fixed_pos: fixed sine position embeddings - * fixed_pos_temp: fixed sine position and learned temporal embeddings - You can additionally pass kwargs to override the default - embedding values (see embedding.py function methods for relevant - embedding parameters). Example: - embedding_meta = { - 'embedding_type': 'learned_pos_temp', - 'kwargs': { - 'learn_pos_emb_num': 16, - 'learn_temp_emb_num': 16, - 'over_boxes': False - } - } - Note: Embedding features are handled directly in the forward - pass for each case. Overriding the features through kwargs will - likely throw errors due to incorrect tensor shapes. + + More details on `embedding_meta`: + By default this will be an empty dict and indicate + that no positional embeddings should be used. To use the positional embeddings + pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie: + {"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'}, + "temp": {'mode': 'learned', 'emb_num': 16}}. (see `biogtr.models.embeddings.Embedding.EMB_TYPES` + and `biogtr.models.embeddings.Embedding.EMB_MODES` for embedding parameters). """ super().__init__() @@ -86,11 +67,9 @@ def __init__( nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, - dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, return_intermediate_dec=return_intermediate_dec, - feature_dim_attn_head=feature_dim_attn_head, norm=norm, num_layers_attn_head=num_layers_attn_head, dropout_attn_head=dropout_attn_head, diff --git a/biogtr/models/transformer.py b/biogtr/models/transformer.py index dec1fc3f..f607c239 100644 --- a/biogtr/models/transformer.py +++ b/biogtr/models/transformer.py @@ -33,12 +33,10 @@ def __init__( nhead: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, - dim_feedforward: int = 1024, dropout: float = 0.1, activation: str = "relu", return_intermediate_dec: bool = False, norm: bool = False, - feature_dim_attn_head: int = 1024, num_layers_attn_head: int = 2, dropout_attn_head: float = 0.1, embedding_meta: dict = None, @@ -52,72 +50,47 @@ def __init__( nhead: The number of heads in the transfomer encoder/decoder. num_encoder_layers: The number of encoder-layers in the encoder. num_decoder_layers: The number of decoder-layers in the decoder. - dim_feedforward: The dimension of the feedforward layers of the transformer. dropout: Dropout value applied to the output of transformer layers. activation: Activation function to use. return_intermediate_dec: Return intermediate layers from decoder. norm: If True, normalize output of encoder and decoder. - feature_dim_attn_head: The number of features in the attention head. num_layers_attn_head: The number of layers in the attention head. dropout_attn_head: Dropout value for the attention_head. embedding_meta: Metadata for positional embeddings. See below. return_embedding: Whether to return the positional embeddings decoder_self_attn: If True, use decoder self attention. - embedding_meta: By default this will be an empty dict and indicate - that no positional embeddings should be used. To use positional - embeddings, a dict should be passed with the type of embedding to - use. Valid options are: - * learned_pos: only learned position embeddings - * learned_temp: only learned temporal embeddings - * learned_pos_temp: learned position and temporal embeddings - * fixed_pos: fixed sine position embeddings - * fixed_pos_temp: fixed sine position and learned temporal embeddings - You can additionally pass kwargs to override the default - embedding values (see embedding.py function methods for relevant - embedding parameters). Example: - - embedding_meta = { - 'embedding_type': 'learned_pos_temp', - 'kwargs': { - 'learn_pos_emb_num': 16, - 'learn_temp_emb_num': 16, - 'over_boxes': False - } - } - Note: Embedding features are handled directly in the forward - pass for each case. Overriding the features through kwargs will - likely throw errors due to incorrect tensor shapes. + More details on `embedding_meta`: + By default this will be an empty dict and indicate + that no positional embeddings should be used. To use the positional embeddings + pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie: + {"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'}, + "temp": {'mode': 'learned', 'emb_num': 16}}. (see `biogtr.models.embeddings.Embedding.EMB_TYPES` + and `biogtr.models.embeddings.Embedding.EMB_MODES` for embedding parameters). """ super().__init__() - self.d_model = d_model + self.d_model = dim_feedforward = feature_dim_attn_head = d_model self.embedding_meta = embedding_meta self.return_embedding = return_embedding - if self.embedding_meta: - key = "embedding_type" - - embedding_types = [ - "learned_pos", - "learned_temp", - "learned_pos_temp", - "fixed_pos", - "fixed_pos_temp", - ] - - assert ( - key in self.embedding_meta - ), f"Please provide an embedding type, valid options are {embedding_types}" - - provided_type = self.embedding_meta[key] + self.pos_emb = Embedding(emb_type="off", mode="off", features=self.d_model) + self.temp_emb = Embedding(emb_type="off", mode="off", features=self.d_model) - assert ( - provided_type in embedding_types - ), f"{provided_type} is invalid. Please choose a valid type from {embedding_types}" - - self.embedding = Embedding() + if self.embedding_meta: + if "pos" in self.embedding_meta: + pos_emb_cfg = self.embedding_meta["pos"] + if pos_emb_cfg: + self.pos_emb = Embedding( + emb_type="pos", features=self.d_model, **pos_emb_cfg + ) + if "temp" in self.embedding_meta: + temp_emb_cfg = self.embedding_meta["temp"] + if temp_emb_cfg: + self.temp_emb = Embedding( + emb_type="temp", features=self.d_model, **temp_emb_cfg + ) # Transformer Encoder encoder_layer = TransformerEncoderLayer( @@ -189,40 +162,22 @@ def forward(self, frames: list[Frame], query_frame: int = None): embed_dim = reid_features.shape[-1] # print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}') - if self.embedding_meta: - kwargs = self.embedding_meta.get("kwargs", {}) + pred_box, pred_time = get_boxes_times(frames) # total_instances x 4 - pred_box, pred_time = get_boxes_times(frames) # total_instances x 4 + temp_emb = self.temp_emb(pred_time / window_length) - embedding_type = self.embedding_meta["embedding_type"] + pos_emb = self.pos_emb(pred_box) - if "temp" in embedding_type: - temp_emb = self.embedding._learned_temp_embedding( - pred_time / window_length, features=self.d_model, **kwargs - ) - - pos_emb = temp_emb - - if "learned" in embedding_type: - if "pos" in embedding_type: - pos_emb = self.embedding._learned_pos_embedding( - pred_box, features=self.d_model, **kwargs - ) - - else: - pos_emb = self.embedding._sine_box_embedding( - pred_box, features=self.d_model // 4, **kwargs - ) + try: + emb = (pos_emb + temp_emb) / 2.0 + except RuntimeError as e: + print(self.pos_emb.features, self.temp_emb.features) + print(pos_emb.shape, temp_emb.shape) + raise (e) - if "temp" in embedding_type and embedding_type != "learned_temp": - pos_emb = (pos_emb + temp_emb) / 2.0 + emb = emb.view(1, total_instances, embed_dim) - pos_emb = pos_emb.view(1, total_instances, embed_dim) - pos_emb = pos_emb.permute( - 1, 0, 2 - ) # (total_instances, batch_size, embed_dim) - else: - pos_emb = None + emb = emb.permute(1, 0, 2) # (total_instances, batch_size, embed_dim) query_inds = None n_query = total_instances @@ -242,23 +197,18 @@ def forward(self, frames: list[Frame], query_frame: int = None): ) # (total_instances x batch_size x embed_dim) memory = self.encoder( - reid_features, pos_emb=pos_emb + reid_features, pos_emb=emb ) # (total_instances, batch_size, embed_dim) - if query_inds is not None: - tgt = reid_features[query_inds] - if pos_emb is not None: - tgt_pos_emb = pos_emb[query_inds] - else: - tgt_pos_emb = pos_emb - else: - tgt = reid_features - tgt_pos_emb = pos_emb + tgt = reid_features + tgt_emb = emb - # tgt: (n_query, batch_size, embed_dim) + if query_inds is not None: + tgt = tgt[query_inds] # tgt: (n_query, batch_size, embed_dim) + tgt_emb = tgt_emb[query_inds] hs = self.decoder( - tgt, memory, pos_emb=pos_emb, tgt_pos_emb=tgt_pos_emb + tgt, memory, pos_emb=emb, tgt_pos_emb=tgt_emb ) # (L, n_query, batch_size, embed_dim) feats = hs.transpose(1, 2) # # (L, batch_size, n_query, embed_dim) @@ -273,7 +223,7 @@ def forward(self, frames: list[Frame], query_frame: int = None): asso_output.append(self.attn_head(x, memory).view(n_query, total_instances)) # (L=1, n_query, total_instances) - return (asso_output, pos_emb) if self.return_embedding else (asso_output, None) + return (asso_output, emb) if self.return_embedding else (asso_output, None) class TransformerEncoder(nn.Module): diff --git a/biogtr/training/configs/base.yaml b/biogtr/training/configs/base.yaml index f7069f40..4b97cdc5 100644 --- a/biogtr/training/configs/base.yaml +++ b/biogtr/training/configs/base.yaml @@ -6,20 +6,20 @@ model: nhead: 8 num_encoder_layers: 1 num_decoder_layers: 1 - dim_feedforward: 1024 dropout: 0.1 activation: "relu" return_intermediate_dec: False - feature_dim_attn_head: 1024 norm: False num_layers_attn_head: 2 dropout_attn_head: 0.1 embedding_meta: - embedding_type: 'learned_pos_temp' - kwargs: - learn_pos_emb_num: 16 - learn_temp_emb_num: 16 - over_boxes: False + pos: + mode: "learned" + emb_num: 16 + over_boxes: false + temp: + mode: "learned" + emb_num: 16 return_embedding: False decoder_self_attn: False diff --git a/biogtr/training/configs/params.yaml b/biogtr/training/configs/params.yaml index e6b16946..41f25613 100644 --- a/biogtr/training/configs/params.yaml +++ b/biogtr/training/configs/params.yaml @@ -2,11 +2,12 @@ model: num_encoder_layers: 2 num_decoder_layers: 2 embedding_meta: - embedding_type: 'learned_pos' - kwargs: - learn_pos_emb_num: 16 + pos: + mode: learned + emb_num: 16 over_boxes: True - + temp: + mode: "off" dataset: train_dataset: slp_files: ['190612_110405_wt_18159111_rig2.2@11730.slp'] diff --git a/tests/configs/base.yaml b/tests/configs/base.yaml index f8cc8429..a29945c7 100644 --- a/tests/configs/base.yaml +++ b/tests/configs/base.yaml @@ -6,20 +6,21 @@ model: nhead: 8 num_encoder_layers: 1 num_decoder_layers: 1 - dim_feedforward: 512 dropout: 0.1 activation: "relu" return_intermediate_dec: False - feature_dim_attn_head: 512 norm: False num_layers_attn_head: 2 dropout_attn_head: 0.1 embedding_meta: - embedding_type: 'learned_pos_temp' - kwargs: - learn_pos_emb_num: 16 - learn_temp_emb_num: 16 + pos: + mode: 'learned' + emb_num: 16 over_boxes: False + temp: + mode: 'learned' + emb_num: 16 + return_embedding: False decoder_self_attn: False diff --git a/tests/configs/params.yaml b/tests/configs/params.yaml index cdf62ce2..1f2fbfdd 100644 --- a/tests/configs/params.yaml +++ b/tests/configs/params.yaml @@ -2,10 +2,13 @@ model: num_encoder_layers: 2 num_decoder_layers: 2 embedding_meta: - embedding_type: 'learned_pos' - kwargs: - learn_pos_num: 16, - over_boxes: True + embedding_type: + pos: + mode: 'learned' + emb_num: 16 + over_boxes: True + temp: + mode: "off" dataset: train_dataset: @@ -27,3 +30,4 @@ trainer: limit_test_batches: 1 limit_val_batches: 1 max_epochs: 1 + enable_checkpointing: true diff --git a/tests/test_models.py b/tests/test_models.py index ceae0bc5..0e8c3c29 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -61,9 +61,50 @@ def test_encoder(): assert output.shape == (b, features) +def test_embedding_validity(): + """Test embedding usage.""" + + # this would throw assertion since embedding should be "pos" + with pytest.raises(Exception): + _ = Embedding(emb_type="position", mode="learned", features=128) + with pytest.raises(Exception): + _ = Embedding(emb_type="position", mode="fixed", features=128) + + with pytest.raises(Exception): + _ = Embedding(emb_type="temporal", mode="learned", features=128) + with pytest.raises(Exception): + _ = Embedding(emb_type="position", mode="fixed", features=128) + + with pytest.raises(Exception): + _ = Embedding(emb_type="pos", mode="learn", features=128) + with pytest.raises(Exception): + _ = Embedding(emb_type="temp", mode="learn", features=128) + + with pytest.raises(Exception): + _ = Embedding(emb_type="pos", mode="fix", features=128) + with pytest.raises(Exception): + _ = Embedding(emb_type="temp", mode="fix", features=128) + + with pytest.raises(Exception): + _ = Embedding(emb_type="position", mode="learn", features=128) + with pytest.raises(Exception): + _ = Embedding(emb_type="temporal", mode="learn", features=128) + with pytest.raises(Exception): + _ = Embedding(emb_type="position", mode="fix", features=128) + with pytest.raises(Exception): + _ = Embedding(emb_type="temporal", mode="learn", features=128) + + with pytest.raises(Exception): + _ = Embedding(emb_type="temp", mode="fixed", features=128) + + _ = Embedding(emb_type="temp", mode="learned", features=128) + _ = Embedding(emb_type="pos", mode="learned", features=128) + + _ = Embedding(emb_type="pos", mode="learned", features=128) + + def test_embedding(): """Test embedding logic.""" - emb = Embedding() frames = 32 objects = 10 @@ -74,26 +115,70 @@ def test_embedding(): boxes = torch.rand(size=(N, 4)) times = torch.rand(size=(N,)) - sine_emb = emb._sine_box_embedding( - boxes, features=d_model // 4, temperature=objects, normalize=True, scale=10 + pos_emb = Embedding( + emb_type="pos", + mode="fixed", + features=d_model, + temperature=objects, + normalize=True, + scale=10, ) - learned_pos_emb = emb._learned_pos_embedding( - boxes, features=d_model, learn_pos_emb_num=100 - ) + sine_pos_emb = pos_emb(boxes) - learned_temp_emb = emb._learned_temp_embedding( - times, features=d_model, learn_temp_emb_num=16 - ) + pos_emb = Embedding(emb_type="pos", mode="learned", features=d_model, emb_num=100) + learned_pos_emb = pos_emb(boxes) - assert sine_emb.size() == (N, d_model) + temp_emb = Embedding(emb_type="temp", mode="learned", features=d_model, emb_num=16) + learned_temp_emb = temp_emb(times) + + pos_emb_off = Embedding(emb_type="pos", mode="off", features=d_model) + off_pos_emb = pos_emb_off(boxes) + + temp_emb_off = Embedding(emb_type="temp", mode="off", features=d_model) + off_temp_emb = temp_emb_off(times) + + learned_emb_off = Embedding(emb_type="off", mode="learned", features=d_model) + off_learned_emb_boxes = learned_emb_off(boxes) + off_learned_emb_times = learned_emb_off(times) + + fixed_emb_off = Embedding(emb_type="off", mode="fixed", features=d_model) + off_fixed_emb_boxes = fixed_emb_off(boxes) + off_fixed_emb_times = fixed_emb_off(times) + + off_emb = Embedding(emb_type="off", mode="off", features=d_model) + off_emb_boxes = off_emb(boxes) + off_emb_times = off_emb(times) + + assert sine_pos_emb.size() == (N, d_model) assert learned_pos_emb.size() == (N, d_model) assert learned_temp_emb.size() == (N, d_model) + assert not torch.equal(sine_pos_emb, learned_pos_emb) + assert not torch.equal(sine_pos_emb, learned_temp_emb) + assert not torch.equal(learned_pos_emb, learned_temp_emb) + + assert off_pos_emb.size() == (N, d_model) + assert off_temp_emb.size() == (N, d_model) + assert off_learned_emb_boxes.size() == (N, d_model) + assert off_learned_emb_times.size() == (N, d_model) + assert off_fixed_emb_boxes.size() == (N, d_model) + assert off_fixed_emb_times.size() == (N, d_model) + assert off_emb_boxes.size() == (N, d_model) + assert off_emb_times.size() == (N, d_model) + + assert not off_pos_emb.any() + assert not off_temp_emb.any() + assert not off_learned_emb_boxes.any() + assert not off_learned_emb_times.any() + assert not off_fixed_emb_boxes.any() + assert not off_fixed_emb_times.any() + assert not off_emb_boxes.any() + assert not off_emb_times.any() + def test_embedding_kwargs(): """Test embedding config logic.""" - emb = Embedding() frames = 32 objects = 10 @@ -105,7 +190,7 @@ def test_embedding_kwargs(): # sine embedding - sine_no_args = emb._sine_box_embedding(boxes) + sine_no_args = Embedding("pos", "fixed", 128)(boxes) sine_args = { "temperature": objects, @@ -113,31 +198,27 @@ def test_embedding_kwargs(): "normalize": True, } - sine_with_args = emb._sine_box_embedding(boxes, **sine_args) + sine_with_args = Embedding("pos", "fixed", 128, **sine_args)(boxes) assert not torch.equal(sine_no_args, sine_with_args) # learned pos embedding - lp_no_args = emb._learned_pos_embedding(boxes) - - lp_args = {"learn_pos_emb_num": 100, "over_boxes": False} + lp_no_args = Embedding("pos", "learned", 128) - emb = Embedding() - lp_with_args = emb._learned_pos_embedding(boxes, **lp_args) + lp_args = {"emb_num": 100, "over_boxes": False} - assert not torch.equal(lp_no_args, lp_with_args) + lp_with_args = Embedding("pos", "learned", 128, **lp_args) + assert lp_no_args.lookup.weight.shape != lp_with_args.lookup.weight.shape # learned temp embedding - lt_no_args = emb._learned_temp_embedding(times) + lt_no_args = Embedding("temp", "learned", 128) - lt_args = {"learn_temp_emb_num": 100} + lt_args = {"emb_num": 100} - emb = Embedding() - lt_with_args = emb._learned_temp_embedding(times, **lt_args) - - assert not torch.equal(lt_no_args, lt_with_args) + lt_with_args = Embedding("temp", "learned", 128, **lt_args) + assert lt_no_args.lookup.weight.shape != lt_with_args.lookup.weight.shape def test_transformer_encoder(): @@ -202,13 +283,7 @@ def test_transformer_basic(): num_detected = 10 img_shape = (1, 100, 100) - transformer = Transformer( - d_model=feats, - num_encoder_layers=1, - num_decoder_layers=1, - dim_feedforward=feats, - feature_dim_attn_head=feats, - ) + transformer = Transformer(d_model=feats, num_encoder_layers=1, num_decoder_layers=1) frames = [] @@ -220,51 +295,15 @@ def test_transformer_basic(): bbox=torch.rand(size=(1, 4)), features=torch.rand(size=(1, feats)) ) ) - frames.append(Frame(video_id=0, frame_id=i, instances=instances)) + frames.append( + Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances) + ) asso_preds, _ = transformer(frames) assert asso_preds[0].size() == (num_detected * num_frames,) * 2 -def test_transformer_embedding_validity(): - """Test embedding usage.""" - # use lower feats and single layer for efficiency - feats = 256 - - # this would throw assertion since no "embedding_type" key - with pytest.raises(Exception): - _ = Transformer( - d_model=feats, - num_encoder_layers=1, - num_decoder_layers=1, - dim_feedforward=feats, - feature_dim_attn_head=feats, - embedding_meta={"type": "learned_pos"}, - ) - - # this would throw assertion since "embedding_type" value invalid - with pytest.raises(Exception): - _ = Transformer( - d_model=feats, - num_encoder_layers=1, - num_decoder_layers=1, - dim_feedforward=feats, - feature_dim_attn_head=feats, - embedding_meta={"embedding_type": "foo"}, - ) - - # this would succeed - _ = Transformer( - d_model=feats, - num_encoder_layers=1, - num_decoder_layers=1, - dim_feedforward=feats, - feature_dim_attn_head=feats, - embedding_meta={"embedding_type": "learned_pos"}, - ) - - def test_transformer_embedding(): """Test transformer using embedding.""" feats = 256 @@ -285,20 +324,14 @@ def test_transformer_embedding(): frames.append(Frame(video_id=0, frame_id=i, instances=instances)) embedding_meta = { - "embedding_type": "learned_pos_temp", - "kwargs": { - "learn_pos_emb_num": 16, - "learn_temp_emb_num": 16, - "normalize": True, - }, + "pos": {"mode": "learned", "emb_num": 16, "normalize": True}, + "temp": {"mode": "learned", "emb_num": 16, "normalize": True}, } transformer = Transformer( d_model=feats, num_encoder_layers=1, num_decoder_layers=1, - dim_feedforward=feats, - feature_dim_attn_head=feats, embedding_meta=embedding_meta, return_embedding=True, ) @@ -331,8 +364,13 @@ def test_tracking_transformer(): ) embedding_meta = { - "embedding_type": "fixed_pos", - "kwargs": {"temperature": num_detected, "scale": num_frames, "normalize": True}, + "pos": { + "mode": "fixed", + "temperature": num_detected, + "scale": num_frames, + "normalize": True, + }, + "temp": None, } cfg = {"resnet18", "ResNet18_Weights.DEFAULT"} @@ -343,8 +381,6 @@ def test_tracking_transformer(): d_model=feats, num_encoder_layers=1, num_decoder_layers=1, - dim_feedforward=feats, - feature_dim_attn_head=feats, embedding_meta=embedding_meta, return_embedding=True, ) diff --git a/tests/test_training.py b/tests/test_training.py index 9951a593..a4947b16 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -11,7 +11,8 @@ from biogtr.config import Config from biogtr.training.train import main -# todo: add named tensor tests +# TODO: add named tensor tests +# TODO: use temp dir and cleanup after tests (https://docs.pytest.org/en/7.1.x/how-to/tmp_path.html) def test_asso_loss():