Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimize TPU Flash Attention (400x speed-up on 32k long context)
Use splash attention lazy mask instead of jnp mask, which is O(T^2). The memory for jnp mask is O(T^2), which almost negates the benefits of reducing HBM communication with flash attention. Let’s use splash attention lazy mask, which lazily generates causal masks. In addition, pallas supports CPU simulation (interpret=True), so use same pallas kernel on CPU, which makes it easier to debug the code. * Benchmark: on TPUv5p, (model_dim/heads/kv_heads/seq_len). NumpyMask (ASIS) ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/256/2/2/512 1.71 ms 1.09 ms 592 (4.43M) FlashAttentionBenchmark/2048/16/2/1024 4.44 ms 1.21 ms 483 (28.62M) FlashAttentionBenchmark/4096/16/2/1024 8.61 ms 1.36 ms 302 (53.27M) FlashAttentionBenchmark/4096/16/2/4096 3264 ms 1537 ms 1 (197.38M) FlashAttentionBenchmark/4096/16/2/8192 7426 ms 5603 ms 1 (389.54M) FlashAttentionBenchmark/4096/16/2/32768 94427 ms 92256 ms 1 (1.50G) CausalMask (Proposed PR): This PR saves both memory and computation. In long context, speed-up (400x) and HBM saving (3x). ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/256/2/2/512 1.55 ms 1.01 ms 578 (3.43M) FlashAttentionBenchmark/2048/16/2/1024 4.21 ms 1.11 ms 490 (13.57M) FlashAttentionBenchmark/4096/16/2/1024 6.50 ms 1.17 ms 493 (24.22M) FlashAttentionBenchmark/4096/16/2/4096 16.8 ms 1.38 ms 228 (84.33M) FlashAttentionBenchmark/4096/16/2/8192 28.8 ms 1.58 ms 217 (164.50M) FlashAttentionBenchmark/4096/16/2/32768 230 ms 6.36 ms 16 (644.60M)
- Loading branch information