Skip to content

Commit e1caa9f

Browse files
Update KV Cache to use num_kv_heads instead of num_heads (#1961)
1 parent 08efaed commit e1caa9f

File tree

6 files changed

+28
-43
lines changed

6 files changed

+28
-43
lines changed

tests/torchtune/models/llama2/scripts/compare_fused_attention.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def compare_attn(
256256
max_seq_len: int,
257257
use_kv_cache: bool,
258258
):
259-
260259
torch.manual_seed(16)
261260
inputs = torch.randn(4, 2048, 4096)
262261

@@ -269,8 +268,9 @@ def compare_attn(
269268
kv_cache = KVCache(
270269
batch_size=4,
271270
max_seq_len=max_seq_len,
272-
n_kv_heads=num_heads,
271+
num_kv_heads=num_kv_heads,
273272
head_dim=head_dim,
273+
dtype=inputs.dtype,
274274
)
275275
else:
276276
kv_cache = None
@@ -330,7 +330,6 @@ def compare_attn(
330330

331331

332332
if __name__ == "__main__":
333-
334333
# compare mha
335334
mha = {
336335
"num_heads": 32,

tests/torchtune/models/llama2/scripts/compare_lora_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def compare_lora_attention(
3333
lora_rank: int,
3434
lora_alpha: float,
3535
) -> None:
36-
3736
# make sure we have the right seed for generating outputs
3837
# this should match up the seed value set in the corresponding
3938
# unit test
@@ -68,8 +67,9 @@ def compare_lora_attention(
6867
KVCache(
6968
batch_size=batch_size,
7069
max_seq_len=max_seq_len,
71-
n_kv_heads=num_heads,
70+
num_kv_heads=num_kv_heads,
7271
head_dim=head_dim,
72+
dtype=x.dtype,
7373
)
7474
if batch_size is not None
7575
else None

tests/torchtune/modules/test_attention.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def gqa_kv_cache(
123123
kv_cache = KVCache(
124124
batch_size=4,
125125
max_seq_len=max_seq_len,
126-
num_heads=num_heads,
126+
num_kv_heads=num_kv_heads,
127127
head_dim=head_dim,
128128
dtype=torch.float32,
129129
)
@@ -178,7 +178,7 @@ def mha_kv_cache(
178178
kv_cache = KVCache(
179179
batch_size=4,
180180
max_seq_len=max_seq_len,
181-
num_heads=num_heads,
181+
num_kv_heads=num_kv_heads,
182182
head_dim=head_dim,
183183
dtype=torch.float32,
184184
)
@@ -233,7 +233,7 @@ def mqa_kv_cache(
233233
kv_cache = KVCache(
234234
batch_size=4,
235235
max_seq_len=max_seq_len,
236-
num_heads=num_heads,
236+
num_kv_heads=num_kv_heads,
237237
head_dim=head_dim,
238238
dtype=torch.float32,
239239
)
@@ -267,7 +267,6 @@ def test_forward_gqa(self, input: torch.Tensor, gqa: MultiHeadAttention) -> None
267267
def test_forward_gqa_kv_cache(
268268
self, input: torch.Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa
269269
) -> None:
270-
271270
_, _, _, max_seq_len = attn_params_gqa
272271
_, seq_len, _ = input.shape
273272

@@ -293,7 +292,6 @@ def test_forward_mha(self, input: torch.Tensor, mha: MultiHeadAttention) -> None
293292
def test_forward_mha_kv_cache(
294293
self, input: torch.Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha
295294
) -> None:
296-
297295
_, _, _, max_seq_len = attn_params_mha
298296
_, seq_len, _ = input.shape
299297

torchtune/models/gemma2/_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def setup_cache(
149149
self.kv_cache = KVCache(
150150
batch_size=batch_size,
151151
max_seq_len=max_seq_len,
152-
num_heads=self.num_heads,
152+
num_kv_heads=self.num_heads,
153153
head_dim=self.head_dim,
154154
dtype=dtype,
155155
)
@@ -211,9 +211,9 @@ def forward(
211211
- h_d: head dim
212212
"""
213213
# until flex attention implementation exists, we do not accept block masks
214-
if (mask is not None) and (type(mask) != torch.Tensor()):
214+
if mask is not None and (not isinstance(mask, torch.Tensor)):
215215
raise NotImplementedError(
216-
"Block masks are not implemeted yet, use packed=False"
216+
"Block masks are not implemeted yet, use packed=False."
217217
)
218218

219219
# x has shape [b, s_x, d]

torchtune/modules/attention.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def setup_cache(
164164
self.kv_cache = KVCache(
165165
batch_size=batch_size,
166166
max_seq_len=max_seq_len,
167-
num_heads=self.num_heads,
167+
num_kv_heads=self.num_kv_heads,
168168
head_dim=self.head_dim,
169169
dtype=dtype,
170170
)
@@ -258,47 +258,37 @@ def forward(
258258
else:
259259
# Update k and v shape, positional embeddings, and normalization
260260

261-
# k has shape [b, s_y, num_kv_heads * head_dim]
262-
# v has shape [b, s_y, num_kv_heads * head_dim]
261+
# k,v shape [b, s_y, num_kv_heads * head_dim]
263262
k = self.k_proj(y)
264263
v = self.v_proj(y)
265264

266265
# Apply positional embeddings
267-
# k: [b, s_y, n_kv, h_d]
266+
# k,v shape: [b, s_y, n_kv, h_d]
268267
k = k.view(b, s_y, -1, self.head_dim)
268+
v = v.view(b, s_y, -1, self.head_dim)
269269
if self.pos_embeddings is not None:
270270
k = self.pos_embeddings(k, input_pos=input_pos)
271271

272-
# View + expand + reshape bring num_kv_heads to num_heads for k and v
273-
# to match q.
272+
# k,v shape: [b, n_kv, s_y, h_d]
273+
k = k.transpose(1, 2)
274+
v = v.transpose(1, 2)
274275

275-
# k: [b, s_y, n_kv, 1, h_d]
276-
# v: [b, s_y, n_kv, 1, h_d]
277-
k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
278-
v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
276+
# Update key-value cache
277+
if self.kv_cache is not None and self.cache_enabled:
278+
k, v = self.kv_cache.update(k, v)
279279

280280
# If needed, expand the key and value tensors to have the same shape
281281
# as the query tensor by copying values across the relevant dim
282+
# k,v shape: [b, n_h, s, h_d]
282283
if self.num_heads != self.num_kv_heads:
283-
k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)
284-
v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)
285-
286-
# [b, s, n_h, h_d]
287-
k = k.reshape(b, s_y, -1, self.head_dim)
288-
v = v.reshape(b, s_y, -1, self.head_dim)
289-
290-
# [b, n_h, s, h_d]
291-
k = k.transpose(1, 2)
292-
v = v.transpose(1, 2)
284+
expand_shape = (-1, -1, q_per_kv, -1, -1)
285+
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
286+
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
293287

294288
# Normalize k
295289
if self.k_norm is not None:
296290
k = self.k_norm(k)
297291

298-
# Update key-value cache
299-
if self.kv_cache is not None and self.cache_enabled:
300-
k, v = self.kv_cache.update(k, v)
301-
302292
output = self._attention_call(
303293
q,
304294
k,

torchtune/modules/kv_cache.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ class KVCache(nn.Module):
1717
Args:
1818
batch_size (int): batch size model will be run with
1919
max_seq_len (int): maximum sequence length model will be run with
20-
num_heads (int): number of heads. We take num_heads instead of num_kv_heads because
21-
the cache is created after we've expanded the key and value tensors to have the
22-
same shape as the query tensor. See attention.py for more details
20+
num_kv_heads (int): number of key/value heads.
2321
head_dim (int): per-attention head embedding dimension
2422
dtype (torch.dtype): dtype for the caches
2523
"""
@@ -28,12 +26,12 @@ def __init__(
2826
self,
2927
batch_size: int,
3028
max_seq_len: int,
31-
num_heads: int,
29+
num_kv_heads: int,
3230
head_dim: int,
3331
dtype: torch.dtype,
3432
) -> None:
3533
super().__init__()
36-
cache_shape = (batch_size, num_heads, max_seq_len, head_dim)
34+
cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
3735
self.register_buffer(
3836
"k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
3937
)
@@ -66,7 +64,7 @@ def update(
6664
already been filled, use ``.reset()``, which will reset the cache to the zero-th position.
6765
6866
Example:
69-
>>> cache = KVCache(batch_size=2, max_seq_len=16, num_heads=4, head_dim=32, dtype=torch.bfloat16)
67+
>>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16)
7068
>>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
7169
>>> cache.update(keys, values)
7270
>>> # now positions 0 through 7 are filled

0 commit comments

Comments
 (0)