Skip to content

Commit 041f575

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
Support MHA in ragged paged attention for packed type
PiperOrigin-RevId: 734695213
1 parent 6095af0 commit 041f575

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

jax/experimental/pallas/ops/tpu/ragged_paged_attention.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,15 @@ def strided_load_kv(ref, start, step):
270270
b = jnp.left_shift(b, bw * (packing - 1))
271271
return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16)
272272

273+
def fold_on_2nd_minor(vec):
274+
assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32
275+
assert len(vec.shape) >= 2
276+
last_dim = vec.shape[-1]
277+
packing = get_dtype_packing(vec.dtype)
278+
if vec.shape[-2] % packing != 0:
279+
vec = vec.astype(jnp.float32)
280+
return vec.reshape(-1, last_dim)
281+
273282
@pl.when(heads_blk_idx + q_blk_idx == 0)
274283
def prefetch_first_kv_blk():
275284
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(
@@ -495,9 +504,9 @@ def prefetch_next_kv_blk():
495504
q_head_idx = kv_head_idx * num_q_heads_per_kv_head
496505
# TODO(jevinjiang): extra handlig for packed type that can start at
497506
# unaligned position!
498-
q = q_ref[
499-
:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :
500-
].reshape(-1, head_dim)
507+
q = fold_on_2nd_minor(
508+
q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :]
509+
)
501510
k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk)
502511
v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk)
503512
flash_attention(

tests/pallas/tpu_ragged_paged_attention_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def test_ragged_paged_attention_mixed(self, dtype):
266266
@parameterized.product(
267267
num_seqs=[1, 5, 16],
268268
# TODO(jevinjiang): Support more num_heads!
269-
num_heads=[(32, 8), (32, 16), (12, 2)],
269+
num_heads=[(32, 8), (32, 16), (12, 2), (4, 4)],
270270
dtype=[jnp.float32, jnp.bfloat16],
271271
num_kv_pages_per_block=[4, 8],
272272
num_queries_per_block=[32, 64],

0 commit comments

Comments
 (0)