Skip to content

Commit 65f7498

Browse files
cleanup nits
1 parent c06bd7f commit 65f7498

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

torchtune/modules/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ def forward(
270270
k = self.pos_embeddings(k, input_pos=input_pos)
271271

272272
# k,v shape: [b, n_kv, s_y, h_d]
273-
k, v = k.transpose(1, 2), v.transpose(1, 2)
273+
k = k.transpose(1, 2)
274+
v = v.transpose(1, 2)
274275

275276
# Update key-value cache
276277
if self.kv_cache is not None and self.cache_enabled:

torchtune/modules/attention_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def _attention_call(
183183
dropout_p: float,
184184
is_causal: bool,
185185
) -> torch.Tensor:
186+
186187
# Flex attention uses the BlockMask
187188
# (https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L168)
188189
# instead of a traditional boolean tensor mask. If this is passed in,

torchtune/modules/kv_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def reset(self) -> None:
5151

5252
@property
5353
def size(self) -> int:
54-
return int(self.cache_pos[0].item())
54+
return self.cache_pos[0].item()
5555

5656
def update(
5757
self, k_val: torch.Tensor, v_val: torch.Tensor

0 commit comments

Comments
 (0)