3030from vllm .logger import init_logger
3131from vllm .v1 .utils import CpuGpuBuffer
3232from vllm .utils import (
33- is_pin_memory_available ,
34- )
33+ is_pin_memory_available , )
3534import numpy as np
3635from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
3736
@@ -79,55 +78,53 @@ def __init__(
7978 if compilation_config .mode == CompilationMode .VLLM_COMPILE :
8079 cudagraph_mode = compilation_config .cudagraph_mode
8180 if cudagraph_mode != CUDAGraphMode .NONE and not cudagraph_mode .has_mode (
82- CUDAGraphMode .PIECEWISE
83- ):
81+ CUDAGraphMode .PIECEWISE ):
8482 logger .warning (
8583 "Currently the eagle proposer only supports cudagraph_mode "
8684 "PIECEWISE, if you want the drafter to use cuda graphs, "
8785 "please set compilation_config.cudagraph_mode to PIECEWISE "
88- "or FULL_AND_PIECEWISE"
89- )
90- self .use_aclgraph = (
91- cudagraph_mode .has_mode (CUDAGraphMode .PIECEWISE )
92- and not self .speculative_config .enforce_eager
93- )
94-
95- self .cudagraph_batch_sizes = (
96- list (reversed (self .vllm_config .compilation_config .cudagraph_capture_sizes ))
97- if self .use_aclgraph
98- else []
99- )
86+ "or FULL_AND_PIECEWISE" )
87+ self .use_aclgraph = (cudagraph_mode .has_mode (
88+ CUDAGraphMode .PIECEWISE )
89+ and not self .speculative_config .enforce_eager )
90+
91+ self .cudagraph_batch_sizes = (list (
92+ reversed (
93+ self .vllm_config .compilation_config .cudagraph_capture_sizes ))
94+ if self .use_aclgraph else [])
10095
10196 # persistent buffers for aclgraph graph
102- self .input_ids = torch .zeros (
103- self . max_num_tokens , dtype = torch .int32 , device = device
104- )
97+ self .input_ids = torch .zeros (self . max_num_tokens ,
98+ dtype = torch .int32 ,
99+ device = device )
105100 self .uses_mrope = self .vllm_config .model_config .uses_mrope
106101 if self .uses_mrope :
107102 # M-RoPE need (3, max_num_tokens)
108- self .mrope_positions = torch .zeros (
109- ( 3 , self . max_num_tokens ), dtype = torch .int64 , device = device
110- )
103+ self .mrope_positions = torch .zeros (( 3 , self . max_num_tokens ),
104+ dtype = torch .int64 ,
105+ device = device )
111106 else :
112107 # RoPE need (max_num_tokens,)
113- self .positions = torch .zeros (
114- self . max_num_tokens , dtype = torch .int64 , device = device
115- )
108+ self .positions = torch .zeros (self . max_num_tokens ,
109+ dtype = torch .int64 ,
110+ device = device )
116111 self .hidden_states = torch .zeros (
117- (self .max_num_tokens , self .hidden_size ), dtype = self .dtype , device = device
118- )
112+ (self .max_num_tokens , self .hidden_size ),
113+ dtype = self .dtype ,
114+ device = device )
119115
120116 # We need +1 here because the arange is used to set query_start_loc,
121117 # which has one more element than batch_size.
122118 max_batch_size = vllm_config .scheduler_config .max_num_seqs
123119 max_num_slots_for_arange = max (max_batch_size + 1 , self .max_num_tokens )
124- self .arange = torch .arange (
125- max_num_slots_for_arange , device = device , dtype = torch . int32
126- )
120+ self .arange = torch .arange (max_num_slots_for_arange ,
121+ device = device ,
122+ dtype = torch . int32 )
127123
128124 self .inputs_embeds = torch .zeros (
129- (self .max_num_tokens , self .hidden_size ), dtype = self .dtype , device = device
130- )
125+ (self .max_num_tokens , self .hidden_size ),
126+ dtype = self .dtype ,
127+ device = device )
131128
132129 self .backup_next_token_ids = CpuGpuBuffer (
133130 max_batch_size ,
@@ -221,8 +218,8 @@ def generate_token_ids(self,
221218 hidden_states : torch .Tensor = None ,
222219 attn_metadata = None ,
223220 aux_hidden_states : torch .Tensor = None ,
224- common_attn_metadata : AscendCommonAttentionMetadata = None
225- ):
221+ common_attn_metadata : AscendCommonAttentionMetadata
222+ | None = None ):
226223 if attn_metadata is not None and isinstance (attn_metadata , dict ):
227224 attn_metadata = attn_metadata ['model.layers.0.self_attn.attn' ]
228225 next_token_ids : list [int ] = []
@@ -299,12 +296,12 @@ def generate_token_ids(self,
299296 # common_attn_metadata
300297 # )
301298 if self .speculative_config .disable_padded_drafter_batch :
302- token_indices_to_sample = None
303- common_attn_metadata , token_indices = \
304- self ._prepare_inputs (
305- common_attn_metadata ,
306- sampled_token_ids ,
307- spec_decode_metadata .num_draft_tokens )
299+ token_indices_to_sample = None
300+ common_attn_metadata , token_indices = \
301+ self ._prepare_inputs (
302+ common_attn_metadata ,
303+ sampled_token_ids ,
304+ spec_decode_metadata .num_draft_tokens )
308305 else :
309306 common_attn_metadata , token_indices , \
310307 token_indices_to_sample = \
@@ -317,15 +314,15 @@ def generate_token_ids(self,
317314 target_hidden_states = hidden_states [:token_indices ]
318315
319316 draft_token_ids = self ._propose (
320- target_token_ids = target_token_ids ,
321- target_positions = target_positions ,
322- target_hidden_states = target_hidden_states ,
323- next_token_ids = next_token_ids ,
324- last_token_indices = token_indices_to_sample ,
325- common_attn_metadata = common_attn_metadata ,
326- sampling_metadata = sampling_metadata ,
327- )
328-
317+ target_token_ids = target_token_ids ,
318+ target_positions = target_positions ,
319+ target_hidden_states = target_hidden_states ,
320+ next_token_ids = next_token_ids ,
321+ last_token_indices = token_indices_to_sample ,
322+ common_attn_metadata = common_attn_metadata ,
323+ sampling_metadata = sampling_metadata ,
324+ )
325+
329326 return draft_token_ids
330327
331328 def _prepare_inputs (
@@ -360,14 +357,16 @@ def _prepare_inputs(
360357 n + 1 - len (sampled_token_ids [i ]) if n > 0 else 0
361358 for i , n in enumerate (num_draft_tokens )
362359 ]
363- num_rejected_tokens = torch .tensor (num_rejected_tokens , dtype = torch .int32 )
360+ num_rejected_tokens = torch .tensor (num_rejected_tokens ,
361+ dtype = torch .int32 )
364362
365363 device = common_attn_metadata .query_start_loc .device
366364 query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
367365 new_seq_lens_cpu = common_attn_metadata .seq_lens_cpu - num_rejected_tokens
368366
369367 # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
370- new_query_len_per_req = query_start_loc_cpu [1 :] - query_start_loc_cpu [:- 1 ]
368+ new_query_len_per_req = query_start_loc_cpu [
369+ 1 :] - query_start_loc_cpu [:- 1 ]
371370 # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
372371 new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
373372 new_num_tokens_per_req_np = new_num_tokens_per_req .numpy ()
@@ -388,36 +387,36 @@ def _prepare_inputs(
388387 # [0, 2, 6, 9] ->
389388 # [0, 0, 2, 2, 2, 2, 6, 6, 6]
390389 # _r1_ ____r2____ ___r3__
391- new_query_start_locs_expanded = np .repeat (
392- new_query_start_loc_np [:- 1 ], new_num_tokens_per_req_np
393- )
390+ new_query_start_locs_expanded = np .repeat (new_query_start_loc_np [:- 1 ],
391+ new_num_tokens_per_req_np )
394392 # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
395393 # [0, 1, 0, 1, 2, 3, 0, 1, 2]
396394 # _r1_ ____r2____ ___r3__
397- token_offests = (
398- self .token_arange_np [:total_num_tokens ] - new_query_start_locs_expanded
399- )
395+ token_offests = (self .token_arange_np [:total_num_tokens ] -
396+ new_query_start_locs_expanded )
400397
401398 # Expand starting positions to match token pattern
402399 # [0, q1, q1 + q2] ->
403400 # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
404401 # _r1_ _____r2_______ ___________r3____________
405402 old_query_start_locs_expanded = np .repeat (
406- query_start_loc_cpu [:- 1 ].numpy (), new_num_tokens_per_req_np
407- )
403+ query_start_loc_cpu [:- 1 ].numpy (), new_num_tokens_per_req_np )
408404 # Final token indices are:
409405 # [0, 1, // req 1
410406 # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
411407 # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
412408 token_indices_np = token_offests + old_query_start_locs_expanded
413- token_indices = torch .from_numpy (token_indices_np ).to (device , non_blocking = True )
409+ token_indices = torch .from_numpy (token_indices_np ).to (
410+ device , non_blocking = True )
414411
415412 spec_common_attn_metadata = AscendCommonAttentionMetadata (
416- query_start_loc = new_query_start_loc_cpu .to (device , non_blocking = True ),
413+ query_start_loc = new_query_start_loc_cpu .to (device ,
414+ non_blocking = True ),
417415 query_start_loc_cpu = new_query_start_loc_cpu ,
418416 seq_lens = new_seq_lens_cpu .to (device , non_blocking = True ),
419417 seq_lens_cpu = new_seq_lens_cpu ,
420- num_computed_tokens_cpu = common_attn_metadata .num_computed_tokens_cpu ,
418+ num_computed_tokens_cpu = common_attn_metadata .
419+ num_computed_tokens_cpu ,
421420 num_reqs = common_attn_metadata .num_reqs ,
422421 num_actual_tokens = total_num_tokens ,
423422 max_query_len = new_query_len_per_req .max ().item (),
@@ -432,9 +431,6 @@ def _prepare_inputs(
432431 decode_token_per_req = self .runner .decode_token_per_req ,
433432 )
434433 return spec_common_attn_metadata , token_indices
435-
436-
437-
438434
439435 def _propose (
440436 self ,
@@ -460,8 +456,7 @@ def _propose(
460456 if self .method == "eagle3" :
461457 assert isinstance (self .model , Eagle3LlamaForCausalLM )
462458 target_hidden_states = self .model .combine_hidden_states (
463- target_hidden_states
464- )
459+ target_hidden_states )
465460 assert target_hidden_states .shape [- 1 ] == self .hidden_size
466461
467462 # Shift the input ids by one token.
@@ -506,10 +501,9 @@ def _propose(
506501 if aclgraph_runtime_mode != CUDAGraphMode .NONE :
507502 # Fallback to piecewise graph, when acl full graph is enabled
508503 logger .warning (
509- f"Currently the eagle proposer only supports cudagraph_mode "
510- "PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
511- "to CUDAGraphMode.PIECEWISE"
512- )
504+ f"Currently the eagle proposer only supports cudagraph_mode "
505+ "PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
506+ "to CUDAGraphMode.PIECEWISE" )
513507 aclgraph_runtime_mode = CUDAGraphMode .PIECEWISE
514508
515509 for step in range (self .num_speculative_tokens ):
@@ -692,14 +686,15 @@ def prepare_next_token_ids_cpu(
692686 # Get the next token id from the request state.
693687 req_id = req_ids [i ]
694688 req_state = requests [req_id ]
695- seq_len = req_state .num_computed_tokens + num_scheduled_tokens [req_id ]
689+ seq_len = req_state .num_computed_tokens + num_scheduled_tokens [
690+ req_id ]
696691 next_token_id = req_state .get_token_id (seq_len )
697692 next_token_ids .append (next_token_id )
698- next_token_ids = torch .tensor (
699- next_token_ids , dtype = torch .int32 , device = self . input_ids . device
700- )
693+ next_token_ids = torch .tensor (next_token_ids ,
694+ dtype = torch .int32 ,
695+ device = self . input_ids . device )
701696 return next_token_ids
702-
697+
703698 def prepare_next_token_ids_padded (
704699 self ,
705700 common_attn_metadata : CommonAttentionMetadata ,
@@ -722,30 +717,24 @@ def prepare_next_token_ids_padded(
722717
723718 # Precompute get_token_id for when there is no valid next token
724719 num_reqs = gpu_input_batch .num_reqs
725- self .backup_next_token_ids .np [:num_reqs ] = np .array (
726- [
727- requests [gpu_input_batch .req_ids [i ]].get_token_id (
728- common_attn_metadata .seq_lens_cpu [i ].item ()
729- )
730- for i in range (num_reqs )
731- ]
732- )
720+ self .backup_next_token_ids .np [:num_reqs ] = np .array ([
721+ requests [gpu_input_batch .req_ids [i ]].get_token_id (
722+ common_attn_metadata .seq_lens_cpu [i ].item ())
723+ for i in range (num_reqs )
724+ ])
733725 self .backup_next_token_ids .copy_to_gpu (num_reqs )
734726
735727 # Mask out the sampled tokens indices that should not be sampled.
736- discard_sampled_tokens_req_indices = discard_request_indices [
737- :num_discarded_requests
738- ]
728+ discard_sampled_tokens_req_indices = discard_request_indices [:
729+ num_discarded_requests ]
739730
740731 valid_sampled_token_ids_gpu = sampled_token_ids .clone ()
741732 valid_sampled_token_ids_gpu .index_fill_ (
742- 0 , discard_sampled_tokens_req_indices , - 1
743- )
733+ 0 , discard_sampled_tokens_req_indices , - 1 )
744734
745735 # Generate a mask for all valid tokens within those requests
746736 valid_mask = (valid_sampled_token_ids_gpu != - 1 ) & (
747- valid_sampled_token_ids_gpu < gpu_input_batch .vocab_size
748- )
737+ valid_sampled_token_ids_gpu < gpu_input_batch .vocab_size )
749738
750739 # Count the number of valid tokens in each request
751740 valid_sampled_tokens_count = valid_mask .sum (dim = 1 )
@@ -757,8 +746,8 @@ def prepare_next_token_ids_padded(
757746 # Get last valid token from each row
758747 # (assume undefined state where there is no valid token)
759748 selected_tokens = torch .gather (
760- valid_sampled_token_ids_gpu , 1 , last_valid_indices_safe . unsqueeze ( 1 )
761- ).squeeze (1 )
749+ valid_sampled_token_ids_gpu , 1 ,
750+ last_valid_indices_safe . unsqueeze ( 1 ) ).squeeze (1 )
762751
763752 # Use last token if valid, pre-computed backup if not
764753 batch_size = valid_sampled_token_ids_gpu .shape [0 ]
@@ -769,7 +758,7 @@ def prepare_next_token_ids_padded(
769758 )
770759
771760 return next_token_ids , valid_sampled_tokens_count
772-
761+
773762 def prepare_inputs_padded (
774763 self ,
775764 common_attn_metadata : CommonAttentionMetadata ,
@@ -784,13 +773,11 @@ def prepare_inputs_padded(
784773 used as padding and filtered out later by `token_indices_to_sample`.
785774 No blocking CPU operations should be introduced in this function.
786775 """
787- num_draft_tokens_gpu = torch .cat (
788- [
789- spec_decode_metadata .cu_num_draft_tokens [0 :1 ],
790- spec_decode_metadata .cu_num_draft_tokens [1 :]
791- - spec_decode_metadata .cu_num_draft_tokens [:- 1 ],
792- ]
793- )
776+ num_draft_tokens_gpu = torch .cat ([
777+ spec_decode_metadata .cu_num_draft_tokens [0 :1 ],
778+ spec_decode_metadata .cu_num_draft_tokens [1 :] -
779+ spec_decode_metadata .cu_num_draft_tokens [:- 1 ],
780+ ])
794781
795782 num_rejected_tokens_gpu = torch .where (
796783 num_draft_tokens_gpu > 0 ,
@@ -800,7 +787,8 @@ def prepare_inputs_padded(
800787
801788 query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
802789
803- new_query_len_per_req = query_start_loc_cpu [1 :] - query_start_loc_cpu [:- 1 ]
790+ new_query_len_per_req = query_start_loc_cpu [
791+ 1 :] - query_start_loc_cpu [:- 1 ]
804792
805793 total_num_tokens = query_start_loc_cpu [- 1 ].item ()
806794 token_indices = self .arange [:total_num_tokens ]
@@ -821,11 +809,11 @@ def prepare_inputs_padded(
821809 attn_state = self .runner .attn_state ,
822810 graph_pad_size = self .runner .graph_pad_size ,
823811 decode_token_per_req = self .runner .decode_token_per_req ,
824- num_computed_tokens_cpu = common_attn_metadata .num_computed_tokens_cpu ,
812+ num_computed_tokens_cpu = common_attn_metadata .
813+ num_computed_tokens_cpu ,
825814 seq_lens = common_attn_metadata .seq_lens )
826815
827- token_indices_to_sample = (
828- common_attn_metadata .query_start_loc [1 :] - 1 - num_rejected_tokens_gpu
829- )
816+ token_indices_to_sample = (common_attn_metadata .query_start_loc [1 :] -
817+ 1 - num_rejected_tokens_gpu )
830818
831- return spec_common_attn_metadata , token_indices , token_indices_to_sample
819+ return spec_common_attn_metadata , token_indices , token_indices_to_sample
0 commit comments