Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
shaikh58 committed Aug 19, 2024
1 parent 5a7e86b commit d5993a9
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 163 deletions.
1 change: 0 additions & 1 deletion dreem/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,6 @@ def _run_global_tracker(
# hungarian matching
match_i, match_j = linear_sum_assignment((-traj_score))


track_ids = instance_ids.new_full((n_query,), -1)
for i, j in zip(match_i, match_j):
# The overlap threshold is multiplied by the number of times the unique track j is matched to an
Expand Down
4 changes: 3 additions & 1 deletion dreem/io/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,9 @@ def add_embedding(self, emb_type: str, embedding: torch.Tensor) -> None:
emb_type: Key/embedding type to be saved to dictionary
embedding: The actual torch tensor embedding.
"""
if type(embedding) != dict: # for embedding agg method "average", input is array
if (
type(embedding) != dict
): # for embedding agg method "average", input is array
# for method stack and concatenate, input is dict
embedding = _expand_to_rank(embedding, 2)
self._embeddings[emb_type] = embedding
Expand Down
48 changes: 25 additions & 23 deletions dreem/models/attention_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,7 @@
class ATTWeightHead(torch.nn.Module):
"""Single attention head."""

def __init__(
self,
feature_dim: int,
num_layers: int,
dropout: float,
**kwargs
):
def __init__(self, feature_dim: int, num_layers: int, dropout: float, **kwargs):
"""Initialize an instance of ATTWeightHead.
Args:
Expand All @@ -25,23 +19,27 @@ def __init__(
embedding_agg_method: how the embeddings are aggregated; average/stack/concatenate
"""
super().__init__()
if 'embedding_agg_method' in kwargs:
self.embedding_agg_method = kwargs['embedding_agg_method']
if "embedding_agg_method" in kwargs:
self.embedding_agg_method = kwargs["embedding_agg_method"]
else:
self.embedding_agg_method = None

# if using stacked embeddings, use 1x1 conv with x,y,t embeddings as channels
# ensures output represents ref instances by query instances
if self.embedding_agg_method == "stack":
self.q_proj = torch.nn.Conv1d(in_channels=3, out_channels=1,
kernel_size=1, stride=1, padding=0
)
self.k_proj = torch.nn.Conv1d(in_channels=3, out_channels=1,
kernel_size=1, stride=1, padding=0
)
self.q_proj = torch.nn.Conv1d(
in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0
)
self.k_proj = torch.nn.Conv1d(
in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0
)
else:
self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
self.q_proj = MLP(
feature_dim, feature_dim, feature_dim, num_layers, dropout
)
self.k_proj = MLP(
feature_dim, feature_dim, feature_dim, num_layers, dropout
)

def forward(
self,
Expand All @@ -63,12 +61,16 @@ def forward(
# if stacked embeddings, create channels for each x,y,t embedding dimension
# maps shape (1,192,1024) -> (1,64,3,1024)
if self.embedding_agg_method == "stack":
key = key.view(
batch_size, 3, num_window_instances//3, feature_dim
).permute(0, 2, 1, 3).squeeze(0)
query = query.view(
batch_size, 3, num_query_instances//3, feature_dim
).permute(0, 2, 1, 3).squeeze(0)
key = (
key.view(batch_size, 3, num_window_instances // 3, feature_dim)
.permute(0, 2, 1, 3)
.squeeze(0)
)
query = (
query.view(batch_size, 3, num_query_instances // 3, feature_dim)
.permute(0, 2, 1, 3)
.squeeze(0)
)
# key, query of shape (batch_size, num_instances, 3, feature_dim)
k = self.k_proj(key).transpose(1, 0)
q = self.q_proj(query).transpose(1, 0)
Expand Down
46 changes: 19 additions & 27 deletions dreem/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
# 100 since it's a fraction of [0,1]*100. temp is from [0, clip_len]; since clip_len
# not available, we use the last value in the indexing array since this will be the
# last possible frame that we would need to index since no instances in a frame after that
self.build_rope_cache(max(101, input_pos[:, -1].max() + 1)) # registers cache
self.build_rope_cache(max(101, input_pos[:, -1].max() + 1)) # registers cache
self.cache = self.cache.to(input_pos.device)
# extract the values based on whether input_pos is set or not
rope_cache = (
Expand All @@ -121,9 +121,8 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
return rope_cache



class Embedding(torch.nn.Module):
"""Class that wraps around different embedding types.
"""Class that wraps around different embedding types.
Creates embedding array and transforms the input data
Used for both learned and fixed embeddings.
"""
Expand Down Expand Up @@ -153,7 +152,7 @@ def __init__(
normalize: bool = False,
scale: float | None = None,
mlp_cfg: dict | None = None,
embedding_agg_method: str = "average"
embedding_agg_method: str = "average",
):
"""Initialize embeddings.
Expand Down Expand Up @@ -228,18 +227,17 @@ def __init__(
if self.emb_type == "pos":
if self.embedding_agg_method == "average":
self._emb_func = self._sine_box_embedding
else: # if using stacked/concatenated agg method
else: # if using stacked/concatenated agg method
self._emb_func = self._sine_pos_embedding
elif self.emb_type == "temp":
self._emb_func = self._sine_temp_embedding

elif self.mode == "rope":
# pos/temp embeddings processed the same way with different embedding array inputs
self._emb_func = self._rope_embedding
# create instance so embedding lookup array is created only once
self.rope_instance = RotaryPositionalEmbeddings(self.features)


def _check_init_args(self, emb_type: str, mode: str):
"""Check whether the correct arguments were passed to initialization.
Expand Down Expand Up @@ -268,7 +266,6 @@ def _check_init_args(self, emb_type: str, mode: str):
f"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
)


def _transform(self, x, emb):
"""Routes to the relevant embedding function to transform the input queries
Expand All @@ -281,15 +278,14 @@ def _transform(self, x, emb):
return self._apply_rope(x, emb)
else:
return self._apply_additive_embeddings(x, emb)


def _apply_rope(self, x, emb):

def _apply_rope(self, x, emb):
"""Applies Rotary Positional Embedding to input queries
Args:
x: Input queries of shape (batch_size, n_query, embed_dim)
emb: Rotation matrix of shape (batch_size, n_query, num_heads, embed_dim // 2, 2)
Returns:
Tensor of input queries transformed by RoPE
"""
Expand All @@ -300,33 +296,29 @@ def _apply_rope(self, x, emb):
# apply RoPE to each query token
xout = torch.stack(
[
xout[..., 0] * emb[..., 0]
- xout[..., 1] * emb[..., 1],
xout[..., 1] * emb[..., 0]
+ xout[..., 0] * emb[..., 1],
xout[..., 0] * emb[..., 0] - xout[..., 1] * emb[..., 1],
xout[..., 1] * emb[..., 0] + xout[..., 0] * emb[..., 1],
],
-1,
)
# output has shape [batch_size, n_query, num_heads, embed_dim]
xout = xout.flatten(3).squeeze(2)

return xout



def _apply_additive_embeddings(self, x, emb):
"""Applies additive embeddings to input queries
Args:
x: Input tensor of shape (batch_size, N, embed_dim)
emb: Embedding array of shape (N, embed_dim)
Returns:
Tensor: Input queries with embeddings added - shape (batch_size, N, embed_dim)
"""
_emb = emb.unsqueeze(0)
return x + _emb



def forward(self, x, seq_positions: torch.Tensor) -> torch.Tensor:
"""Get the sequence positional embeddings.
Expand All @@ -341,8 +333,8 @@ def forward(self, x, seq_positions: torch.Tensor) -> torch.Tensor:
- An `N` x `self.features` tensor representing the corresponding spatial or temporal embedding.
"""

# create embedding array; either rotation matrix of shape
# (batch_size, n_query, num_heads, embed_dim // 2, 2),
# create embedding array; either rotation matrix of shape
# (batch_size, n_query, num_heads, embed_dim // 2, 2),
# or (N, embed_dim) array
emb = self._emb_func(seq_positions, x.size())
# transform the input data with the embedding
Expand All @@ -364,8 +356,9 @@ def _torch_int_div(
"""
return torch.div(tensor1, tensor2, rounding_mode="floor")


def _rope_embedding(self, seq_positions: torch.Tensor, input_shape: torch.Size) -> torch.Tensor:
def _rope_embedding(
self, seq_positions: torch.Tensor, input_shape: torch.Size
) -> torch.Tensor:
"""Computes the rotation matrix to apply RoPE to input queries
Args:
seq_positions: Pos array of shape (embed_dim,) used to compute rotational embedding
Expand All @@ -380,12 +373,11 @@ def _rope_embedding(self, seq_positions: torch.Tensor, input_shape: torch.Size)
is_pos_emb = 1 if seq_positions.max() <= 1 else 0
# if it is positional, scale seq_positions since these are fractions
# in [0,1] and we need int indexes for embedding lookup
seq_positions = seq_positions*100 if is_pos_emb else seq_positions
seq_positions = seq_positions * 100 if is_pos_emb else seq_positions
# RoPE module takes in dimension, num_queries as input to calculate rotation matrix
rot_mat = self.rope_instance(x_rope, seq_positions.unsqueeze(0).int())

return rot_mat


def _sine_pos_embedding(self, centroids: torch.Tensor, *args) -> torch.Tensor:
"""Compute fixed sine temporal embeddings per dimension (x,y)
Expand Down
1 change: 0 additions & 1 deletion dreem/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(
# list concatenations to ensure layer shape compability
for n, k in zip([input_dim] + h, h + [output_dim])
]

)
if self.dropout > 0.0:
self.dropouts = torch.nn.ModuleList(
Expand Down
Loading

0 comments on commit d5993a9

Please sign in to comment.