@@ -75,19 +75,22 @@ def __init__(
7575 self .use_sparse = hasattr (vllm_config .model_config .hf_config ,
7676 "index_topk" )
7777
78- self .actual_seq_lengths_q = list (
79- range (1 , self .runner .max_num_tokens + 1 , 1 ))
80- self .query_start_loc = torch .zeros (self .runner .max_num_reqs + 1 ,
81- dtype = torch .int32 ,
82- device = self .device )
83- self .query_start_loc_cpu = torch .zeros (self .runner .max_num_reqs + 1 ,
84- dtype = torch .int32 ,
85- device = "cpu" ,
86- pin_memory = True )
78+ # self.actual_seq_lengths_q = list(
79+ # range(1, self.runner.max_num_tokens + 1, 1))
80+ self .query_start_loc = torch .zeros (
81+ self .runner .max_num_reqs * (self .num_speculative_tokens + 1 ) + 1 ,
82+ dtype = torch .int32 ,
83+ device = self .device )
84+ self .query_start_loc_cpu = torch .zeros (
85+ self .runner .max_num_reqs * (self .num_speculative_tokens + 1 ) + 1 ,
86+ dtype = torch .int32 ,
87+ device = "cpu" ,
88+ pin_memory = True )
8789 self .slot_mapping = torch .zeros (self .runner .max_num_tokens ,
8890 dtype = torch .int32 ,
8991 device = self .device )
90- self .seq_lens_cpu = torch .zeros (self .runner .max_num_reqs ,
92+ self .seq_lens_cpu = torch .zeros (self .runner .max_num_reqs *
93+ (self .num_speculative_tokens + 1 ),
9194 dtype = torch .int32 ,
9295 device = "cpu" ,
9396 pin_memory = True )
@@ -175,7 +178,6 @@ def dummy_run(self,
175178 elif aclgraph_runtime_mode == CUDAGraphMode .FULL :
176179 assert with_prefill is False , \
177180 "Full decode graph only supports uniform batch now."
178- num_reqs = num_tokens
179181 max_seq_lens = self .runner .model_config .max_model_len
180182 self .seq_lens_cpu [:num_reqs ] = max_seq_lens
181183 self .seq_lens_cpu [num_reqs :] = 0
@@ -184,7 +186,7 @@ def dummy_run(self,
184186 self .runner .input_batch .
185187 num_computed_tokens_cpu_tensor [:num_reqs ])
186188 query_start_loc = torch .tensor (
187- [0 ] + self .actual_seq_lengths_q [:num_reqs ],
189+ [0 ] + self .runner . actual_seq_lengths_q [:num_reqs ],
188190 device = self .runner .device ,
189191 dtype = torch .int32 )
190192 self .query_start_loc [:num_reqs + 1 ].copy_ (query_start_loc )
@@ -207,7 +209,7 @@ def dummy_run(self,
207209 spec_attn_mask = self .runner .spec_attn_mask ,
208210 attn_state = self .runner .attn_state ,
209211 decode_token_per_req = self .runner .decode_token_per_req ,
210- cos = self .runner .cos , # 考虑mrope,是否可以共用?
212+ cos = self .runner .cos ,
211213 sin = self .runner .sin ,
212214 )
213215
@@ -350,7 +352,8 @@ def generate_token_ids(self,
350352 block_table = attn_metadata .block_tables ,
351353 sampling_metadata = sampling_metadata ,
352354 token_indices = accepted_token_indices ,
353- scheduler_output = scheduler_output )
355+ scheduler_output = scheduler_output ,
356+ num_scheduled_tokens = num_scheduled_tokens )
354357 spec_token_ids = draft_token_ids .tolist ()
355358 return spec_token_ids
356359
@@ -416,12 +419,16 @@ def _prepare_inputs(
416419 batch_size = num_rejected_tokens .shape [0 ]
417420 self .query_start_loc [:batch_size + 1 ].copy_ (cu_num_tokens [:batch_size +
418421 1 ])
422+ self .query_start_loc [batch_size + 1 :].fill_ (0 )
419423 self .query_start_loc_cpu [:batch_size + 1 ].copy_ (
420424 self .query_start_loc [:batch_size + 1 ], non_blocking = True )
425+ self .query_start_loc_cpu [batch_size + 1 :].fill_ (0 )
421426 target_positions_len = target_positions .shape [0 ]
422427 self .positions [:target_positions_len ].copy_ (target_positions )
428+ self .positions [target_positions_len :].fill_ (0 )
423429 target_slot_mapping_len = target_slot_mapping .shape [0 ]
424430 self .slot_mapping [:target_slot_mapping_len ].copy_ (target_slot_mapping )
431+ self .slot_mapping [target_slot_mapping_len :].fill_ (0 )
425432
426433 return cu_num_tokens , token_indices , target_token_ids , target_positions , target_hidden_states , target_slot_mapping
427434
@@ -443,7 +450,8 @@ def _propose(
443450 block_table : torch .Tensor ,
444451 sampling_metadata : SamplingMetadata ,
445452 token_indices = None ,
446- scheduler_output : SchedulerOutput = None ) -> torch .Tensor :
453+ scheduler_output : SchedulerOutput = None ,
454+ num_scheduled_tokens : int = 0 ) -> torch .Tensor :
447455 num_tokens = target_token_ids .shape [0 ]
448456 batch_size = next_token_ids .shape [0 ]
449457 last_token_indices = cu_num_tokens [1 :] - 1
@@ -489,6 +497,30 @@ def _propose(
489497 seq_lens = seq_lens .int ()
490498 seq_lens_len = seq_lens .shape [0 ]
491499 self .seq_lens_cpu [:seq_lens_len ].copy_ (seq_lens , non_blocking = True )
500+ self .seq_lens_cpu [seq_lens_len :].fill_ (0 )
501+
502+ if self .torchair_graph_enabled :
503+ # torchair mode can reuse self.runner.num_tokens_across_dp
504+ num_tokens_across_dp = self .runner .num_tokens_across_dp
505+ with_prefill = self .runner .with_prefill
506+ elif self .vllm_config .compilation_config .cudagraph_mode .has_full_cudagraphs (
507+ ):
508+ (num_input_tokens , num_tokens_across_dp , with_prefill ,
509+ _ ) = self .runner ._sync_metadata_across_dp (
510+ num_scheduled_tokens , self .runner .with_prefill , False )
511+ else :
512+ # torch mode need to update num_tokens_across_dp
513+ # TODO: adapt enable_dbo later
514+ (num_input_tokens , num_tokens_across_dp , with_prefill ,
515+ _ ) = self .runner ._sync_metadata_across_dp (
516+ num_input_tokens , self .runner .with_prefill , False )
517+
518+ self .vllm_config .compilation_config .cudagraph_mode .has_full_cudagraphs (
519+ ):
520+ graph_pad_size = num_input_tokens
521+ else :
522+ graph_pad_size = self .runner .graph_pad_size
523+
492524 common_attn_metadata = AscendCommonAttentionMetadata (
493525 query_start_loc = self .query_start_loc [:batch_size + 1 ],
494526 query_start_loc_cpu = self .query_start_loc_cpu [:batch_size + 1 ],
@@ -504,7 +536,7 @@ def _propose(
504536 attn_mask = self .runner .attn_mask ,
505537 spec_attn_mask = self .runner .spec_attn_mask ,
506538 attn_state = self .runner .attn_state ,
507- graph_pad_size = self . runner . graph_pad_size ,
539+ graph_pad_size = graph_pad_size ,
508540 decode_token_per_req = self .runner .decode_token_per_req ,
509541 num_computed_tokens_cpu = None ,
510542 seq_lens = None )
@@ -522,20 +554,8 @@ def _propose(
522554 attn_metadata = self .runner .attn_metadata_builder .build (
523555 0 , common_attn_metadata , self .runner .get_model ())
524556
525- self .positions [:num_tokens ] = target_positions
526557 self .hidden_states [:num_tokens ] = target_hidden_states
527558
528- if not self .torchair_graph_enabled :
529- # torch mode need to update num_tokens_across_dp
530- # TODO: adapt enable_dbo later
531- (num_input_tokens , num_tokens_across_dp , with_prefill ,
532- _ ) = self .runner ._sync_metadata_across_dp (
533- num_input_tokens , self .runner .with_prefill , False )
534- else :
535- # torchair mode can reuse self.runner.num_tokens_across_dp
536- num_tokens_across_dp = self .runner .num_tokens_across_dp
537- with_prefill = self .runner .with_prefill
538-
539559 moe_comm_type = self .runner ._select_moe_comm_method (
540560 num_input_tokens , with_prefill )
541561
0 commit comments