@@ -270,6 +270,15 @@ def strided_load_kv(ref, start, step):
270
270
b = jnp .left_shift (b , bw * (packing - 1 ))
271
271
return pltpu .bitcast (b , jnp .float32 ).astype (jnp .bfloat16 )
272
272
273
+ def fold_on_2nd_minor (vec ):
274
+ assert vec .dtype == jnp .bfloat16 or vec .dtype == jnp .float32
275
+ assert len (vec .shape ) >= 2
276
+ last_dim = vec .shape [- 1 ]
277
+ packing = get_dtype_packing (vec .dtype )
278
+ if vec .shape [- 2 ] % packing != 0 :
279
+ vec = vec .astype (jnp .float32 )
280
+ return vec .reshape (- 1 , last_dim )
281
+
273
282
@pl .when (heads_blk_idx + q_blk_idx == 0 )
274
283
def prefetch_first_kv_blk ():
275
284
async_copy_k , async_copy_v = create_kv_async_copy_descriptors (
@@ -495,9 +504,9 @@ def prefetch_next_kv_blk():
495
504
q_head_idx = kv_head_idx * num_q_heads_per_kv_head
496
505
# TODO(jevinjiang): extra handlig for packed type that can start at
497
506
# unaligned position!
498
- q = q_ref [
499
- :, q_head_idx : q_head_idx + num_q_heads_per_kv_head , :
500
- ]. reshape ( - 1 , head_dim )
507
+ q = fold_on_2nd_minor (
508
+ q_ref [ :, q_head_idx : q_head_idx + num_q_heads_per_kv_head , :]
509
+ )
501
510
k = strided_load_kv (k_ref , kv_head_idx , num_kv_heads_per_blk )
502
511
v = strided_load_kv (v_ref , kv_head_idx , num_kv_heads_per_blk )
503
512
flash_attention (
0 commit comments