2929from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
3030from vllm_ascend .compilation .acl_graph import (ACLGraphWrapper ,
3131 set_mtp_graph_params ,
32+ update_mla_attn_dcp_pcp_params ,
3233 update_mla_attn_params )
3334from vllm_ascend .spec_decode .interface import Proposer , SpecDcodeType
3435from vllm_ascend .utils import (ProfileExecuteDuration , lmhead_tp_enable ,
@@ -102,6 +103,7 @@ def __init__(
102103 self .pcp_size = self .runner .pcp_size
103104 self .dcp_size = self .runner .dcp_size
104105 self .pcp_rank = self .runner .pcp_rank
106+ self .dcp_rank = self .runner .dcp_rank
105107
106108 self .attn_metadata_builder : Optional [AttentionMetadataBuilder ] = None
107109 self .draft_indexer_metadata_builder : Optional [
@@ -258,6 +260,13 @@ def dummy_run(self,
258260 cos = self .runner .cos ,
259261 sin = self .runner .sin ,
260262 )
263+ if self .pcp_size * self .dcp_size > 1 :
264+ # update long_seq related params and flatten block_table
265+ common_attn_metadata .prefill_context_parallel_metadata = \
266+ self .runner .long_seq_metadata
267+ common_attn_metadata .block_table_tensor = \
268+ self .runner .input_batch .block_table [0 ].get_device_tensor ()[
269+ :num_reqs * self .decode_threshold ]
261270
262271 builder = self .runner .attn_groups [0 ][0 ].get_metadata_builder ()
263272 attn_metadata_mtp = builder .build_for_graph_capture (
@@ -303,9 +312,13 @@ def dummy_run(self,
303312 if forward_context .cudagraph_runtime_mode == CUDAGraphMode .FULL and \
304313 not forward_context .capturing :
305314 if self .vllm_config .model_config .use_mla :
306- update_mla_attn_params (
307- self .update_stream , forward_context , num_tokens ,
308- self .vllm_config .speculative_config )
315+ if self .pcp_size * self .dcp_size > 1 :
316+ update_mla_attn_dcp_pcp_params (
317+ self .update_stream , forward_context , num_tokens )
318+ else :
319+ update_mla_attn_params (
320+ self .update_stream , forward_context , num_tokens ,
321+ self .vllm_config .speculative_config )
309322 if self .enable_shared_expert_dp :
310323 positions = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
311324 positions , True )
@@ -357,7 +370,7 @@ def generate_token_ids(self,
357370 )
358371
359372 req_scheduled_tokens = scheduler_output .num_scheduled_tokens
360- if self .pcp_size > 1 :
373+ if self .pcp_size * self . dcp_size > 1 :
361374 long_seq_metadata = self .runner .long_seq_metadata
362375 input_ids_pcp_full = self .runner .input_ids_pcp_full
363376 query_start_loc_pcp_full = self .runner .query_start_loc_pcp_full
@@ -393,7 +406,6 @@ def generate_token_ids(self,
393406 common_attn_metadata .query_start_loc = \
394407 query_start_loc_pcp_full [:num_reqs + 1 ]
395408 if self .speculative_config .disable_padded_drafter_batch :
396- # NOTE: Currently, MTP-fullgraph is incompatibility with pcp
397409 token_indices_to_sample = None
398410 common_attn_metadata , token_indices = \
399411 self ._prepare_inputs (
@@ -604,28 +616,36 @@ def _propose(
604616 self .input_ids [last_token_indices ] = next_token_ids
605617
606618 # update pcp related params
607- if self .pcp_size > 1 :
619+ if self .pcp_size * self . dcp_size > 1 :
608620 assert long_seq_metadata is not None
609621 common_attn_metadata .prefill_context_parallel_metadata = long_seq_metadata
622+ ori_last_token_indices = last_token_indices .cpu ()
623+ query_lens_d = self .runner .query_lens [:num_decode_reqs ]
624+ if self .pcp_size > 1 :
610625 # 1. preprocess decode/prefill input_ids & target_hidden_states
611626 # decode input_ids: keep unchanged
612627 # decode target_hidden_states: remove padding
613628 # prefill input_ids: add padding and pcp split
614629 # prefill target_hidden_states: pcp split
615- num_tokens_d = num_decode_reqs * self . decode_threshold
630+ num_tokens_d = query_lens_d . sum (). item ()
616631 num_tokens_d_padded = num_tokens_d * self .pcp_size
617632 input_ids_d = self .input_ids [:num_tokens_d ]
618633 input_ids_p = self .input_ids [num_tokens_d :num_tokens ]
619634 target_hidden_states_d_padded = \
620635 target_hidden_states [:num_tokens_d_padded ]
621636 if num_tokens_d :
622637 # remove padding (from pcp all-gather) in decode part
623- target_hidden_states_d = target_hidden_states_d_padded .reshape (
624- [
625- num_decode_reqs , self .decode_threshold * self .pcp_size ,
626- - 1
627- ])[:, :self .decode_threshold , :].reshape (
628- [num_tokens_d , - 1 ])
638+ mask_start_loc = torch .cat ([
639+ torch .tensor ([0 ], dtype = torch .int32 ),
640+ torch .cumsum (query_lens_d * self .pcp_size , dim = 0 )[:- 1 ]
641+ ])
642+ mask_len = query_lens_d
643+ mask = []
644+ for req_id in range (num_decode_reqs ):
645+ mask += list (
646+ range (mask_start_loc [req_id ],
647+ mask_start_loc [req_id ] + mask_len [req_id ]))
648+ target_hidden_states_d = target_hidden_states_d_padded [mask ]
629649 else :
630650 target_hidden_states_d = target_hidden_states_d_padded
631651 target_hidden_states_p = target_hidden_states [num_tokens_d_padded :]
@@ -755,10 +775,15 @@ def _propose(
755775 forward_context = get_forward_context ()
756776 if forward_context .cudagraph_runtime_mode == CUDAGraphMode .FULL :
757777 if self .vllm_config .model_config .use_mla :
758- update_mla_attn_params (
759- self .update_stream , forward_context ,
760- num_input_tokens ,
761- self .vllm_config .speculative_config )
778+ if self .pcp_size * self .dcp_size > 1 :
779+ update_mla_attn_dcp_pcp_params (
780+ self .update_stream , forward_context ,
781+ num_input_tokens )
782+ else :
783+ update_mla_attn_params (
784+ self .update_stream , forward_context ,
785+ num_input_tokens ,
786+ self .vllm_config .speculative_config )
762787
763788 if self .enable_shared_expert_dp :
764789 hidden_states = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
@@ -777,6 +802,8 @@ def _propose(
777802 (0 , max_num_reqs_across_dp - num_indices ))
778803
779804 if self .pcp_size > 1 :
805+ # remove graph padding before all_gather
806+ hidden_states = hidden_states [:num_tokens ]
780807 hidden_states = get_pcp_group ().all_gather (hidden_states , 0 )
781808 hidden_states = torch .index_select (
782809 hidden_states , 0 , self .runner .
@@ -808,6 +835,81 @@ def _propose(
808835
809836 attn_metadata_i = attn_metadata [self .attn_layer_name [0 ]]
810837
838+ # TODO refactor this
839+ if self .pcp_size * self .dcp_size > 1 :
840+ if step == 0 :
841+ num_reject_tokens = torch .tensor (self .runner .cu_num_tokens_pcp_full , dtype = torch .int32 ) - ori_last_token_indices - 1
842+ num_accept_tokens = query_lens_d - num_reject_tokens
843+ ori_seq_len = attn_metadata_i .seq_lens
844+ mtp_slot_pad = self .runner .mtp_slot_pad
845+ # ori slot: [ -1, -1, 134, -1, -1, -1, 135, -1, | -1, -1, 261, -1, -1, -1, 262, -1]
846+ # mtp slot: [ -1, -1, 134, -1, -1, -1, 135, -1, | -1, -1, 136, -1, | -1, -1, 261, -1, -1, -1, 262, -1, | -1, -1, 263, -1]
847+ # scheduled_tokens * pcp_size + (num_speculative_tokens - 1) * pcp_size
848+ slot_idx_base = torch .cat ([torch .tensor ([0 ], dtype = torch .int32 ), torch .cumsum (query_lens_d , dim = 0 )[:- 1 ] * self .pcp_size ]) # base offset from scheduled tokens
849+ slot_idx_base += torch .arange (num_decode_reqs ) * (self .num_speculative_tokens - 1 ) * self .pcp_size # offset from pre-allocated mtp tokens
850+ slot_idx_base += (num_accept_tokens - 1 ) * self .pcp_size # offset from accepted tokens
851+ slot_indices = []
852+ for req_id in range (num_decode_reqs ):
853+ slot_indices += list (range (slot_idx_base [req_id ], slot_idx_base [req_id ] + self .pcp_size ))
854+ slot_indices = torch .tensor (slot_indices , dtype = torch .int32 )
855+
856+ # fold block_table (restore it to original size before flattened)
857+ block_indices = torch .cat ([torch .tensor ([0 ], dtype = torch .int32 ), torch .cumsum (query_lens_d , dim = 0 )[:- 1 ]])
858+ attn_metadata_i .decode .block_table [:batch_size ] = attn_metadata_i .decode .block_table [block_indices ]
859+ attn_metadata_i .decode .block_table = attn_metadata_i .decode .block_table [:batch_size ]
860+
861+ positions = target_positions [ori_last_token_indices ]
862+ hidden_states = hidden_states [last_token_indices ]
863+ last_token_indices = self .arange [:batch_size ]
864+ if attn_metadata_i .num_decode_tokens != 0 :
865+ attn_metadata_i .num_decode_tokens = batch_size
866+
867+ input_ids = draft_token_ids_list [- 1 ].int ()
868+ positions += 1
869+
870+ if self .speculative_config .disable_padded_drafter_batch or \
871+ aclgraph_runtime_mode != CUDAGraphMode .FULL :
872+ attn_metadata_i .decode .cos = builder .cos_cache [
873+ positions [:batch_size ]].unsqueeze (1 ).unsqueeze (2 )
874+ attn_metadata_i .decode .sin = builder .sin_cache [
875+ positions [:batch_size ]].unsqueeze (1 ).unsqueeze (2 )
876+
877+ # exceeds_max_model_len
878+ exceeds_max_model_len = positions [:
879+ batch_size ] >= self .runner .model_config .max_model_len
880+ clamped_positions = torch .where (exceeds_max_model_len , 0 ,
881+ positions [:batch_size ])
882+
883+ # update local seq_len
884+ num_computed_tokens_of_pcp_dcp = self .runner ._get_cp_local_seq_lens (
885+ ori_seq_len + step + 1 ,
886+ self .pcp_size ,
887+ self .dcp_size ,
888+ self .runner .parallel_config .cp_kv_cache_interleave_size ,
889+ )
890+ cp_seq_len = num_computed_tokens_of_pcp_dcp [:, self .pcp_rank , self .dcp_rank ]
891+ batch_seq_mask = (cp_seq_len == 0 )
892+ builder .batch_seq_mask_buf [:batch_seq_mask .shape [0 ]].copy_ (
893+ batch_seq_mask , non_blocking = True )
894+ batch_seq_mask = builder .batch_seq_mask_buf [:batch_seq_mask .
895+ shape [0 ]]
896+ # batch_seq_mask = batch_seq_mask.to(self.device)
897+ cp_seq_len = torch .where (cp_seq_len == 0 , 1 , cp_seq_len )
898+ attn_metadata_i .decode .cp_seq_len = cp_seq_len
899+ attn_metadata_i .decode .batch_seq_mask = batch_seq_mask
900+
901+ # update slot_mapping
902+ slot_indices += self .pcp_size
903+ slot_mapping = mtp_slot_pad [slot_indices ]
904+
905+ self .input_ids [:batch_size ] = input_ids
906+ # self.positions[:batch_size] = positions[:batch_size]
907+ self .positions [:batch_size ] = clamped_positions
908+ self .hidden_states [:hidden_states .shape [0 ]] = hidden_states
909+ attn_metadata_i .slot_mapping [:batch_size * self .pcp_size ] = slot_mapping
910+
911+ continue
912+
811913 if step == 0 :
812914 positions = target_positions [last_token_indices ]
813915 hidden_states = hidden_states [last_token_indices ]
0 commit comments