Skip to content

Commit 4b49c03

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
Open source TPU-friendly ragged paged attention kernel.
Key features: * ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.) * ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly. * ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***! * ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode. * ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine. PiperOrigin-RevId: 734269519
1 parent 5d64b3d commit 4b49c03

File tree

3 files changed

+1035
-0
lines changed

3 files changed

+1035
-0
lines changed

0 commit comments

Comments
 (0)