diff --git a/tests/kernels/attention/test_triton_decode_attention.py b/tests/kernels/attention/test_triton_decode_attention.py index 2dca720fe330..48aacac8376b 100644 --- a/tests/kernels/attention/test_triton_decode_attention.py +++ b/tests/kernels/attention/test_triton_decode_attention.py @@ -46,6 +46,8 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): # o will have the same shape as q o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda") + b_seq_len = torch.full((B, ), seq_len, device="cuda") attn_logits = torch.empty( @@ -60,6 +62,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, @@ -72,12 +75,14 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) o1 = torch.zeros_like(o) + lse1 = torch.zeros_like(lse) decode_attention_fwd( q, k_buffer, v_buffer, o1, + lse1, req_to_page, b_seq_len, attn_logits, diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index f82ce5b4d4b6..7f5a678615cf 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -474,12 +474,14 @@ def _decode_grouped_att_m_fwd( def _fwd_kernel_stage2( Mid_O, o, + lse, B_Seqlen, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_obs, stride_oh, + stride_lse_bs, NUM_KV_SPLITS: tl.constexpr, BLOCK_DV: tl.constexpr, Lv: tl.constexpr, @@ -525,12 +527,18 @@ def _fwd_kernel_stage2( acc / e_sum, mask=mask_d, ) + lse_val = e_max + tl.log(e_sum) + tl.store( + lse + cur_batch * stride_lse_bs + cur_head, + lse_val, + ) def _decode_softmax_reducev_fwd( logits, q, o, + lse, v_buffer, b_seq_len, num_kv_splits, @@ -555,12 +563,14 @@ def _decode_softmax_reducev_fwd( _fwd_kernel_stage2[grid]( logits, o, + lse, b_seq_len, logits.stride(0), logits.stride(1), logits.stride(2), o.stride(0), o.stride(1), + lse.stride(0), NUM_KV_SPLITS=NUM_KV_SPLITS, BLOCK_DV=BLOCK_DV, Lv=Lv, @@ -575,6 +585,7 @@ def decode_attention_fwd_normal( k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, @@ -595,7 +606,7 @@ def decode_attention_fwd_normal( page_size, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + _decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits) @@ -604,6 +615,7 @@ def decode_attention_fwd_grouped( k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, @@ -624,7 +636,7 @@ def decode_attention_fwd_grouped( page_size, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + _decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits) @@ -633,6 +645,7 @@ def decode_attention_fwd( k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, @@ -651,6 +664,7 @@ def decode_attention_fwd( k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, @@ -666,6 +680,7 @@ def decode_attention_fwd( k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 415d36c681d8..cb6bc7132ed0 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -688,7 +688,7 @@ def forward( ) -> torch.Tensor: # Self Attention if residual is None: - residual = hidden_states + residual = hidden_states.clone() hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index d692b00d78b4..dd272fa01925 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -32,6 +32,7 @@ def get_impl_cls() -> type["TritonMLAImpl"]: class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): + can_return_lse_for_decode: bool = True def __init__( self, @@ -139,19 +140,20 @@ def _forward_decode( assert isinstance(q, torch.Tensor) B = q.shape[0] + q_num_heads = q.shape[1] o = torch.zeros(B, - self.num_heads, + q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device) - + lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) num_kv_splits = 4 # TODO: heuristic # TODO(lucas) Allocate ahead of time attn_logits = torch.empty( ( B, - self.num_heads, + q_num_heads, num_kv_splits, # NOTE(lucas) idk why the +1 is here but sglang has it so we # just mirror that @@ -167,9 +169,9 @@ def _forward_decode( PAGE_SIZE = kv_c_and_k_pe_cache.size(1) # Run MQA - decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, + decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, lse, attn_metadata.decode.block_table, attn_metadata.decode.seq_lens, attn_logits, num_kv_splits, self.scale, PAGE_SIZE) - return o, None + return o, lse