Skip to content

Commit

Permalink
Optimize TPU Flash Attention (400x speed-up on 32k long context)
Browse files Browse the repository at this point in the history
Use splash attention lazy mask instead of jnp mask, which is O(T^2).

The memory for jnp mask is O(T^2), which almost negates the benefits of
reducing HBM communication with flash attention. Let’s use splash attention
lazy mask, which lazily generates causal masks.

In addition, pallas supports CPU simulation (interpret=True), so use same
pallas kernel on CPU, which makes it easier to debug the code.

* Benchmark: on TPUv5p, (model_dim/heads/kv_heads/seq_len).

NumpyMask (ASIS)
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512           1.71 ms         1.09 ms          592   (4.43M)
FlashAttentionBenchmark/2048/16/2/1024        4.44 ms         1.21 ms          483  (28.62M)
FlashAttentionBenchmark/4096/16/2/1024        8.61 ms         1.36 ms          302  (53.27M)
FlashAttentionBenchmark/4096/16/2/4096        3264 ms         1537 ms            1 (197.38M)
FlashAttentionBenchmark/4096/16/2/8192        7426 ms         5603 ms            1 (389.54M)
FlashAttentionBenchmark/4096/16/2/32768      94427 ms        92256 ms            1   (1.50G)

CausalMask (Proposed PR): This PR saves both memory and computation. In long
context, speed-up (400x) and HBM saving (3x).
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512           1.55 ms         1.01 ms          578   (3.43M)
FlashAttentionBenchmark/2048/16/2/1024        4.21 ms         1.11 ms          490  (13.57M)
FlashAttentionBenchmark/4096/16/2/1024        6.50 ms         1.17 ms          493  (24.22M)
FlashAttentionBenchmark/4096/16/2/4096        16.8 ms         1.38 ms          228  (84.33M)
FlashAttentionBenchmark/4096/16/2/8192        28.8 ms         1.58 ms          217 (164.50M)
FlashAttentionBenchmark/4096/16/2/32768        230 ms         6.36 ms           16 (644.60M)
  • Loading branch information
ds-hwang committed Nov 19, 2024
1 parent 594313d commit 649fee3
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 76 deletions.
24 changes: 19 additions & 5 deletions axlearn/common/flash_attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
make_segment_mask,
)
from axlearn.common.config import config_class
from axlearn.common.flash_attention import tpu_attention
from axlearn.common.flash_attention.utils import (
MultiHeadAttentionImpl,
flash_attention_implementation,
Expand Down Expand Up @@ -169,10 +170,6 @@ def _compute_attention(
cfg = self.config
backend = self._backend()

# Repeats key/value heads dim if necessary.
k_proj = self._repeat_kv_heads(k_proj)
v_proj = self._repeat_kv_heads(v_proj)

batch, target_len, num_heads, _ = q_proj.shape
_, source_len, _, _ = k_proj.shape

Expand Down Expand Up @@ -228,7 +225,18 @@ def _compute_attention(
f"{k_proj.shape[1]} for correctly supported GPU flash attention usage."
)

if backend == "tpu":
if backend == "cpu" and not tpu_attention.check_tpu_splash_attention(
query=q_proj,
key=k_proj,
has_mask=bool(cfg.mask),
segment_ids=segment_ids,
has_bias=(attention_logit_biases is not None),
):
backend = "xla"

if backend in ("tpu", "cpu"):
# Splash attention needs to know sliding_window_size.
mask_fn = cfg.mask
assert q_proj.shape[1] % cfg.tpu_block_size == 0, (
f"Target seq len {q_proj.shape[1]} must be "
f"divisible by block size {cfg.tpu_block_size}."
Expand Down Expand Up @@ -263,6 +271,12 @@ def _compute_attention(
q_proj = self.scale_query(q_proj)
k_proj = self.scale_key(k_proj)

# TODO(dhwang2): splash attention supports GQA natively, so don't repeat with proper shard.
# https://github.com/jax-ml/jax/blob/7b9914d711593dca8725d46aa1dadb2194284519/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py#L934
# Repeats key/value heads dim if necessary.
k_proj = self._repeat_kv_heads(k_proj)
v_proj = self._repeat_kv_heads(v_proj)

# Constrain input to conform to partitioned MHA expectations.
q_proj = with_sharding_constraint(q_proj, cfg.mha_dim_to_partition_spec["btnh"])
k_proj = with_sharding_constraint(k_proj, cfg.mha_dim_to_partition_spec["bsnh"])
Expand Down
10 changes: 9 additions & 1 deletion axlearn/common/flash_attention/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized
from absl.testing import absltest, parameterized
from jax.experimental import mesh_utils
from jax.sharding import Mesh

Expand Down Expand Up @@ -91,6 +91,7 @@ def _prepare_layers(
sliding_window_size,
inference=False,
set_layer_bias_recursively=False,
tpu_block_size=512,
):
hidden_dim = num_heads * per_head_dim
kwargs = dict(
Expand All @@ -110,6 +111,7 @@ def _prepare_layers(
.set(
mha_dim_to_partition_spec=default_mha_dim_to_partition_spec(mesh_axis_names),
output_dim_to_partition_spec=default_output_dim_to_partition_spec(mesh_axis_names),
tpu_block_size=tpu_block_size,
)
)
if inference:
Expand Down Expand Up @@ -378,7 +380,9 @@ def test_forward(
mesh_axis_names=mesh_axis_names,
causal=causal,
sliding_window_size=sliding_window_size,
tpu_block_size=128,
)

# pylint: disable-next=protected-access
if test_layer._backend() == "gpu" and query_len_multiplier != 1:
pytest.skip(
Expand Down Expand Up @@ -734,3 +738,7 @@ def test_extend_step(
atol=2e-2,
)
jax.clear_backends()


if __name__ == "__main__":
absltest.main()
Loading

0 comments on commit 649fee3

Please sign in to comment.