Skip to content

Commit 91031f4

Browse files
simplify expand for num_kv_heads
1 parent bb842e0 commit 91031f4

File tree

2 files changed

+4
-19
lines changed

2 files changed

+4
-19
lines changed

torchtune/modules/attention.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99

1010
import torch
1111
from torch import nn
12-
from torchtune.modules.attention_utils import (
13-
_MaskType,
14-
_sdpa_or_flex_attention,
15-
repeat_interleave,
16-
)
12+
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1713
from torchtune.modules.kv_cache import KVCache
1814

1915
logger = logging.getLogger(__name__)
@@ -284,8 +280,9 @@ def forward(
284280
# as the query tensor by copying values across the relevant dim
285281
# k,v shape: [b, n_h, s, h_d]
286282
if self.num_heads != self.num_kv_heads:
287-
k = repeat_interleave(k, dim=1, repeat=q_per_kv)
288-
v = repeat_interleave(v, dim=1, repeat=q_per_kv)
283+
expand_shape = (-1, -1, q_per_kv, -1, -1)
284+
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
285+
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
289286

290287
# Normalize k
291288
if self.k_norm is not None:

torchtune/modules/attention_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,3 @@ def _attention_call(
246246
)
247247

248248
return _attention_call
249-
250-
251-
def repeat_interleave(x: torch.Tensor, *, dim: int, repeat: int) -> torch.Tensor:
252-
if repeat == 1:
253-
return x
254-
255-
dim = dim + x.ndim if dim < 0 else dim
256-
257-
shape = [-1] * (x.ndim + 1)
258-
shape[dim + 1] = repeat
259-
260-
return x.unsqueeze(dim + 1).expand(shape).flatten(dim, dim + 1)

0 commit comments

Comments
 (0)