@@ -340,28 +340,32 @@ def schedule(
340340        prefix_scheduler_metadata  =  None 
341341
342342        if  self .dcp_world_size  >  1 :
343-             query_kv_lens_cpu  =  common_attn_metadata .query_start_loc_cpu [1 :] \
343+             query_kv_lens_cpu  =  (
344+                 common_attn_metadata .query_start_loc_cpu [1 :]
344345                -  common_attn_metadata .query_start_loc_cpu [:- 1 ]
346+             )
345347            dcp_context_kv_lens_cpu  =  seq_lens_cpu  -  query_kv_lens_cpu 
346-             dcp_context_kv_lens_cpu  =  dcp_context_kv_lens_cpu  \ 
347-                 //   self .dcp_world_size   +  ( self .dcp_rank  \ 
348-                  <=  ( dcp_context_kv_lens_cpu - 1 )  %   self . dcp_world_size )
348+             dcp_context_kv_lens_cpu  =  dcp_context_kv_lens_cpu  //   self . dcp_world_size   +  ( 
349+                 self .dcp_rank   <=  ( dcp_context_kv_lens_cpu   -   1 )  %   self .dcp_world_size 
350+             )
349351            dcp_context_kv_lens  =  dcp_context_kv_lens_cpu .to (self .device )
350352            max_dcp_context_kv_len  =  dcp_context_kv_lens .max ().item ()
351353
352-             scheduler_metadata  =  schedule (batch_size = num_reqs ,
353-                                           cu_query_lens = query_start_loc ,
354-                                           max_query_len = max_query_len ,
355-                                           seqlens = dcp_context_kv_lens ,
356-                                           max_seq_len = max_dcp_context_kv_len ,
357-                                           causal = False )
354+             scheduler_metadata  =  schedule (
355+                 batch_size = num_reqs ,
356+                 cu_query_lens = query_start_loc ,
357+                 max_query_len = max_query_len ,
358+                 seqlens = dcp_context_kv_lens ,
359+                 max_seq_len = max_dcp_context_kv_len ,
360+                 causal = False ,
361+             )
358362        elif  use_cascade :
359-             cu_prefix_query_lens  =  torch .tensor ([ 0 ,  num_actual_tokens ], 
360-                                                  dtype = torch .int32 ,
361-                                                  device = self . device )
362-             prefix_kv_lens  =  torch .tensor ([ common_prefix_len ], 
363-                                            dtype = torch .int32 ,
364-                                            device = self . device )
363+             cu_prefix_query_lens  =  torch .tensor (
364+                 [ 0 ,  num_actual_tokens ],  dtype = torch .int32 ,  device = self . device 
365+             )
366+             prefix_kv_lens  =  torch .tensor (
367+                 [ common_prefix_len ],  dtype = torch .int32 ,  device = self . device 
368+             )
365369            suffix_kv_lens  =  (seq_lens_cpu [:num_reqs ] -  common_prefix_len ).to (
366370                self .device , non_blocking = True 
367371            )
@@ -683,60 +687,57 @@ def _forward_with_dcp(
683687
684688        query  =  query .contiguous ()
685689        query_across_dcp  =  get_dcp_group ().all_gather (query , dim = 1 )
686-         context_attn_out , context_lse  =  \
687-             flash_attn_varlen_func (
688-                 q = query_across_dcp ,
689-                 k = key_cache ,
690-                 v = value_cache ,
691-                 out = None ,
692-                 cu_seqlens_q = cu_seqlens_q ,
693-                 max_seqlen_q = max_seqlen_q ,
694-                 seqused_k = attn_metadata .dcp_context_kv_lens ,
695-                 max_seqlen_k = attn_metadata .max_dcp_context_kv_len ,
696-                 softmax_scale = self .scale ,
697-                 causal = False ,
698-                 alibi_slopes = self .alibi_slopes ,
699-                 window_size = self .sliding_window ,
700-                 block_table = block_table ,
701-                 softcap = self .logits_soft_cap ,
702-                 return_softmax_lse = True ,
703-                 scheduler_metadata = attn_metadata .scheduler_metadata ,
704-                 fa_version = self .vllm_flash_attn_version ,
705-                 q_descale = q_descale ,
706-                 k_descale = k_descale ,
707-                 v_descale = v_descale ,
708-             )
690+         context_attn_out , context_lse  =  flash_attn_varlen_func (
691+             q = query_across_dcp ,
692+             k = key_cache ,
693+             v = value_cache ,
694+             out = None ,
695+             cu_seqlens_q = cu_seqlens_q ,
696+             max_seqlen_q = max_seqlen_q ,
697+             seqused_k = attn_metadata .dcp_context_kv_lens ,
698+             max_seqlen_k = attn_metadata .max_dcp_context_kv_len ,
699+             softmax_scale = self .scale ,
700+             causal = False ,
701+             alibi_slopes = self .alibi_slopes ,
702+             window_size = self .sliding_window ,
703+             block_table = block_table ,
704+             softcap = self .logits_soft_cap ,
705+             return_softmax_lse = True ,
706+             scheduler_metadata = attn_metadata .scheduler_metadata ,
707+             fa_version = self .vllm_flash_attn_version ,
708+             q_descale = q_descale ,
709+             k_descale = k_descale ,
710+             v_descale = v_descale ,
711+         )
709712        # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] 
710-         context_attn_out_cor , context_lse_cor  =  \
711-             cp_lse_ag_out_rs (
712-                 context_attn_out ,
713-                 context_lse .transpose (0 , 1 ),
714-                 get_dcp_group (),
715-                 return_lse = True 
716-             )
713+         context_attn_out_cor , context_lse_cor  =  cp_lse_ag_out_rs (
714+             context_attn_out ,
715+             context_lse .transpose (0 , 1 ),
716+             get_dcp_group (),
717+             return_lse = True ,
718+         )
717719        context_lse_cor  =  context_lse_cor .transpose (0 , 1 ).contiguous ()
718720
719-         query_attn_out , query_lse  =  \
720-             flash_attn_varlen_func (
721-                 q = query ,
722-                 k = key ,
723-                 v = value ,
724-                 out = None ,
725-                 cu_seqlens_q = cu_seqlens_q ,
726-                 max_seqlen_q = max_seqlen_q ,
727-                 cu_seqlens_k = cu_seqlens_q ,
728-                 max_seqlen_k = max_seqlen_q ,
729-                 softmax_scale = self .scale ,
730-                 causal = attn_metadata .causal ,
731-                 alibi_slopes = self .alibi_slopes ,
732-                 window_size = self .sliding_window ,
733-                 softcap = self .logits_soft_cap ,
734-                 return_softmax_lse = True ,
735-                 fa_version = self .vllm_flash_attn_version ,
736-                 q_descale = q_descale ,
737-                 k_descale = k_descale ,
738-                 v_descale = v_descale ,
739-             )
721+         query_attn_out , query_lse  =  flash_attn_varlen_func (
722+             q = query ,
723+             k = key ,
724+             v = value ,
725+             out = None ,
726+             cu_seqlens_q = cu_seqlens_q ,
727+             max_seqlen_q = max_seqlen_q ,
728+             cu_seqlens_k = cu_seqlens_q ,
729+             max_seqlen_k = max_seqlen_q ,
730+             softmax_scale = self .scale ,
731+             causal = attn_metadata .causal ,
732+             alibi_slopes = self .alibi_slopes ,
733+             window_size = self .sliding_window ,
734+             softcap = self .logits_soft_cap ,
735+             return_softmax_lse = True ,
736+             fa_version = self .vllm_flash_attn_version ,
737+             q_descale = q_descale ,
738+             k_descale = k_descale ,
739+             v_descale = v_descale ,
740+         )
740741        assert  context_attn_out_cor .shape  ==  query_attn_out .shape 
741742        assert  context_lse_cor .shape  ==  query_lse .shape 
742743        merge_attn_states (
0 commit comments