Skip to content

Commit

Permalink
Optimize MQA computation. (#837)
Browse files Browse the repository at this point in the history
The advantage of multi-query attention (MQA) lies in both reducing the size of
the KV cache and making self-attention computation more efficient. The current
implementation only saves on KV cache size.

This PR improves it further by not only reducing the computation cost, but also
saving the per-layer KV cache memory.

This becomes especially critical when dealing with very long contexts. For
instance, if an LLM is processing a context length of 1 million tokens using
the Character.ai architecture [1], there might be around 4 unique KV cache layers.
Let’s assume there are 4 KV heads and 32 total attention heads, with a
dim_per_head of 128. In the current implementation, each layer consumes
significant memory for self-attention KV caching (using bfloat16):
* Current (ASIS): 8GB (128 * 32 * 2 * 1M)
* Optimized (TODO): 1GB (128 * 4 * 2 * 1M)

[1] https://research.character.ai/optimizing-inference/

* Benchmark results: it saves memory and computation.
tools/attention_benchmark.py on TPUv5p

ASIS
-----------------------------------------------------------------------------------------
Benchmark                           Time             CPU   Iterations   HBM (over 95.74G)
-----------------------------------------------------------------------------------------
MQABenchmark/2048/16/2/1024       1.42 ms        0.247 ms         2347           291.16M
MQABenchmark/4096/16/2/1024       3.60 ms        0.277 ms         1257           322.95M
MQABenchmark/4096/16/2/4096       47.3 ms        0.818 ms          139             4.25G
MQABenchmark/4096/16/2/8192        869 ms        0.932 ms          140            48.00G

This PR
-----------------------------------------------------------------------------------------
Benchmark                           Time             CPU   Iterations   HBM (over 95.74G)
-----------------------------------------------------------------------------------------
MQABenchmark/2048/16/2/1024       1.16 ms        0.256 ms         2535           262.35M
MQABenchmark/4096/16/2/1024       3.46 ms        0.294 ms         1114           266.88M
MQABenchmark/4096/16/2/4096       24.8 ms        0.769 ms          137             4.04G
MQABenchmark/4096/16/2/8192        860 ms         1.19 ms          136            48.00G
  • Loading branch information
ds-hwang authored Nov 14, 2024
1 parent c4bbaa1 commit e080157
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 66 deletions.
104 changes: 65 additions & 39 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from enum import Enum, unique
from typing import Any, Callable, Literal, NamedTuple, Optional, Protocol, Union

import einops
import jax
from jax import numpy as jnp
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies
Expand Down Expand Up @@ -710,7 +711,7 @@ class Output(NamedTuple):

@property
def num_kv_heads(self):
raise NotImplementedError(type(self))
return self.config.num_heads

def init_states(
self,
Expand Down Expand Up @@ -920,10 +921,6 @@ def __init__(self, cfg: Config, *, parent: Module):
proj_cfg.per_head_dim = cfg.per_head_dim
self._add_child(f"{name}_proj", proj_cfg)

@property
def num_kv_heads(self):
return self.config.num_heads

def forward(
self,
query: Tensor,
Expand Down Expand Up @@ -994,10 +991,6 @@ def __init__(self, cfg: Config, *, parent: Module):
proj_cfg.per_head_dim = cfg.per_head_dim
self._add_child("q_proj", proj_cfg)

@property
def num_kv_heads(self):
raise NotImplementedError(type(self))

def forward(
self,
query: Tensor,
Expand Down Expand Up @@ -1046,10 +1039,6 @@ def __init__(self, cfg: Config, *, parent: Module):
proj_cfg.per_head_dim = cfg.per_head_dim
self._add_child("qkv_proj", proj_cfg)

@property
def num_kv_heads(self):
return self.config.num_heads

def create_parameter_specs_recursively(self) -> NestedParameterSpec:
specs = VDict(**super().create_parameter_specs_recursively())

Expand Down Expand Up @@ -1951,7 +1940,7 @@ def _compute_attention(
self.vlog(3, "atten.logits=%s", logits[0, 0, 0, :])
probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases)
probs = self.dropout(probs)
context = jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype)
context = self._compute_context(probs, v_proj)
context = self._remat_name(context, "context")
return context, probs

Expand Down Expand Up @@ -2007,10 +1996,31 @@ def _cap_logits(self, logits: Tensor) -> Tensor:
return cap * jnp.tanh(logits / cap)

def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor:
"""Compute attention logits.
Args:
q_proj: query tensor, [batch, target_length, num_heads, per_head_dim].
k_proj: key tensor, [batch, source_length, num_heads, per_head_dim].
Returns:
logits: [batch, num_heads, target_length, source_length].
"""
q_proj = self.scale_query(q_proj)
k_proj = self.scale_key(k_proj)
return jnp.einsum("btnh,bsnh->bnts", q_proj, k_proj)

def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor:
"""Compute attention context.
Args:
probs: probs tensor, [batch, num_heads, target_length, source_length].
v_proj: value tensor, [batch, source_length, num_heads, per_head_dim].
Returns:
context: [batch, target_length, num_heads, per_head_dim].
"""
return jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype)

def init_states(
self,
*,
Expand Down Expand Up @@ -2173,31 +2183,47 @@ class GroupedQueryAttention(MultiheadAttention):
def num_kv_heads(self):
return self.i_proj.num_kv_heads

def _repeat_kv_heads(self, key_or_value: Tensor) -> Tensor:
"""Repeats key or value heads dim to match the query."""
num_head_repeats = self.config.num_heads // key_or_value.shape[-2]
if num_head_repeats == 1:
return key_or_value
# Repeat along the num_heads dim: [batch, source_length, num_heads, per_head_dim].
return jnp.repeat(key_or_value, num_head_repeats, axis=-2)
def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor:
"""Compute attention logits.
def _compute_attention(
self,
*,
q_proj: Tensor,
k_proj: Tensor,
v_proj: Tensor,
**kwargs,
) -> tuple[Tensor, Tensor]:
"""See `MultiheadAttention._compute_attention` for details."""
k_proj = self._repeat_kv_heads(k_proj)
v_proj = self._repeat_kv_heads(v_proj)
return super()._compute_attention(
q_proj=q_proj,
k_proj=k_proj,
v_proj=v_proj,
**kwargs,
)
Args:
q_proj: query tensor, [batch, target_length, num_heads, per_head_dim].
k_proj: key tensor, [batch, source_length, num_kv_heads, per_head_dim].
Returns:
logits: [batch, num_heads, target_length, source_length].
"""
kv_heads = k_proj.shape[-2]
num_head_group = self.config.num_heads // kv_heads
if num_head_group == 1:
return super()._compute_logits(q_proj=q_proj, k_proj=k_proj)

q_proj = self.scale_query(q_proj)
k_proj = self.scale_key(k_proj)
q_proj = einops.rearrange(q_proj, "b t (g k) h -> b t g k h", g=num_head_group, k=kv_heads)
k_proj = einops.rearrange(k_proj, "b s k h -> b s 1 k h")
logits = jnp.einsum("btgkh,bs1kh->bgkts", q_proj, k_proj)
return einops.rearrange(logits, "b g k t s -> b (g k) t s")

def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor:
"""Compute attention context.
Args:
probs: probs tensor, [batch, num_heads, target_length, source_length].
v_proj: value tensor, [batch, source_length, num_kv_heads, per_head_dim].
Returns:
context: [batch, target_length, num_heads, per_head_dim].
"""
kv_heads = v_proj.shape[-2]
num_head_group = self.config.num_heads // kv_heads
if num_head_group == 1:
return super()._compute_context(probs=probs, v_proj=v_proj)

probs = einops.rearrange(probs, "b (g k) t s -> b g k t s", g=num_head_group, k=kv_heads)
v_proj = einops.rearrange(v_proj, "b s k h -> b s 1 k h")
context = jnp.einsum("bgkts,bs1kh->btgkh", probs, v_proj)
return einops.rearrange(context, "b t g k h -> b t (g k) h")


class SigmoidAttention(MultiheadAttention):
Expand Down Expand Up @@ -2248,7 +2274,7 @@ def _compute_attention(
)
probs = self.dropout(probs)

context = jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype)
context = self._compute_context(probs, v_proj)
context = self._remat_name(context, "context")
return context, probs

Expand Down
25 changes: 0 additions & 25 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2303,31 +2303,6 @@ def test_sliding_window(
# The outputs are equivalent.
self.assertNestedAllClose(outputs[0], outputs[1])

def test_gqa_kv_heads(self):
"""Checks that only the heads dim is repeated."""
batch = source_length = num_heads = 8
per_head_dim = 2
num_kv_heads = 4
dtype = jnp.float32
key_or_value = jnp.zeros((batch, source_length, num_kv_heads, per_head_dim), dtype=dtype)
model_dim = per_head_dim * num_heads
cfg = attention.GroupedQueryAttention.default_config().set(
query_dim=model_dim,
key_dim=model_dim,
value_dim=model_dim,
num_heads=num_heads,
input_linear=attention.FusedGroupedQKVLinear.default_config().set(
num_kv_heads=num_kv_heads
),
dtype=dtype,
)
test_layer = cfg.set(name="test").instantiate(parent=None)
# pylint: disable-next=protected-access
repeated_key_or_value = test_layer._repeat_kv_heads(key_or_value)
self.assertEqual(
repeated_key_or_value.shape, (batch, source_length, num_heads, per_head_dim)
)

@parameterized.product(
dtype=(jnp.float32, jnp.float16, jnp.bfloat16),
per_dim_scale=(None, PerDimScale.default_config()),
Expand Down
11 changes: 11 additions & 0 deletions axlearn/common/flash_attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ def _logit_biases_spec(self, attention_logit_biases: Tensor) -> Tensor:
spec = PartitionSpec(spec[0], None, *spec[2:])
return spec

def _repeat_kv_heads(self, key_or_value: Tensor) -> Tensor:
"""Repeats key or value heads dim to match the query.
TODO(dhwang2): optimize computation like GroupedQueryAttention.
"""
num_head_repeats = self.config.num_heads // key_or_value.shape[-2]
if num_head_repeats == 1:
return key_or_value
# Repeat along the num_heads dim: [batch, source_length, num_heads, per_head_dim].
return jnp.repeat(key_or_value, num_head_repeats, axis=-2)

def _compute_attention(
self,
*,
Expand Down
2 changes: 1 addition & 1 deletion axlearn/vision/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def forward(
attention_logit_biases = attention_logit_biases[:, None, :, :]
probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases)
probs = self.dropout(probs)
context = jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype)
context = self._compute_context(probs, v_proj)
context = self._remat_name(context, "context")
self.vlog(3, "atten.prob=%s", probs[0, 0, 0, :])
self.vlog(3, "atten.context=%s", context.sum())
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
core = [
"absl-py==2.1.0",
"chex==0.1.86", # chex 0.1.86 is required for jax 0.4.25.
"einops==0.8.0",
"importlab==0.7", # breaks pytype on 0.8
"jax==0.4.34",
"jaxlib==0.4.34",
Expand Down Expand Up @@ -53,10 +54,10 @@ apple-silicon = [
]
# Requirements for testing and development.
dev = [
"axlearn[core]", # core
"axlearn[audio]", # audio tests
"axlearn[orbax]", # checkpointer tests
"black==23.1a1", # formatting
"einops==0.8.0",
"evaluate",
"isort", # formatting
"pika==1.3.2", # used by event queue
Expand Down

0 comments on commit e080157

Please sign in to comment.