Skip to content

Commit

Permalink
Refactor embeddings (#29)
Browse files Browse the repository at this point in the history
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
aaprasad and coderabbitai[bot] committed Apr 26, 2024
1 parent 16add88 commit 041d0a4
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 346 deletions.
274 changes: 154 additions & 120 deletions biogtr/models/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Module containing different position and temporal embeddings."""

from typing import Tuple
from typing import Tuple, Optional
import math
import torch

Expand All @@ -13,12 +13,116 @@ class Embedding(torch.nn.Module):
Used for both learned and fixed embeddings.
"""

def __init__(self):
"""Initialize embeddings."""
EMB_TYPES = {
"temp": {},
"pos": {"over_boxes"},
"off": {},
None: {},
} # dict of valid args:keyword params
EMB_MODES = {
"fixed": {"temperature", "scale", "normalize"},
"learned": {"emb_num"},
"off": {},
} # dict of valid args:keyword params

def __init__(
self,
emb_type: str,
mode: str,
features: int,
emb_num: Optional[int] = 16,
over_boxes: Optional[bool] = True,
temperature: Optional[int] = 10000,
normalize: Optional[bool] = False,
scale: Optional[float] = None,
):
"""Initialize embeddings.
Args:
emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", "off"}`
mode: The mode or function used to map positions to vector embeddings.
Must be one of `{"fixed", "learned", "off"}`
features: The embedding dimensions. Must match the dimension of the
input vectors for the transformer model.
emb_num: the number of embeddings in the `self.lookup` table (Only used in learned embeddings).
over_boxes: Whether to compute the position embedding for each bbox coordinate (y1x1y2x2) or the centroid + bbox size (yxwh).
temperature: the temperature constant to be used when computing the sinusoidal position embedding
normalize: whether or not to normalize the positions (Only used in fixed embeddings).
scale: factor by which to scale the positions after normalizing (Only used in fixed embeddings).
"""
self._check_init_args(emb_type, mode)

super().__init__()
# empty init for flexibility
self.pos_lookup = None
self.temp_lookup = None

self.emb_type = emb_type
self.mode = mode
self.features = features
self.emb_num = emb_num
self.over_boxes = over_boxes
self.temperature = temperature
self.normalize = normalize
self.scale = scale
if self.normalize and self.scale is None:
self.scale = 2 * math.pi

self._emb_func = lambda tensor: torch.zeros(
(tensor.shape[0], self.features), dtype=tensor.dtype, device=tensor.device
) # turn off embedding by returning zeros

self.lookup = None

if self.mode == "learned":
if self.emb_type == "pos":
self.lookup = torch.nn.Embedding(self.emb_num * 4, self.features // 4)
self._emb_func = self._learned_pos_embedding
elif self.emb_type == "temp":
self.lookup = torch.nn.Embedding(self.emb_num, self.features)
self._emb_func = self._learned_temp_embedding

elif self.mode == "fixed":
if self.emb_type == "pos":
self._emb_func = self._sine_box_embedding
elif self.emb_type == "temp":
pass # TODO Implement fixed sine temporal embedding

def _check_init_args(self, emb_type: str, mode: str):
"""Check whether the correct arguments were passed to initialization.
Args:
emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", ""}`
mode: The mode or function used to map positions to vector embeddings.
Must be one of `{"fixed", "learned"}`
Raises:
ValueError:
* if the incorrect `emb_type` or `mode` string are passed
NotImplementedError: if `emb_type` is `temp` and `mode` is `fixed`.
"""
if emb_type.lower() not in self.EMB_TYPES:
raise ValueError(
f"Embedding `emb_type` must be one of {self.EMB_TYPES} not {emb_type}"
)

if mode.lower() not in self.EMB_MODES:
raise ValueError(
f"Embedding `mode` must be one of {self.EMB_MODES} not {mode}"
)

if mode == "fixed" and emb_type == "temp":
raise NotImplementedError("TODO: Implement Fixed Sinusoidal Temp Embedding")

def forward(self, seq_positions: torch.Tensor) -> torch.Tensor:
"""Get the sequence positional embeddings.
Args:
seq_positions:
* An `N` x 1 tensor where seq_positions[i] represents the temporal position of instance_i in the sequence.
* An `N` x 4 tensor where seq_positions[i] represents the [y1, x1, y2, x2] spatial locations of instance_i in the sequence.
Returns:
An `N` x `self.features` tensor representing the corresponding spatial or temporal embedding.
"""
return self._emb_func(seq_positions)

def _torch_int_div(
self, tensor1: torch.Tensor, tensor2: torch.Tensor
Expand All @@ -34,56 +138,31 @@ def _torch_int_div(
"""
return torch.div(tensor1, tensor2, rounding_mode="floor")

def _sine_box_embedding(
self,
boxes,
features: int = 512,
temperature: int = 10000,
scale: float = None,
normalize: bool = False,
**kwargs,
) -> torch.Tensor:
def _sine_box_embedding(self, boxes: torch.Tensor) -> torch.Tensor:
"""Compute sine positional embeddings for boxes using given parameters.
Args:
boxes: the input boxes.
features: number of position features to use.
temperature: frequency factor to control spread of pos embed values.
A higher temp (e.g 10000) gives a larger spread of values
scale: A scale factor to use if normalizing
normalize: Whether to normalize the input before computing embedding
boxes: the input boxes of shape N x 4 or B x N x 4
where the last dimension is the bbox coords in [y1, x1, y2, x2].
(Note currently `B=batch_size=1`).
Returns:
torch.Tensor, the sine positional embeddings.
"""
# update default parameters with kwargs if available
params = {
"features": features,
"temperature": temperature,
"scale": scale,
"normalize": normalize,
**kwargs,
}

self.features = params["features"]
self.temperature = params["temperature"]
self.scale = params["scale"]
self.normalize = params["normalize"]

if self.scale is not None and self.normalize is False:
raise ValueError("normalize should be True if scale is passed")
if self.scale is None:
self.scale = 2 * math.pi

if len(boxes.size()) == 2:
boxes = boxes.unsqueeze(0)

if self.normalize:
boxes = boxes / (boxes[:, -1:] + 1e-6) * self.scale

dim_t = torch.arange(self.features, dtype=torch.float32)
dim_t = torch.arange(self.features // 4, dtype=torch.float32)

dim_t = self.temperature ** (2 * self._torch_int_div(dim_t, 2) / self.features)
dim_t = self.temperature ** (
2 * self._torch_int_div(dim_t, 2) / (self.features // 4)
)

# (b, n_t, 4, D//4)
pos_emb = boxes[:, :, :, None] / dim_t.to(boxes.device)
Expand All @@ -97,41 +176,18 @@ def _sine_box_embedding(

return pos_emb

def _learned_pos_embedding(
self,
boxes: torch.Tensor,
features: int = 1024,
learn_pos_emb_num: int = 16,
over_boxes: bool = True,
**kwargs,
) -> torch.Tensor:
def _learned_pos_embedding(self, boxes: torch.Tensor) -> torch.Tensor:
"""Compute learned positional embeddings for boxes using given parameters.
Args:
boxes: the input boxes.
features: Number of features in attention head.
learn_pos_emb_num: Size of the dictionary of embeddings.
over_boxes: If True, use box dimensions, rather than box offset and shape.
boxes: the input boxes of shape N x 4 or B x N x 4
where the last dimension is the bbox coords in [y1, x1, y2, x2].
(Note currently `B=batch_size=1`).
Returns:
torch.Tensor, the learned positional embeddings.
"""
params = {
"features": features,
"learn_pos_emb_num": learn_pos_emb_num,
"over_boxes": over_boxes,
**kwargs,
}

self.features = params["features"]
self.learn_pos_emb_num = params["learn_pos_emb_num"]
self.over_boxes = params["over_boxes"]

if self.pos_lookup is None:
self.pos_lookup = torch.nn.Embedding(
self.learn_pos_emb_num * 4, self.features // 4
)
pos_lookup = self.pos_lookup
pos_lookup = self.lookup

N = boxes.shape[0]
boxes = boxes.view(N, 4)
Expand All @@ -144,92 +200,70 @@ def _learned_pos_embedding(
dim=1,
)

l, r, lw, rw = self._compute_weights(xywh, self.learn_pos_emb_num)
left_ind, right_ind, left_weight, right_weight = self._compute_weights(xywh)

f = pos_lookup.weight.shape[1]
f = pos_lookup.weight.shape[1] # self.features // 4

pos_emb_table = pos_lookup.weight.view(
self.learn_pos_emb_num, 4, f
) # T x 4 x (D * 4)
pos_emb_table = pos_lookup.weight.view(self.emb_num, 4, f) # T x 4 x (D * 4)

pos_le = pos_emb_table.gather(
0, l[:, :, None].to(pos_emb_table.device).expand(N, 4, f)
left_emb = pos_emb_table.gather(
0, left_ind[:, :, None].to(pos_emb_table.device).expand(N, 4, f)
) # N x 4 x d
pos_re = pos_emb_table.gather(
0, r[:, :, None].to(pos_emb_table.device).expand(N, 4, f)
right_emb = pos_emb_table.gather(
0, right_ind[:, :, None].to(pos_emb_table.device).expand(N, 4, f)
) # N x 4 x d
pos_emb = lw[:, :, None] * pos_re.to(lw.device) + rw[:, :, None] * pos_le.to(
rw.device
)
pos_emb = left_weight[:, :, None] * right_emb.to(
left_weight.device
) + right_weight[:, :, None] * left_emb.to(right_weight.device)

pos_emb = pos_emb.view(N, 4 * f)
pos_emb = pos_emb.view(N, self.features)

return pos_emb

def _learned_temp_embedding(
self,
times: torch.Tensor,
features: int = 1024,
learn_temp_emb_num: int = 16,
**kwargs,
) -> torch.Tensor:
def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
"""Compute learned temporal embeddings for times using given parameters.
Args:
times: the input times.
features: Number of features in attention head.
learn_temp_emb_num: Size of the dictionary of embeddings.
times: the input times of shape (N,) or (N,1) where N = (sum(instances_per_frame))
which is the frame index of the instance relative
to the batch size
(e.g. `torch.tensor([0, 0, ..., 0, 1, 1, ..., 1, 2, 2, ..., 2,..., B, B, ...B])`).
Returns:
torch.Tensor, the learned temporal embeddings.
"""
params = {
"features": features,
"learn_temp_emb_num": learn_temp_emb_num,
**kwargs,
}

self.features = params["features"]
self.learn_temp_emb_num = params["learn_temp_emb_num"]

if self.temp_lookup is None:
self.temp_lookup = torch.nn.Embedding(
self.learn_temp_emb_num, self.features
)

temp_lookup = self.temp_lookup
temp_lookup = self.lookup
N = times.shape[0]

l, r, lw, rw = self._compute_weights(times, self.learn_temp_emb_num)
left_ind, right_ind, left_weight, right_weight = self._compute_weights(times)

le = temp_lookup.weight[l.to(temp_lookup.weight.device)] # T x D --> N x D
re = temp_lookup.weight[r.to(temp_lookup.weight.device)]
left_emb = temp_lookup.weight[
left_ind.to(temp_lookup.weight.device)
] # T x D --> N x D
right_emb = temp_lookup.weight[right_ind.to(temp_lookup.weight.device)]

temp_emb = lw[:, None] * re.to(lw.device) + rw[:, None] * le.to(rw.device)
temp_emb = left_weight[:, None] * right_emb.to(
left_weight.device
) + right_weight[:, None] * left_emb.to(right_weight.device)

return temp_emb.view(N, self.features)

def _compute_weights(
self, data: torch.Tensor, learn_emb_num: int = 16
) -> Tuple[torch.Tensor, ...]:
def _compute_weights(self, data: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Compute left and right learned embedding weights.
Args:
data: the input data (e.g boxes or times).
learn_temp_emb_num: Size of the dictionary of embeddings.
Returns:
A torch.Tensor for each of the left/right indices and weights, respectively
"""
data = data * learn_emb_num
data = data * self.emb_num

left_index = data.clamp(min=0, max=learn_emb_num - 1).long() # N x 4
right_index = (
(left_index + 1).clamp(min=0, max=learn_emb_num - 1).long()
) # N x 4
left_ind = data.clamp(min=0, max=self.emb_num - 1).long() # N x 4
right_ind = (left_ind + 1).clamp(min=0, max=self.emb_num - 1).long() # N x 4

left_weight = data - left_index.float() # N x 4
left_weight = data - left_ind.float() # N x 4

right_weight = 1.0 - left_weight

return left_index, right_index, left_weight, right_weight
return left_ind, right_ind, left_weight, right_weight
Loading

0 comments on commit 041d0a4

Please sign in to comment.