Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/kernels/attention/test_triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")

lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")

b_seq_len = torch.full((B, ), seq_len, device="cuda")

attn_logits = torch.empty(
Expand All @@ -60,6 +62,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
Expand All @@ -72,12 +75,14 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)

o1 = torch.zeros_like(o)
lse1 = torch.zeros_like(lse)

decode_attention_fwd(
q,
k_buffer,
v_buffer,
o1,
lse1,
req_to_page,
b_seq_len,
attn_logits,
Expand Down
19 changes: 17 additions & 2 deletions vllm/attention/ops/triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,14 @@ def _decode_grouped_att_m_fwd(
def _fwd_kernel_stage2(
Mid_O,
o,
lse,
B_Seqlen,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
stride_lse_bs,
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
Expand Down Expand Up @@ -525,12 +527,18 @@ def _fwd_kernel_stage2(
acc / e_sum,
mask=mask_d,
)
lse_val = e_max + tl.log(e_sum)
tl.store(
lse + cur_batch * stride_lse_bs + cur_head,
lse_val,
)


def _decode_softmax_reducev_fwd(
logits,
q,
o,
lse,
v_buffer,
b_seq_len,
num_kv_splits,
Expand All @@ -555,12 +563,14 @@ def _decode_softmax_reducev_fwd(
_fwd_kernel_stage2[grid](
logits,
o,
lse,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
lse.stride(0),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
Expand All @@ -575,6 +585,7 @@ def decode_attention_fwd_normal(
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
Expand All @@ -595,7 +606,7 @@ def decode_attention_fwd_normal(
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
num_kv_splits)


Expand All @@ -604,6 +615,7 @@ def decode_attention_fwd_grouped(
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
Expand All @@ -624,7 +636,7 @@ def decode_attention_fwd_grouped(
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
num_kv_splits)


Expand All @@ -633,6 +645,7 @@ def decode_attention_fwd(
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
Expand All @@ -651,6 +664,7 @@ def decode_attention_fwd(
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
Expand All @@ -666,6 +680,7 @@ def decode_attention_fwd(
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def forward(
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
Expand Down
12 changes: 7 additions & 5 deletions vllm/v1/attention/backends/mla/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def get_impl_cls() -> type["TritonMLAImpl"]:


class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True

def __init__(
self,
Expand Down Expand Up @@ -139,19 +140,20 @@ def _forward_decode(

assert isinstance(q, torch.Tensor)
B = q.shape[0]
q_num_heads = q.shape[1]
o = torch.zeros(B,
self.num_heads,
q_num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)

lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
num_kv_splits = 4 # TODO: heuristic

# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
self.num_heads,
q_num_heads,
num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
Expand All @@ -167,9 +169,9 @@ def _forward_decode(
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)

# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, lse,
attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)

return o, None
return o, lse