1010from vllm .model_executor .model_loader .loader import get_model_loader
1111from vllm .model_executor .model_loader .utils import set_default_torch_dtype
1212from vllm .model_executor .models import ModelRegistry
13+ from vllm .model_executor .models .llama_eagle3 import Eagle3LlamaForCausalLM
1314from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
1415from vllm .v1 .sample .metadata import SamplingMetadata
1516
@@ -39,11 +40,9 @@ def __init__(
3940
4041 self .hidden_size = vllm_config .model_config .get_hidden_size ()
4142
42- # TODO: make eagle3 compatible with cudagraph
43- self .use_cuda_graph = self .method != 'eagle3' and \
44- (self .vllm_config .compilation_config .level
45- == CompilationLevel .PIECEWISE and
46- not self .vllm_config .model_config .enforce_eager )
43+ self .use_cuda_graph = (self .vllm_config .compilation_config .level
44+ == CompilationLevel .PIECEWISE and
45+ not self .vllm_config .model_config .enforce_eager )
4746
4847 self .cudagraph_batch_sizes = list (
4948 reversed (
@@ -90,6 +89,12 @@ def propose(
9089 batch_size = next_token_ids .shape [0 ]
9190 last_token_indices = cu_num_tokens [1 :] - 1
9291
92+ if self .method == "eagle3" :
93+ assert isinstance (self .model , Eagle3LlamaForCausalLM )
94+ target_hidden_states = self .model .combine_hidden_states (
95+ target_hidden_states )
96+ assert target_hidden_states .shape [- 1 ] == self .hidden_size
97+
9398 # Shift the input ids by one token.
9499 # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
95100 self .input_ids [:num_tokens - 1 ] = target_token_ids [1 :]
@@ -126,20 +131,15 @@ def propose(
126131 # copy inputs to buffer for cudagraph
127132 self .positions [:num_tokens ] = target_positions
128133
129- if self .method == 'eagle' :
130- self .hidden_states [:num_tokens ] = target_hidden_states
131- hidden_states = self .hidden_states
132- else :
133- # TODO: make eagle3 compatible with cuda graph
134- hidden_states = target_hidden_states
134+ self .hidden_states [:num_tokens ] = target_hidden_states
135135
136136 with set_forward_context (attn_metadata ,
137137 self .vllm_config ,
138138 num_tokens = num_input_tokens ):
139139 last_hidden_states , hidden_states = self .model (
140140 input_ids = self .input_ids [:num_input_tokens ],
141141 positions = self .positions [:num_input_tokens ],
142- hidden_states = hidden_states [:num_input_tokens ],
142+ hidden_states = self . hidden_states [:num_input_tokens ],
143143 )
144144 sample_hidden_states = last_hidden_states [last_token_indices ]
145145 logits = self .model .compute_logits (sample_hidden_states , None )
@@ -209,10 +209,7 @@ def propose(
209209 self .input_ids [:batch_size ] = input_ids
210210 self .positions [:batch_size ] = clamped_positions
211211
212- if self .method == 'eagle' :
213- # TODO: make eagle3 compatible with cudagraph.
214- self .hidden_states [:batch_size ] = hidden_states
215- hidden_states = self .hidden_states
212+ self .hidden_states [:batch_size ] = hidden_states
216213
217214 # Run the model.
218215 with set_forward_context (attn_metadata ,
@@ -221,7 +218,7 @@ def propose(
221218 last_hidden_states , hidden_states = self .model (
222219 input_ids = self .input_ids [:input_batch_size ],
223220 positions = self .positions [:input_batch_size ],
224- hidden_states = hidden_states [:input_batch_size ],
221+ hidden_states = self . hidden_states [:input_batch_size ],
225222 )
226223 hidden_states = hidden_states [:batch_size ]
227224 logits = self .model .compute_logits (last_hidden_states [:batch_size ],
@@ -314,12 +311,11 @@ def dummy_run(
314311 ) -> None :
315312 with set_forward_context (None , self .vllm_config ,
316313 num_tokens = num_tokens ):
317- if self .method == 'eagle' :
318- self .model (
319- input_ids = self .input_ids [:num_tokens ],
320- positions = self .positions [:num_tokens ],
321- hidden_states = self .hidden_states [:num_tokens ],
322- )
314+ self .model (
315+ input_ids = self .input_ids [:num_tokens ],
316+ positions = self .positions [:num_tokens ],
317+ hidden_states = self .hidden_states [:num_tokens ],
318+ )
323319
324320
325321# NOTE(woosuk): Currently, the below code is not used and we always use argmax
0 commit comments