Skip to content

Commit 5f0d2b7

Browse files
Update gemma-2 kvcache constructor and fix mask type check.
1 parent 39b9801 commit 5f0d2b7

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchtune/models/gemma2/_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def setup_cache(
149149
self.kv_cache = KVCache(
150150
batch_size=batch_size,
151151
max_seq_len=max_seq_len,
152-
num_heads=self.num_heads,
152+
num_kv_heads=self.num_heads,
153153
head_dim=self.head_dim,
154154
dtype=dtype,
155155
)
@@ -211,9 +211,9 @@ def forward(
211211
- h_d: head dim
212212
"""
213213
# until flex attention implementation exists, we do not accept block masks
214-
if (mask is not None) and (type(mask) != torch.Tensor()):
214+
if mask is not None and (not isinstance(mask, torch.Tensor)):
215215
raise NotImplementedError(
216-
"Block masks are not implemeted yet, use packed=False"
216+
"Block masks are not implemeted yet, use packed=False."
217217
)
218218

219219
# x has shape [b, s_x, d]

0 commit comments

Comments
 (0)