Skip to content

[Executorch][llm] Add ring buffer based kv cache and mask calculation to MHA #10609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: gh/kimishpatel/186/base
Choose a base branch
from
93 changes: 77 additions & 16 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,12 @@ def __init__(
head_dim: int,
n_rep: int,
max_context_len: int,
enable_dynamic_shape: bool,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
self.max_context_len = max_context_len
self.enable_dynamic_shape = enable_dynamic_shape

def forward(
self,
Expand All @@ -142,21 +140,12 @@ def forward(
seqlen,
mask: torch.Tensor,
) -> torch.Tensor:
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_context_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = mask.narrow(0, start_pos, seq_length)
else:
attn_mask = mask[None, None, input_pos]

# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
# can natively support GQA now. But needs enable_gqa=True
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand Down Expand Up @@ -236,21 +225,79 @@ def __init__(
enable_dynamic_shape: bool,
dtype=torch.float32,
):
self.window_size = max_context_length
"""
Reason why we want the kv cache size to be twice the context length:
Sliding window attention without ringbuffer
pos 0 1 2 3 4 5 6 7 8 9 10
0 x 0 0 0 0 0 0 0 0 0 0
1 x x 0 0 0 0 0 0 0 0 0
2 x x x 0 0 0 0 0 0 0 0
3 x x x x 0 0 0 0 0 0 0
4 0 x x x x 0 0 0 0 0 0
5 0 0 x x x x 0 0 0 0 0
6 0 0 0 x x x x 0 0 0 0
7 0 0 0 0 x x x x 0 0 0
8 0 0 0 0 0 x x x x 0 0
9 0 0 0 0 0 0 x x x x 0
10 0 0 0 0 0 0 0 x x x x

So when doing attention for pos = 5 and seq_len = 4 our attention
mask would be
5 0 0 x x x x 0 0 0 0 0
6 0 0 0 x x x x 0 0 0 0
7 0 0 0 0 x x x x 0 0 0
8 0 0 0 0 0 x x x x 0 0
Thus tok at pos = 5 is able to attend to tokens at pos 2, 3 and 4.
This is how training is done.

Now lets consider ring kv cache of size 4. When we are at pos = 5
before updating the kv cache, state of the kv cache would be
[4 1 2 3]. That is we evicted token at pos = 0 out. Now during
attention calculation at pos = 5 seq len = 4, we will update cache and
new pos in the cache would be [8 5 6 7]. So note that 5 can now only attend
to itself. Not 2, 3 and 4 as you would have during training.
So not having kept 2, 3 and 4 in cache means we will have divergent behavior.
Worst case of this would have been when update it equal to the length of
the cache. like in our case pos = 5 seq len = 4.
Thus we need to have a cache that is larger. How much larger, as much as
the sliding window size. So twice the max_context_length.
How would that have helped. Lets see. At pos = 5 our cache would have
[0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have
[8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the
current step still has access to [pos - sliding_window_size, pos] tokens.

To make sure we dont over attend, i.e. we dont have pos = 5
to attend to pos = 1, mask calculaton has to account for the sliding window
size.
"""
super().__init__(
max_batch_size,
max_context_length,
max_context_length * 2,
n_heads,
head_dim,
enable_dynamic_shape,
dtype,
)
self.cache_positions_manager = CachePositionsManager(max_context_length)
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
self.is_ring_buffer = True

def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
cache_positions = self.cache_positions_manager.cache_positions
delta = pos_q - cache_positions
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size)
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
return attn_mask

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D]
seq_len = k_val.size(2)
assert seq_len <= self.k_cache.size(
2
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
input_pos, seq_len
)
Expand Down Expand Up @@ -286,6 +333,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.attention_qkv_bias = args.attention_qkv_bias
self.use_qk_norm = args.use_qk_norm
self.qk_norm_before_rope = args.qk_norm_before_rope
self.enable_dynamic_shape = args.enable_dynamic_shape

if self.use_qk_norm:
q_norm_dim = self.head_dim
Expand Down Expand Up @@ -331,7 +379,6 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
head_dim=self.head_dim,
n_rep=self.n_rep,
max_context_len=self.max_context_len,
enable_dynamic_shape=args.enable_dynamic_shape,
)

def forward(
Expand Down Expand Up @@ -368,8 +415,22 @@ def forward(

if self.use_kv_cache:
assert input_pos is not None
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_context_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = self.mask.narrow(0, start_pos, seq_length)
else:
# mask is always 2D
attn_mask = self.mask[input_pos]
k, v = self.kv_cache.update(input_pos, k, v)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
if getattr(self.kv_cache, "is_ring_buffer", False):
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
input_pos[0].item(), seqlen
)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
return self.wo(output), None

# grouped multiquery attention: expand out keys and values
Expand Down
81 changes: 45 additions & 36 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,11 @@ class SDPACustom(torch.nn.Module):
def __init__(
self,
dim: int,
max_context_len,
enable_dynamic_shape,
use_attention_mask: bool = False,
):
super().__init__()
self.dim = dim
self.max_context_len = max_context_len
self.use_attention_mask = use_attention_mask
self.enable_dynamic_shape = enable_dynamic_shape

def forward(
self,
Expand All @@ -42,16 +38,6 @@ def forward(
seqlen,
mask,
):
if self.use_attention_mask:
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_context_len)
seq_length = q.size(2)
mask = mask.narrow(0, start_pos, seq_length)
else:
mask = mask[input_pos]

q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
Expand Down Expand Up @@ -96,8 +82,6 @@ def _replace_sdpa_with_custom_op(
name,
SDPACustom(
child.dim,
child.max_context_len,
child.enable_dynamic_shape,
use_attention_mask=use_attention_mask,
),
)
Expand Down Expand Up @@ -133,12 +117,15 @@ class QuantizedSDPA(torch.nn.Module):
zero points, we need to pass kv_cache to SDPA.
"""

def __init__(self, dim: int, kv_cache: QuantizedKVCache):
def __init__(
self, dim: int, kv_cache: QuantizedKVCache, use_attention_mask: bool = False
):
super().__init__()
self.dim = dim
self.quantized_dtype = torch.int8
self.float_dtype = torch.float32
self.kv_cache = kv_cache
self.use_attention_mask = use_attention_mask

def forward(
self,
Expand Down Expand Up @@ -176,22 +163,40 @@ def forward(
v_scale_fp32 = self.kv_cache.v_cache_scales

start_pos = input_pos[0].item()
output = torch.ops.llama.custom_quantized_sdpa(
q_quantized,
k_quantized,
v_quantized,
start_pos,
None,
0,
True,
None,
q_zero_point_int8,
q_scale_fp32,
k_zero_point_int8,
k_scale_fp32,
v_zero_point_int8,
v_scale_fp32,
)
if self.use_attention_mask:
output = torch.ops.llama.custom_quantized_sdpa(
q_quantized,
k_quantized,
v_quantized,
start_pos,
mask,
0,
False,
None,
q_zero_point_int8,
q_scale_fp32,
k_zero_point_int8,
k_scale_fp32,
v_zero_point_int8,
v_scale_fp32,
)
else:
output = torch.ops.llama.custom_quantized_sdpa(
q_quantized,
k_quantized,
v_quantized,
start_pos,
None,
0,
True,
None,
q_zero_point_int8,
q_scale_fp32,
k_zero_point_int8,
k_scale_fp32,
v_zero_point_int8,
v_scale_fp32,
)

return output.view(bsz, seqlen, self.dim)

Expand All @@ -201,6 +206,7 @@ def _update_attention_module_with_quantized_sdpa(
):
sdpa = getattr(module, "SDPA", None)
assert sdpa is not None
# TODO: add support for SDPA with attention mask
# pyre-ignore
setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache)) # noqa: B010

Expand Down Expand Up @@ -254,7 +260,8 @@ def forward(
seqlen,
mask,
):
attn_mask = mask[None, None, input_pos]
# Input mask is slided however it is 2D
attn_mask = mask[None, None]

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
Expand Down Expand Up @@ -310,7 +317,8 @@ def forward(
"""
k = repeat_kv(k, self.n_rep)
v = repeat_kv(v, self.n_rep)
attn_mask = mask[input_pos]
# Mask is already sliced as needed
attn_mask = mask

scale_factor = 1 / math.sqrt(q.size(-1))
attn_weight = q @ k.transpose(-2, -1) * scale_factor
Expand Down Expand Up @@ -391,7 +399,8 @@ def forward(
seqlen,
mask,
):
attn_mask = mask[None, None, input_pos]
# Input mask is slided however it is 2D
attn_mask = mask[None, None]

if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len, enable_dynamic_shape)
self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len)
self.kv_cache = None

def forward(self, x, freqs_cos, freqs_sin, **kwargs):
Expand Down Expand Up @@ -159,15 +159,9 @@ def test_forward_functionality(self):
k_quantized, v_quantized = model.attention.kv_cache.update(input_pos, k, v)

# Run the forward pass with the quantized SDPA
try:
output = model.attention.SDPA(
input_pos, q, k_quantized, v_quantized, bsz, seqlen, None
)
output = model.attention.SDPA(
input_pos, q, k_quantized, v_quantized, bsz, seqlen, None
)

# Verify the output shape
self.assertEqual(output.shape, (bsz, seqlen, self.dim))
except Exception:
# If the forward pass fails, it might be due to missing custom ops
self.skipTest(
"Custom ops not available, skipping forward functionality test"
)
# Verify the output shape
self.assertEqual(output.shape, (bsz, seqlen, self.dim))
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_simple(self, is_dynamic_shape=False):
self.seq_len = 3
self._init_cache()
q, k_val, v_val = self._init_kv()
self.float_sdpa = SDPACustom(self.dim, self.max_context_len, True)
self.quantized_sdpa = SDPACustom(self.dim, self.max_context_len, True)
self.float_sdpa = SDPACustom(self.dim)
self.quantized_sdpa = SDPACustom(self.dim)
k, v = self.custom_kv_cache.update(input_pos, k_val, v_val)
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val)
Expand Down
11 changes: 11 additions & 0 deletions examples/models/llama/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,14 @@ python_unittest(
"//executorch/examples/models/llama:llama_transformer",
],
)

python_unittest(
name = "test_ring_attention",
srcs = [
"test_ring_attention.py",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama:llama_transformer",
],
)
Loading
Loading