2121                                              AttentionType )
2222from  vllm .attention .ops .common  import  cp_lse_ag_out_ar 
2323from  vllm .config  import  CUDAGraphMode , VllmConfig 
24- from  vllm .logger  import  init_logger 
2524from  vllm .distributed .parallel_state  import  get_cp_group 
25+ from  vllm .logger  import  init_logger 
2626from  vllm .model_executor .layers .quantization .utils .quant_utils  import  (
2727    QuantKey , kFp8StaticTensorSym , kNvfp4Quant )
2828from  vllm .platforms  import  current_platform 
@@ -239,7 +239,7 @@ class FlashInferMetadata:
239239    paged_kv_indptr_gpu : Optional [torch .Tensor ] =  None 
240240
241241    # For context parallel 
242-     cp_kv_recover_idx : Optional [torch .Tensor ] =  None 
242+     cp_allgather_restore_idx : Optional [torch .Tensor ] =  None 
243243
244244
245245class  FlashInferMetadataBuilder (AttentionMetadataBuilder [FlashInferMetadata ]):
@@ -262,9 +262,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
262262                                     self .kv_cache_spec .block_size )
263263        max_num_reqs  =  vllm_config .scheduler_config .max_num_seqs 
264264        max_num_pages  =  max_num_reqs  *  max_num_pages_per_req 
265-         # NOTE(qcs): Context Parallel do not support graph mode now 
266265        self .enable_cuda_graph  =  (self .compilation_config .cudagraph_mode .\
267-             decode_mode () ==  CUDAGraphMode .FULL   and   self . cp_world_size   ==   1 )
266+             decode_mode () ==  CUDAGraphMode .FULL )
268267        if  self .enable_cuda_graph :
269268            # For full cudagraph capture, one `decode_wrapper` for each batch 
270269            # size is needed for FlashInfer. 
@@ -552,7 +551,7 @@ def build(self,
552551            num_prefills = num_prefills ,
553552            num_prefill_tokens = num_prefill_tokens ,
554553            use_cascade = use_cascade ,
555-             cp_kv_recover_idx = common_attn_metadata .cp_kv_recover_idx ,
554+             cp_allgather_restore_idx = common_attn_metadata .cp_allgather_restore_idx ,
556555        )
557556
558557        qo_indptr_cpu  =  common_attn_metadata .query_start_loc_cpu 
@@ -599,38 +598,30 @@ def build(self,
599598                qo_indptr_cpu  =  qo_indptr_cpu [prefill_start :] -  qo_indptr_cpu [
600599                    prefill_start ]
601600                paged_kv_indptr_cpu  =  paged_kv_indptr_cpu [prefill_start :]
602-                 prefill_num_computed_tokens_cpu  =  num_computed_tokens_cpu [prefill_start :]
601+                 prefill_num_computed_tokens_cpu  =  \
602+                     num_computed_tokens_cpu [prefill_start :]
603603                if  not  attn_metadata .prefill_use_trtllm :
604604                    if  self .cp_world_size  >  1 :
605-                         # NOTE(qcs): no chunked prefill and prefix caching 
605+                         assert   common_attn_metadata . query_positions   is   not   None 
606606                        kv_indptr_cpu  =  qo_indptr_cpu  *  self .cp_world_size 
607607                        # init custom mask for head-tail query order 
608-                         mask_arr  =  []
609-                         q_pos  =  common_attn_metadata .query_positions 
610-                         for  i  in  range (num_prefills ):
611-                             # |----<C>-----|-<Q0>-|-<Q1>-| 
612-                             # |---<C+Q*cp_world_size>----| 
613-                             # cp_world_size = 2 
614-                             # Q = 2 
615-                             # C = 8 
616-                             # cur_q_pos = [0,3] 
617-                             # context_mask_i.shape = (2, 8) 
618-                             # upper = [0,1,2,3] 
619-                             # local_mask_i = [[True, False, False, False],  
620-                             #                 [True, True, True, True]] # size=(2, 4) 
621-                             # mask_i.shape = (2, 12) 
622-                             cur_q_pos  =  torch .from_numpy (q_pos [qo_indptr_cpu [i ]:qo_indptr_cpu [i + 1 ]])
623-                             Q  =  len (cur_q_pos )
624-                             C  =  prefill_num_computed_tokens_cpu [i ]
625-                             if  Q  <=  0 :
626-                                 mask_arr .append (torch .zeros (0 , dtype = torch .bool ))
627-                                 continue 
628-                             context_mask_i  =  torch .ones ((Q , C ), dtype = torch .bool )
629-                             upper  =  torch .arange (Q * self .cp_world_size )
630-                             local_mask_i  =  (upper .unsqueeze (0 ) <=  cur_q_pos .unsqueeze (1 ))
631-                             mask_i  =  torch .cat ([context_mask_i , local_mask_i ], dim = 1 )
632-                             mask_arr .append (mask_i .flatten ())
633-                         custom_mask  =  torch .cat (mask_arr , dim = 0 ).to (self .device )
608+                         q_pos  =  torch .from_numpy (
609+                             common_attn_metadata .query_positions [
610+                                 prefill_start :]).long ()
611+                         kv_lens  =  prefill_num_computed_tokens_cpu  +  \
612+                             kv_indptr_cpu [1 :] -  kv_indptr_cpu [:- 1 ]
613+                         max_q_lens  =  int (q_pos .max ().item ()) +  1 
614+                         max_kv_lens  =  int (kv_lens .max ().item ())
615+                         mask  =  torch .ones (max_q_lens , max_kv_lens ,
616+                                           dtype = torch .bool ).tril ()
617+                         selected_rows  =  torch .index_select (mask , 0 , q_pos )
618+                         col_indices  =  torch .arange (max_kv_lens ).expand (q_pos .size (0 ), - 1 )
619+                         valid_mask  =  col_indices  <  torch .repeat_interleave (
620+                                                         kv_lens ,
621+                                                         qo_indptr_cpu [1 :] -  \
622+                                                             qo_indptr_cpu [:- 1 ]
623+                                                     ).unsqueeze (1 )
624+                         custom_mask  =  selected_rows [valid_mask ].to (self .device )
634625
635626                        attn_metadata .prefill_wrapper .plan (
636627                            qo_indptr_cpu .to (self .device ),
@@ -874,6 +865,28 @@ def forward(
874865        # performance to make sure it does not introduce any overhead. 
875866
876867        num_actual_tokens  =  attn_metadata .num_actual_tokens 
868+         num_decode_tokens  =  attn_metadata .num_decode_tokens 
869+         num_prefill_tokens  =  attn_metadata .num_prefill_tokens 
870+ 
871+         key_across_cp  =  get_cp_group ().all_gather (
872+             key .contiguous (), dim = 0 )
873+         value_across_cp  =  get_cp_group ().all_gather (
874+             value .contiguous (), dim = 0 )
875+         if  (self .cp_world_size  >  1 
876+             and  attn_metadata .cp_allgather_restore_idx  is  not None ):
877+             # Reorder kv after cp allgather. 
878+             # Note that there are duplicate decoding tokens, 
879+             # but we only save the first one in kvcache. 
880+             key_across_cp  =  torch .index_select (
881+                 key_across_cp , 0 ,
882+                 attn_metadata .cp_allgather_restore_idx 
883+             )
884+             value_across_cp  =  torch .index_select (
885+                 value_across_cp , 0 ,
886+                 attn_metadata .cp_allgather_restore_idx 
887+             )
888+         key  =  key_across_cp 
889+         value  =  value_across_cp 
877890
878891        if  self .kv_sharing_target_layer_name  is  None :
879892            # Reshape the input keys and values and store them in the cache. 
@@ -883,17 +896,16 @@ def forward(
883896            # and value[:num_actual_tokens] because the reshape_and_cache_flash 
884897            # op uses the slot_mapping's shape to determine the number of 
885898            # actual tokens. 
886-             if  self .cp_world_size  ==  1 :
887-                 torch .ops ._C_cache_ops .reshape_and_cache_flash (
888-                     key ,
889-                     value ,
890-                     kv_cache [:, 0 ],
891-                     kv_cache [:, 1 ],
892-                     attn_metadata .slot_mapping ,
893-                     self .kv_cache_dtype ,
894-                     layer ._k_scale ,
895-                     layer ._v_scale ,
896-                 )
899+             torch .ops ._C_cache_ops .reshape_and_cache_flash (
900+                 key ,
901+                 value ,
902+                 kv_cache [:, 0 ],
903+                 kv_cache [:, 1 ],
904+                 attn_metadata .slot_mapping ,
905+                 self .kv_cache_dtype ,
906+                 layer ._k_scale ,
907+                 layer ._v_scale ,
908+             )
897909
898910            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 
899911            # to process the cache when the kv_cache_dtype is fp8 
@@ -913,9 +925,6 @@ def forward(
913925            output .copy_ (attn_metadata .cascade_wrapper .run (query , kv_cache ))
914926            return  output 
915927
916-         num_decode_tokens  =  attn_metadata .num_decode_tokens 
917-         num_prefill_tokens  =  attn_metadata .num_prefill_tokens 
918- 
919928        stride_order  =  FlashInferBackend .get_kv_cache_stride_order ()
920929        kv_cache_permute  =  kv_cache .permute (* stride_order )
921930        # Regular attention (common case). 
@@ -933,34 +942,15 @@ def forward(
933942                    self .logits_soft_cap  or  0.0 )
934943                assert  prefill_wrapper ._sm_scale  ==  self .scale 
935944                if  self .cp_world_size  >  1 :
936-                     key_across_cp  =  get_cp_group ().all_gather (
937-                         key [num_decode_tokens :].contiguous (), dim = 0 )
938-                     value_across_cp  =  get_cp_group ().all_gather (
939-                         value [num_decode_tokens :].contiguous (), dim = 0 )
940-                     key_across_cp  =  torch .index_select (
941-                         key_across_cp , 0 ,
942-                         attn_metadata .cp_kv_recover_idx 
943-                     )
944-                     value_across_cp  =  torch .index_select (
945-                         value_across_cp , 0 ,
946-                         attn_metadata .cp_kv_recover_idx 
947-                     )
948-                     torch .ops ._C_cache_ops .reshape_and_cache_flash (
949-                         key_across_cp ,
950-                         value_across_cp ,
951-                         kv_cache [:, 0 ],
952-                         kv_cache [:, 1 ],
953-                         attn_metadata .slot_mapping [num_decode_tokens :],
954-                         self .kv_cache_dtype ,
955-                         layer ._k_scale ,
956-                         layer ._v_scale ,
957-                     )
958-                     # TODO(qcs): 考虑 chunked prefill/ prefix cache 情况下 
959-                     # kvcache的获取与拼接 
945+                     # NOTE(qcs): Allgather causes duplicate decoding tokens. 
946+                     prefill_key  =  key [
947+                         num_decode_tokens * self .cp_world_size :]
948+                     prefill_value  =  value [
949+                         num_decode_tokens * self .cp_world_size :]
960950                    prefill_wrapper .run (
961951                        prefill_query ,
962-                         key_across_cp ,
963-                         value_across_cp ,
952+                         prefill_key ,
953+                         prefill_value ,
964954                        out = output [num_decode_tokens :],
965955                    )
966956                else :
@@ -1047,17 +1037,6 @@ def forward(
10471037                                                           or  0.0 )
10481038                assert  decode_wrapper ._sm_scale  ==  self .scale 
10491039                if  self .cp_world_size  >  1 :
1050-                     torch .ops ._C_cache_ops .reshape_and_cache_flash (
1051-                         key [:num_decode_tokens ],
1052-                         value [:num_decode_tokens ],
1053-                         kv_cache [:, 0 ],
1054-                         kv_cache [:, 1 ],
1055-                         attn_metadata .slot_mapping [:num_decode_tokens ],
1056-                         self .kv_cache_dtype ,
1057-                         layer ._k_scale ,
1058-                         layer ._v_scale ,
1059-                     )
1060-                     kv_cache_permute  =  kv_cache .permute (* stride_order )
10611040                    out , lse  =  decode_wrapper .run (
10621041                        decode_query ,
10631042                        kv_cache_permute ,
0 commit comments