1717 is_quantized_kv_cache ,
1818)
1919from vllm .attention .layer import Attention
20+ from vllm .attention .ops .common import cp_lse_ag_out_rs
2021from vllm .attention .ops .merge_attn_states import merge_attn_states
2122from vllm .attention .utils .fa_utils import (
2223 flash_attn_supports_fp8 ,
3233 )
3334
3435from vllm .config import VllmConfig , get_layers_from_vllm_config
36+ from vllm .distributed .parallel_state import get_dcp_group
3537from vllm .logger import init_logger
3638from vllm .utils import cdiv
3739from vllm .v1 .attention .backends .utils import (
@@ -147,6 +149,10 @@ class FlashAttentionMetadata:
147149 prefix_kv_lens : torch .Tensor | None
148150 suffix_kv_lens : torch .Tensor | None
149151
152+ # For GQA DCP
153+ max_dcp_context_kv_len : int | None = None
154+ dcp_context_kv_lens : torch .Tensor | None = None
155+
150156 # Optional aot scheduling
151157 scheduler_metadata : torch .Tensor | None = None
152158 prefix_scheduler_metadata : torch .Tensor | None = None
@@ -216,6 +222,16 @@ def __init__(
216222 self .max_num_splits = 0 # No upper bound on the number of splits.
217223 self .aot_schedule = get_flash_attn_version () == 3
218224
225+ try :
226+ from vllm .distributed .parallel_state import get_dcp_group
227+
228+ self .dcp_world_size = get_dcp_group ().world_size
229+ self .dcp_rank = get_dcp_group ().rank_in_group
230+ except AssertionError :
231+ # DCP might not be initialized in testing
232+ self .dcp_world_size = 1
233+ self .dcp_rank = 0
234+
219235 self .use_full_cuda_graph = (
220236 self .compilation_config .cudagraph_mode .has_full_cudagraphs ()
221237 )
@@ -306,7 +322,7 @@ def schedule(
306322 batch_size = batch_size ,
307323 max_seqlen_q = max_query_len ,
308324 max_seqlen_k = max_seq_len ,
309- num_heads_q = self .num_heads_q ,
325+ num_heads_q = self .num_heads_q * self . dcp_world_size ,
310326 num_heads_kv = self .num_heads_kv ,
311327 headdim = self .headdim ,
312328 cache_seqlens = seqlens ,
@@ -320,8 +336,35 @@ def schedule(
320336 return None
321337
322338 use_cascade = common_prefix_len > 0
339+ max_dcp_context_kv_len = 0
340+ dcp_context_kv_lens = None
341+
342+ cu_prefix_query_lens = None
343+ prefix_kv_lens = None
344+ suffix_kv_lens = None
345+ prefix_scheduler_metadata = None
346+
347+ if self .dcp_world_size > 1 :
348+ query_kv_lens_cpu = (
349+ common_attn_metadata .query_start_loc_cpu [1 :]
350+ - common_attn_metadata .query_start_loc_cpu [:- 1 ]
351+ )
352+ dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
353+ dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self .dcp_world_size + (
354+ self .dcp_rank <= (dcp_context_kv_lens_cpu - 1 ) % self .dcp_world_size
355+ )
356+ dcp_context_kv_lens = dcp_context_kv_lens_cpu .to (self .device )
357+ max_dcp_context_kv_len = dcp_context_kv_lens .max ().item ()
323358
324- if use_cascade :
359+ scheduler_metadata = schedule (
360+ batch_size = num_reqs ,
361+ cu_query_lens = query_start_loc ,
362+ max_query_len = max_query_len ,
363+ seqlens = dcp_context_kv_lens ,
364+ max_seq_len = max_dcp_context_kv_len ,
365+ causal = False ,
366+ )
367+ elif use_cascade :
325368 cu_prefix_query_lens = torch .tensor (
326369 [0 , num_actual_tokens ], dtype = torch .int32 , device = self .device
327370 )
@@ -348,10 +391,6 @@ def schedule(
348391 causal = True ,
349392 )
350393 else :
351- cu_prefix_query_lens = None
352- prefix_kv_lens = None
353- suffix_kv_lens = None
354- prefix_scheduler_metadata = None
355394 scheduler_metadata = schedule (
356395 batch_size = num_reqs ,
357396 cu_query_lens = query_start_loc ,
@@ -379,6 +418,8 @@ def schedule(
379418 seq_lens = seq_lens ,
380419 block_table = block_table_tensor ,
381420 slot_mapping = slot_mapping ,
421+ max_dcp_context_kv_len = max_dcp_context_kv_len ,
422+ dcp_context_kv_lens = dcp_context_kv_lens ,
382423 use_cascade = use_cascade ,
383424 common_prefix_len = common_prefix_len ,
384425 scheduler_metadata = scheduler_metadata ,
@@ -396,6 +437,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool:
396437
397438
398439class FlashAttentionImpl (AttentionImpl ):
440+ can_return_lse_for_decode : bool = True
441+
399442 def __init__ (
400443 self ,
401444 num_heads : int ,
@@ -562,30 +605,45 @@ def forward(
562605
563606 descale_shape = (cu_seqlens_q .shape [0 ] - 1 , self .num_kv_heads )
564607
565- flash_attn_varlen_func (
566- q = query [:num_actual_tokens ],
567- k = key_cache ,
568- v = value_cache ,
569- out = output [:num_actual_tokens ],
570- cu_seqlens_q = cu_seqlens_q ,
571- max_seqlen_q = max_seqlen_q ,
572- seqused_k = seqused_k ,
573- max_seqlen_k = max_seqlen_k ,
574- softmax_scale = self .scale ,
575- causal = attn_metadata .causal ,
576- alibi_slopes = self .alibi_slopes ,
577- window_size = self .sliding_window ,
578- block_table = block_table ,
579- softcap = self .logits_soft_cap ,
580- scheduler_metadata = scheduler_metadata ,
581- fa_version = self .vllm_flash_attn_version ,
582- q_descale = layer ._q_scale .expand (descale_shape ),
583- k_descale = layer ._k_scale .expand (descale_shape ),
584- v_descale = layer ._v_scale .expand (descale_shape ),
585- num_splits = attn_metadata .max_num_splits ,
586- s_aux = self .sinks ,
587- )
588- return output
608+ if self .dcp_world_size > 1 :
609+ self ._forward_with_dcp (
610+ query [:num_actual_tokens ],
611+ key [:num_actual_tokens ],
612+ value [:num_actual_tokens ],
613+ key_cache ,
614+ value_cache ,
615+ output [:num_actual_tokens ],
616+ attn_metadata ,
617+ q_descale = layer ._q_scale .expand (descale_shape ),
618+ k_descale = layer ._k_scale .expand (descale_shape ),
619+ v_descale = layer ._v_scale .expand (descale_shape ),
620+ )
621+ return output
622+ else :
623+ flash_attn_varlen_func (
624+ q = query [:num_actual_tokens ],
625+ k = key_cache ,
626+ v = value_cache ,
627+ out = output [:num_actual_tokens ],
628+ cu_seqlens_q = cu_seqlens_q ,
629+ max_seqlen_q = max_seqlen_q ,
630+ seqused_k = seqused_k ,
631+ max_seqlen_k = max_seqlen_k ,
632+ softmax_scale = self .scale ,
633+ causal = attn_metadata .causal ,
634+ alibi_slopes = self .alibi_slopes ,
635+ window_size = self .sliding_window ,
636+ block_table = block_table ,
637+ softcap = self .logits_soft_cap ,
638+ scheduler_metadata = scheduler_metadata ,
639+ fa_version = self .vllm_flash_attn_version ,
640+ q_descale = layer ._q_scale .expand (descale_shape ),
641+ k_descale = layer ._k_scale .expand (descale_shape ),
642+ v_descale = layer ._v_scale .expand (descale_shape ),
643+ num_splits = attn_metadata .max_num_splits ,
644+ s_aux = self .sinks ,
645+ )
646+ return output
589647
590648 # Cascade attention (rare case).
591649 cascade_attention (
@@ -615,6 +673,86 @@ def forward(
615673 )
616674 return output
617675
676+ def _forward_with_dcp (
677+ self ,
678+ query : torch .Tensor ,
679+ key : torch .Tensor ,
680+ value : torch .Tensor ,
681+ key_cache : torch .Tensor ,
682+ value_cache : torch .Tensor ,
683+ output : torch .Tensor ,
684+ attn_metadata : FlashAttentionMetadata ,
685+ q_descale : torch .Tensor | None = None ,
686+ k_descale : torch .Tensor | None = None ,
687+ v_descale : torch .Tensor | None = None ,
688+ ) -> torch .Tensor :
689+ cu_seqlens_q = attn_metadata .query_start_loc
690+ max_seqlen_q = attn_metadata .max_query_len
691+ block_table = attn_metadata .block_table
692+
693+ query = query .contiguous ()
694+ query_across_dcp = get_dcp_group ().all_gather (query , dim = 1 )
695+ context_attn_out , context_lse = flash_attn_varlen_func (
696+ q = query_across_dcp ,
697+ k = key_cache ,
698+ v = value_cache ,
699+ out = None ,
700+ cu_seqlens_q = cu_seqlens_q ,
701+ max_seqlen_q = max_seqlen_q ,
702+ seqused_k = attn_metadata .dcp_context_kv_lens ,
703+ max_seqlen_k = attn_metadata .max_dcp_context_kv_len ,
704+ softmax_scale = self .scale ,
705+ causal = False ,
706+ alibi_slopes = self .alibi_slopes ,
707+ window_size = self .sliding_window ,
708+ block_table = block_table ,
709+ softcap = self .logits_soft_cap ,
710+ return_softmax_lse = True ,
711+ scheduler_metadata = attn_metadata .scheduler_metadata ,
712+ fa_version = self .vllm_flash_attn_version ,
713+ q_descale = q_descale ,
714+ k_descale = k_descale ,
715+ v_descale = v_descale ,
716+ )
717+ # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
718+ context_attn_out_cor , context_lse_cor = cp_lse_ag_out_rs (
719+ context_attn_out ,
720+ context_lse .transpose (0 , 1 ),
721+ get_dcp_group (),
722+ return_lse = True ,
723+ )
724+ context_lse_cor = context_lse_cor .transpose (0 , 1 ).contiguous ()
725+
726+ query_attn_out , query_lse = flash_attn_varlen_func (
727+ q = query ,
728+ k = key ,
729+ v = value ,
730+ out = None ,
731+ cu_seqlens_q = cu_seqlens_q ,
732+ max_seqlen_q = max_seqlen_q ,
733+ cu_seqlens_k = cu_seqlens_q ,
734+ max_seqlen_k = max_seqlen_q ,
735+ softmax_scale = self .scale ,
736+ causal = attn_metadata .causal ,
737+ alibi_slopes = self .alibi_slopes ,
738+ window_size = self .sliding_window ,
739+ softcap = self .logits_soft_cap ,
740+ return_softmax_lse = True ,
741+ fa_version = self .vllm_flash_attn_version ,
742+ q_descale = q_descale ,
743+ k_descale = k_descale ,
744+ v_descale = v_descale ,
745+ )
746+ assert context_attn_out_cor .shape == query_attn_out .shape
747+ assert context_lse_cor .shape == query_lse .shape
748+ merge_attn_states (
749+ output ,
750+ context_attn_out_cor ,
751+ context_lse_cor ,
752+ query_attn_out ,
753+ query_lse ,
754+ )
755+
618756 def _forward_encoder_attention (
619757 self ,
620758 query : torch .Tensor ,
@@ -684,6 +822,7 @@ def use_cascade_attention(
684822 use_sliding_window : bool ,
685823 use_local_attention : bool ,
686824 num_sms : int ,
825+ dcp_world_size : int ,
687826) -> bool :
688827 """Decide whether to use cascade attention.
689828
@@ -705,6 +844,9 @@ def use_cascade_attention(
705844 num_reqs = len (query_lens )
706845 if num_reqs < 8 :
707846 return False
847+ # disable cascade attention for DCP
848+ if dcp_world_size > 1 :
849+ return False
708850
709851 # Heuristics to decide whether using cascade attention is beneficial.
710852 # 1. When FlashDecoding is not used for normal attention, cascade attention
0 commit comments