Skip to content

Commit 05c1948

Browse files
authored
[Kernel] Support DCP for Triton backend (#25132)
Signed-off-by: Wei Wei <wwei6@meta.com>
1 parent 52d0cb8 commit 05c1948

File tree

4 files changed

+30
-8
lines changed

4 files changed

+30
-8
lines changed

tests/kernels/attention/test_triton_decode_attention.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
4646
# o will have the same shape as q
4747
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
4848

49+
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
50+
4951
b_seq_len = torch.full((B, ), seq_len, device="cuda")
5052

5153
attn_logits = torch.empty(
@@ -60,6 +62,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
6062
k_buffer,
6163
v_buffer,
6264
o,
65+
lse,
6366
req_to_token,
6467
b_seq_len,
6568
attn_logits,
@@ -72,12 +75,14 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
7275
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
7376

7477
o1 = torch.zeros_like(o)
78+
lse1 = torch.zeros_like(lse)
7579

7680
decode_attention_fwd(
7781
q,
7882
k_buffer,
7983
v_buffer,
8084
o1,
85+
lse1,
8186
req_to_page,
8287
b_seq_len,
8388
attn_logits,

vllm/attention/ops/triton_decode_attention.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,12 +474,14 @@ def _decode_grouped_att_m_fwd(
474474
def _fwd_kernel_stage2(
475475
Mid_O,
476476
o,
477+
lse,
477478
B_Seqlen,
478479
stride_mid_ob,
479480
stride_mid_oh,
480481
stride_mid_os,
481482
stride_obs,
482483
stride_oh,
484+
stride_lse_bs,
483485
NUM_KV_SPLITS: tl.constexpr,
484486
BLOCK_DV: tl.constexpr,
485487
Lv: tl.constexpr,
@@ -525,12 +527,18 @@ def _fwd_kernel_stage2(
525527
acc / e_sum,
526528
mask=mask_d,
527529
)
530+
lse_val = e_max + tl.log(e_sum)
531+
tl.store(
532+
lse + cur_batch * stride_lse_bs + cur_head,
533+
lse_val,
534+
)
528535

529536

530537
def _decode_softmax_reducev_fwd(
531538
logits,
532539
q,
533540
o,
541+
lse,
534542
v_buffer,
535543
b_seq_len,
536544
num_kv_splits,
@@ -555,12 +563,14 @@ def _decode_softmax_reducev_fwd(
555563
_fwd_kernel_stage2[grid](
556564
logits,
557565
o,
566+
lse,
558567
b_seq_len,
559568
logits.stride(0),
560569
logits.stride(1),
561570
logits.stride(2),
562571
o.stride(0),
563572
o.stride(1),
573+
lse.stride(0),
564574
NUM_KV_SPLITS=NUM_KV_SPLITS,
565575
BLOCK_DV=BLOCK_DV,
566576
Lv=Lv,
@@ -575,6 +585,7 @@ def decode_attention_fwd_normal(
575585
k_buffer,
576586
v_buffer,
577587
o,
588+
lse,
578589
req_to_token,
579590
b_seq_len,
580591
attn_logits,
@@ -595,7 +606,7 @@ def decode_attention_fwd_normal(
595606
page_size,
596607
logit_cap,
597608
)
598-
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
609+
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
599610
num_kv_splits)
600611

601612

@@ -604,6 +615,7 @@ def decode_attention_fwd_grouped(
604615
k_buffer,
605616
v_buffer,
606617
o,
618+
lse,
607619
req_to_token,
608620
b_seq_len,
609621
attn_logits,
@@ -624,7 +636,7 @@ def decode_attention_fwd_grouped(
624636
page_size,
625637
logit_cap,
626638
)
627-
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
639+
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
628640
num_kv_splits)
629641

630642

@@ -633,6 +645,7 @@ def decode_attention_fwd(
633645
k_buffer,
634646
v_buffer,
635647
o,
648+
lse,
636649
req_to_token,
637650
b_seq_len,
638651
attn_logits,
@@ -651,6 +664,7 @@ def decode_attention_fwd(
651664
k_buffer,
652665
v_buffer,
653666
o,
667+
lse,
654668
req_to_token,
655669
b_seq_len,
656670
attn_logits,
@@ -666,6 +680,7 @@ def decode_attention_fwd(
666680
k_buffer,
667681
v_buffer,
668682
o,
683+
lse,
669684
req_to_token,
670685
b_seq_len,
671686
attn_logits,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def forward(
685685
) -> torch.Tensor:
686686
# Self Attention
687687
if residual is None:
688-
residual = hidden_states
688+
residual = hidden_states.clone()
689689
hidden_states = self.input_layernorm(hidden_states)
690690
else:
691691
hidden_states, residual = self.input_layernorm(

vllm/v1/attention/backends/mla/triton_mla.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def get_impl_cls() -> type["TritonMLAImpl"]:
3232

3333

3434
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
35+
can_return_lse_for_decode: bool = True
3536

3637
def __init__(
3738
self,
@@ -139,19 +140,20 @@ def _forward_decode(
139140

140141
assert isinstance(q, torch.Tensor)
141142
B = q.shape[0]
143+
q_num_heads = q.shape[1]
142144
o = torch.zeros(B,
143-
self.num_heads,
145+
q_num_heads,
144146
self.kv_lora_rank,
145147
dtype=q.dtype,
146148
device=q.device)
147-
149+
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
148150
num_kv_splits = 4 # TODO: heuristic
149151

150152
# TODO(lucas) Allocate ahead of time
151153
attn_logits = torch.empty(
152154
(
153155
B,
154-
self.num_heads,
156+
q_num_heads,
155157
num_kv_splits,
156158
# NOTE(lucas) idk why the +1 is here but sglang has it so we
157159
# just mirror that
@@ -167,9 +169,9 @@ def _forward_decode(
167169
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
168170

169171
# Run MQA
170-
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
172+
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, lse,
171173
attn_metadata.decode.block_table,
172174
attn_metadata.decode.seq_lens, attn_logits,
173175
num_kv_splits, self.scale, PAGE_SIZE)
174176

175-
return o, None
177+
return o, lse

0 commit comments

Comments
 (0)