@@ -32,27 +32,16 @@ def __init__(self,
3232 device : torch .device ,
3333 runner = None ):
3434 self .name = SpecDcodeType .EAGLE if vllm_config .speculative_config .method == "eagle" else SpecDcodeType .EAGLE3
35- self .device = device
3635 self .vllm_config = vllm_config
37- self .speculative_config = vllm_config .speculative_config
38- self .draft_model_config = self .speculative_config .draft_model_config
39- self .method = self .speculative_config .method
40-
36+ self .device = device
4137 self .runner = runner
42- self .dtype = vllm_config .model_config .dtype
43- self .max_model_len = vllm_config .model_config .max_model_len
44- self .block_size = vllm_config .cache_config .block_size
45- self .num_speculative_tokens = (
46- self .speculative_config .num_speculative_tokens )
47- self .max_num_tokens = (
48- vllm_config .scheduler_config .max_num_batched_tokens )
49- self .token_arange_np = np .arange (self .max_num_tokens )
5038
5139 self .block_size = vllm_config .cache_config .block_size
5240 # We need to get the hidden size from the draft model config because
5341 # the draft model's hidden size can be different from the target model's
5442 # hidden size (e.g., Llama 3.3 70B).
55- self .hidden_size = self .draft_model_config .get_hidden_size ()
43+ self .hidden_size = vllm_config .speculative_config .draft_model_config .get_hidden_size (
44+ )
5645
5746 self .use_cuda_graph = (self .vllm_config .compilation_config .level
5847 == CompilationLevel .PIECEWISE and
@@ -62,15 +51,18 @@ def __init__(self,
6251 self .vllm_config .compilation_config .cudagraph_capture_sizes ))
6352
6453 # persistent buffers for cuda graph
65- self .input_ids = torch .zeros (self .max_num_tokens ,
66- dtype = torch .int32 ,
67- device = device )
68- self .positions = torch .zeros (self .max_num_tokens ,
69- dtype = torch .int64 ,
70- device = device )
54+ self .input_ids = torch .zeros (
55+ self .vllm_config .scheduler_config .max_num_batched_tokens ,
56+ dtype = torch .int32 ,
57+ device = device )
58+ self .positions = torch .zeros (
59+ self .vllm_config .scheduler_config .max_num_batched_tokens ,
60+ dtype = torch .int64 ,
61+ device = device )
7162 self .hidden_states = torch .zeros (
72- (self .max_num_tokens , self .hidden_size ),
73- dtype = self .dtype ,
63+ (self .vllm_config .scheduler_config .max_num_batched_tokens ,
64+ self .hidden_size ),
65+ dtype = self .vllm_config .model_config .dtype ,
7466 device = device )
7567 # We need +1 here because the arange is used to set query_start_loc,
7668 # which has one more element than batch_size.
@@ -406,17 +398,14 @@ def _propose(
406398 # [batch_size, max_num_blocks_per_req]
407399 block_table : torch .Tensor ,
408400 sampling_metadata : SamplingMetadata ,
409- last_token_indices : Optional [torch .Tensor ],
410401 ) -> torch .Tensor :
411402 device = cu_num_tokens .device
412403 cu_num_tokens = cu_num_tokens .cpu ()
413404 block_table = block_table .cpu ()
414405 num_tokens = target_token_ids .shape [0 ]
415406 batch_size = next_token_ids .shape [0 ]
416- if last_token_indices is None :
417- last_token_indices = common_attn_metadata .query_start_loc [1 :] - 1
407+ last_token_indices = cu_num_tokens [1 :] - 1
418408 target_positions = target_positions .cpu ()
419-
420409 if self .name == SpecDcodeType .EAGLE3 :
421410 assert isinstance (self .model , Eagle3LlamaForCausalLM )
422411 target_hidden_states = self .model .combine_hidden_states (
0 commit comments