Skip to content

Commit

Permalink
Merge pull request #30 from talmolab/aadi/implement-fixed-temp-embedding
Browse files Browse the repository at this point in the history
Implement fixed temp embedding
  • Loading branch information
talmo authored Apr 30, 2024
2 parents 041d0a4 + e3f577b commit a47c48b
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 21 deletions.
2 changes: 1 addition & 1 deletion biogtr/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
point_scores: ArrayLike = None,
instance_score: float = -1.0,
skeleton: sio.Skeleton = None,
pose: dict[str, ArrayLike] = np.array([]),
pose: dict[str, ArrayLike] = None,
device: str = None,
):
"""Initialize Instance.
Expand Down
2 changes: 1 addition & 1 deletion biogtr/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict
crop=crop,
bbox=bbox,
skeleton=skeleton,
pose=np.array(list(poses[j].values())),
pose=poses[j],
point_scores=point_scores[j],
instance_score=instance_score[j],
)
Expand Down
52 changes: 43 additions & 9 deletions biogtr/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
if self.emb_type == "pos":
self._emb_func = self._sine_box_embedding
elif self.emb_type == "temp":
pass # TODO Implement fixed sine temporal embedding
self._emb_func = self._sine_temp_embedding

def _check_init_args(self, emb_type: str, mode: str):
"""Check whether the correct arguments were passed to initialization.
Expand All @@ -108,9 +108,6 @@ def _check_init_args(self, emb_type: str, mode: str):
f"Embedding `mode` must be one of {self.EMB_MODES} not {mode}"
)

if mode == "fixed" and emb_type == "temp":
raise NotImplementedError("TODO: Implement Fixed Sinusoidal Temp Embedding")

def forward(self, seq_positions: torch.Tensor) -> torch.Tensor:
"""Get the sequence positional embeddings.
Expand Down Expand Up @@ -141,13 +138,18 @@ def _torch_int_div(
def _sine_box_embedding(self, boxes: torch.Tensor) -> torch.Tensor:
"""Compute sine positional embeddings for boxes using given parameters.
Args:
boxes: the input boxes of shape N x 4 or B x N x 4
where the last dimension is the bbox coords in [y1, x1, y2, x2].
(Note currently `B=batch_size=1`).
Args:
boxes: the input boxes of shape N x 4 or B x N x 4
where the last dimension is the bbox coords in [y1, x1, y2, x2].
(Note currently `B=batch_size=1`).
Returns:
torch.Tensor, the sine positional embeddings.
torch.Tensor, the sine positional embeddings
(embedding[:, 4i] = sin(x)
embedding[:, 4i+1] = cos(x)
embedding[:, 4i+2] = sin(y)
embedding[:, 4i+3] = cos(y)
)
"""
if self.scale is not None and self.normalize is False:
raise ValueError("normalize should be True if scale is passed")
Expand Down Expand Up @@ -176,6 +178,38 @@ def _sine_box_embedding(self, boxes: torch.Tensor) -> torch.Tensor:

return pos_emb

def _sine_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
"""Compute fixed sine temporal embeddings.
Args:
times: the input times of shape (N,) or (N,1) where N = (sum(instances_per_frame))
which is the frame index of the instance relative
to the batch size
(e.g. `torch.tensor([0, 0, ..., 0, 1, 1, ..., 1, 2, 2, ..., 2,..., B, B, ...B])`).
Returns:
an n_instances x D embedding representing the temporal embedding.
"""
T = times.int().max().item() + 1
d = self.features
n = self.temperature

positions = torch.arange(0, T).unsqueeze(1)
temp_lookup = torch.zeros(T, d, device=times.device)

denominators = torch.pow(
n, 2 * torch.arange(0, d // 2) / d
) # 10000^(2i/d_model), i is the index of embedding
temp_lookup[:, 0::2] = torch.sin(
positions / denominators
) # sin(pos/10000^(2i/d_model))
temp_lookup[:, 1::2] = torch.cos(
positions / denominators
) # cos(pos/10000^(2i/d_model))

temp_emb = temp_lookup[times.int()]
return temp_emb # .view(len(times), self.features)

def _learned_pos_embedding(self, boxes: torch.Tensor) -> torch.Tensor:
"""Compute learned positional embeddings for boxes using given parameters.
Expand Down
4 changes: 2 additions & 2 deletions biogtr/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]:
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors containing the
bounding boxes and corresponding frame
indices, respectively.
bounding boxes normalized by the height and width of the image
and corresponding frame indices, respectively.
"""
boxes, times = [], []
_, h, w = frames[0].img_shape.flatten()
Expand Down
19 changes: 11 additions & 8 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ def test_embedding_validity():
with pytest.raises(Exception):
_ = Embedding(emb_type="temporal", mode="learn", features=128)

with pytest.raises(Exception):
_ = Embedding(emb_type="temp", mode="fixed", features=128)

_ = Embedding(emb_type="temp", mode="learned", features=128)
_ = Embedding(emb_type="pos", mode="learned", features=128)

Expand Down Expand Up @@ -185,20 +182,23 @@ def test_embedding_kwargs():

N = frames * objects

boxes = torch.rand(size=(N, 4))
times = torch.rand(size=(N,))
boxes = torch.rand(N, 2) * (1024 - 128)
boxes = torch.concat([boxes / 1024, (boxes + 128) / 1024], axis=-1)

# sine embedding

sine_no_args = Embedding("pos", "fixed", 128)(boxes)

sine_args = {
"temperature": objects,
"scale": frames,
"normalize": True,
}
sine_no_args = Embedding("pos", "fixed", 128)
sine_with_args = Embedding("pos", "fixed", 128, **sine_args)

assert sine_no_args.temperature != sine_with_args.temperature

sine_with_args = Embedding("pos", "fixed", 128, **sine_args)(boxes)
sine_no_args = sine_no_args(boxes)
sine_with_args = sine_with_args(boxes)

assert not torch.equal(sine_no_args, sine_with_args)

Expand Down Expand Up @@ -336,6 +336,9 @@ def test_transformer_embedding():
return_embedding=True,
)

assert transformer.pos_emb.mode == "learned"
assert transformer.temp_emb.mode == "learned"

asso_preds, embedding = transformer(frames)

assert asso_preds[0].size() == (num_detected * num_frames,) * 2
Expand Down

0 comments on commit a47c48b

Please sign in to comment.