File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -149,7 +149,7 @@ def setup_cache(
149149 self .kv_cache = KVCache (
150150 batch_size = batch_size ,
151151 max_seq_len = max_seq_len ,
152- num_heads = self .num_heads ,
152+ num_kv_heads = self .num_heads ,
153153 head_dim = self .head_dim ,
154154 dtype = dtype ,
155155 )
@@ -211,9 +211,9 @@ def forward(
211211 - h_d: head dim
212212 """
213213 # until flex attention implementation exists, we do not accept block masks
214- if ( mask is not None ) and (type (mask ) != torch .Tensor ( )):
214+ if mask is not None and (not isinstance (mask , torch .Tensor )):
215215 raise NotImplementedError (
216- "Block masks are not implemeted yet, use packed=False"
216+ "Block masks are not implemeted yet, use packed=False. "
217217 )
218218
219219 # x has shape [b, s_x, d]
You can’t perform that action at this time.
0 commit comments