Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize TPU Flash Attention (400x speed-up on 32k long context) #845

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ds-hwang
Copy link
Contributor

@ds-hwang ds-hwang commented Nov 18, 2024

Optimize TPU Flash Attention (400x speed-up on 32k long context)

Use splash attention lazy mask instead of jnp mask, which is O(T^2).

The memory for jnp mask is O(T^2), which almost negates the benefits of
reducing HBM communication with flash attention. Let’s use splash attention
lazy mask, which lazily generates causal masks.

In addition, pallas supports CPU simulation (interpret=True), so use same
pallas kernel on CPU, which makes it easier to debug the code.

  • Benchmark: on TPUv5p, (model_dim/heads/kv_heads/seq_len).

NumpyMask (ASIS)

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512           1.71 ms         1.09 ms          592   (4.43M)
FlashAttentionBenchmark/2048/16/2/1024        4.44 ms         1.21 ms          483  (28.62M)
FlashAttentionBenchmark/4096/16/2/1024        8.61 ms         1.36 ms          302  (53.27M)
FlashAttentionBenchmark/4096/16/2/4096        3264 ms         1537 ms            1 (197.38M)
FlashAttentionBenchmark/4096/16/2/8192        7426 ms         5603 ms            1 (389.54M)
FlashAttentionBenchmark/4096/16/2/32768      94427 ms        92256 ms            1   (1.50G)

CausalMask (Proposed PR): This PR saves both memory and computation. In long
context, speed-up (400x) and HBM saving (3x).

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512           1.55 ms         1.01 ms          578   (3.43M)
FlashAttentionBenchmark/2048/16/2/1024        4.21 ms         1.11 ms          490  (13.57M)
FlashAttentionBenchmark/4096/16/2/1024        6.50 ms         1.17 ms          493  (24.22M)
FlashAttentionBenchmark/4096/16/2/4096        16.8 ms         1.38 ms          228  (84.33M)
FlashAttentionBenchmark/4096/16/2/8192        28.8 ms         1.58 ms          217 (164.50M)
FlashAttentionBenchmark/4096/16/2/32768        230 ms         6.36 ms           16 (644.60M)

Use splash attention lazy mask instead of jnp mask, which is O(T^2).

The memory for jnp mask is O(T^2), which almost negates the benefits of
reducing HBM communication with flash attention. Let’s use splash attention
lazy mask, which lazily generates causal masks.

In addition, pallas supports CPU simulation (interpret=True), so use same
pallas kernel on CPU, which makes it easier to debug the code.

* Benchmark: on TPUv5p, (model_dim/heads/kv_heads/seq_len).

NumpyMask (ASIS)
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512           1.71 ms         1.09 ms          592   (4.43M)
FlashAttentionBenchmark/2048/16/2/1024        4.44 ms         1.21 ms          483  (28.62M)
FlashAttentionBenchmark/4096/16/2/1024        8.61 ms         1.36 ms          302  (53.27M)
FlashAttentionBenchmark/4096/16/2/4096        3264 ms         1537 ms            1 (197.38M)
FlashAttentionBenchmark/4096/16/2/8192        7426 ms         5603 ms            1 (389.54M)
FlashAttentionBenchmark/4096/16/2/32768      94427 ms        92256 ms            1   (1.50G)

CausalMask (Proposed PR): This PR saves both memory and computation. In long
context, speed-up (400x) and HBM saving (3x).
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512           1.55 ms         1.01 ms          578   (3.43M)
FlashAttentionBenchmark/2048/16/2/1024        4.21 ms         1.11 ms          490  (13.57M)
FlashAttentionBenchmark/4096/16/2/1024        6.50 ms         1.17 ms          493  (24.22M)
FlashAttentionBenchmark/4096/16/2/4096        16.8 ms         1.38 ms          228  (84.33M)
FlashAttentionBenchmark/4096/16/2/8192        28.8 ms         1.58 ms          217 (164.50M)
FlashAttentionBenchmark/4096/16/2/32768        230 ms         6.36 ms           16 (644.60M)
@ds-hwang ds-hwang changed the title Optimize TPU Flash Attention Optimize TPU Flash Attention (400x speed-up on 32k long context) Nov 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant