44import triton
55import triton .language as tl
66
7- from vllm .config import VllmConfig , set_current_vllm_config
7+ from vllm .config import CompilationLevel , VllmConfig , set_current_vllm_config
88from vllm .forward_context import set_forward_context
99from vllm .logger import init_logger
1010from vllm .model_executor .model_loader .loader import get_model_loader
@@ -26,10 +26,41 @@ def __init__(
2626 device : torch .device ,
2727 ):
2828 self .vllm_config = vllm_config
29+ self .method = self .vllm_config .speculative_config .method
2930 self .num_speculative_tokens = (
3031 vllm_config .speculative_config .num_speculative_tokens )
3132 self .max_model_len = vllm_config .model_config .max_model_len
3233 self .block_size = vllm_config .cache_config .block_size
34+
35+ self .dtype = vllm_config .model_config .dtype
36+
37+ self .max_num_tokens = vllm_config .scheduler_config \
38+ .max_num_batched_tokens
39+
40+ self .hidden_size = vllm_config .model_config .get_hidden_size ()
41+
42+ # TODO: make eagle3 compatible with cudagraph
43+ self .use_cuda_graph = self .method != 'eagle3' and \
44+ (self .vllm_config .compilation_config .level
45+ == CompilationLevel .PIECEWISE and
46+ not self .vllm_config .model_config .enforce_eager )
47+
48+ self .cudagraph_batch_sizes = list (
49+ reversed (
50+ self .vllm_config .compilation_config .cudagraph_capture_sizes ))
51+
52+ # persistent buffers for cuda graph
53+ self .input_ids = torch .zeros (self .max_num_tokens ,
54+ dtype = torch .int32 ,
55+ device = device )
56+ self .positions = torch .zeros (self .max_num_tokens ,
57+ dtype = torch .int64 ,
58+ device = device )
59+
60+ self .hidden_states = torch .zeros (
61+ (self .max_num_tokens , self .hidden_size ),
62+ dtype = self .dtype ,
63+ device = device )
3364 # We need +1 here because the arange is used to set query_start_loc,
3465 # which has one more element than batch_size.
3566 self .arange = torch .arange (vllm_config .scheduler_config .max_num_seqs +
@@ -59,13 +90,12 @@ def propose(
5990 batch_size = next_token_ids .shape [0 ]
6091 last_token_indices = cu_num_tokens [1 :] - 1
6192
62- input_ids = torch .empty_like (target_token_ids )
6393 # Shift the input ids by one token.
6494 # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
65- input_ids [:- 1 ] = target_token_ids [1 :]
95+ self . input_ids [:num_tokens - 1 ] = target_token_ids [1 :]
6696 # Replace the last token with the next token.
6797 # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
68- input_ids [last_token_indices ] = next_token_ids
98+ self . input_ids [last_token_indices ] = next_token_ids
6999
70100 # FA requires seq_len to have dtype int32.
71101 seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
@@ -88,14 +118,30 @@ def propose(
88118 prefix_kv_lens = None ,
89119 suffix_kv_lens = None ,
90120 )
121+ if self .use_cuda_graph and \
122+ num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
123+ num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
124+ else :
125+ num_input_tokens = num_tokens
126+ # copy inputs to buffer for cudagraph
127+ self .positions [:num_tokens ] = target_positions
91128
92- with set_forward_context (attn_metadata , self .vllm_config ):
93- hidden_states_logits , hidden_states_fwd = self .model (
94- input_ids = input_ids ,
95- hidden_states = target_hidden_states ,
96- positions = target_positions ,
129+ if self .method == 'eagle' :
130+ self .hidden_states [:num_tokens ] = target_hidden_states
131+ hidden_states = self .hidden_states
132+ else :
133+ # TODO: make eagle3 compatible with cuda graph
134+ hidden_states = target_hidden_states
135+
136+ with set_forward_context (attn_metadata ,
137+ self .vllm_config ,
138+ num_tokens = num_input_tokens ):
139+ last_hidden_states , hidden_states = self .model (
140+ input_ids = self .input_ids [:num_input_tokens ],
141+ positions = self .positions [:num_input_tokens ],
142+ hidden_states = hidden_states [:num_input_tokens ],
97143 )
98- sample_hidden_states = hidden_states_logits [last_token_indices ]
144+ sample_hidden_states = last_hidden_states [last_token_indices ]
99145 logits = self .model .compute_logits (sample_hidden_states , None )
100146 draft_token_ids = logits .argmax (dim = - 1 )
101147
@@ -108,13 +154,20 @@ def propose(
108154 draft_token_ids_list = [draft_token_ids ]
109155
110156 positions = target_positions [last_token_indices ]
111- hidden_states = hidden_states_fwd [last_token_indices ]
157+ hidden_states = hidden_states [last_token_indices ]
158+ if self .use_cuda_graph and \
159+ batch_size <= self .cudagraph_batch_sizes [- 1 ]:
160+ input_batch_size = self .vllm_config .pad_for_cudagraph (batch_size )
161+ else :
162+ input_batch_size = batch_size
112163 attn_metadata .num_actual_tokens = batch_size
113164 attn_metadata .max_query_len = 1
114165 attn_metadata .query_start_loc = self .arange [:batch_size + 1 ]
115166 for _ in range (self .num_speculative_tokens - 1 ):
116167 # Update the inputs.
117- input_ids = draft_token_ids_list [- 1 ]
168+ # cast to int32 is crucial when eagle model is compiled.
169+ # tensor.argmax() returns int64 by default.
170+ input_ids = draft_token_ids_list [- 1 ].int ()
118171 positions += 1
119172
120173 # NOTE(woosuk): We should handle the case where the draft model
@@ -152,14 +205,27 @@ def propose(
152205 attn_metadata .slot_mapping .masked_fill_ (exceeds_max_model_len ,
153206 PADDING_SLOT_ID )
154207
208+ # copy inputs to buffer for cudagraph
209+ self .input_ids [:batch_size ] = input_ids
210+ self .positions [:batch_size ] = clamped_positions
211+
212+ if self .method == 'eagle' :
213+ # TODO: make eagle3 compatible with cudagraph.
214+ self .hidden_states [:batch_size ] = hidden_states
215+ hidden_states = self .hidden_states
216+
155217 # Run the model.
156- with set_forward_context (attn_metadata , self .vllm_config ):
157- hidden_states_logits , hidden_states = self .model (
158- input_ids = input_ids ,
159- hidden_states = hidden_states ,
160- positions = clamped_positions ,
218+ with set_forward_context (attn_metadata ,
219+ self .vllm_config ,
220+ num_tokens = input_batch_size ):
221+ last_hidden_states , hidden_states = self .model (
222+ input_ids = self .input_ids [:input_batch_size ],
223+ positions = self .positions [:input_batch_size ],
224+ hidden_states = hidden_states [:input_batch_size ],
161225 )
162- logits = self .model .compute_logits (hidden_states_logits , None )
226+ hidden_states = hidden_states [:batch_size ]
227+ logits = self .model .compute_logits (last_hidden_states [:batch_size ],
228+ None )
163229 draft_token_ids = logits .argmax (dim = - 1 )
164230 draft_token_ids_list .append (draft_token_ids )
165231
@@ -227,13 +293,11 @@ def load_model(self, target_model: nn.Module) -> None:
227293 draft_model_cls , arch = ModelRegistry .resolve_model_cls (
228294 draft_model_config .architectures )
229295 self .model = draft_model_cls (
230- model_config = draft_model_config ,
296+ vllm_config = self . vllm_config ,
231297 start_layer_id = target_layer_num ).to (target_device )
232298
233299 loaded_weights = self .model .load_weights (
234- loader .get_all_weights (
235- self .vllm_config .speculative_config .draft_model_config ,
236- self .model ))
300+ loader .get_all_weights (draft_model_config , self .model ))
237301 if self .vllm_config .speculative_config .method == "eagle3" :
238302 if "model.embed_tokens.weight" not in loaded_weights :
239303 logger .info (
@@ -243,6 +307,20 @@ def load_model(self, target_model: nn.Module) -> None:
243307 logger .info ("Loading EAGLE LM head weights from the target model." )
244308 self .model .lm_head = target_model .lm_head
245309
310+ @torch .inference_mode ()
311+ def dummy_run (
312+ self ,
313+ num_tokens : int ,
314+ ) -> None :
315+ with set_forward_context (None , self .vllm_config ,
316+ num_tokens = num_tokens ):
317+ if self .method == 'eagle' :
318+ self .model (
319+ input_ids = self .input_ids [:num_tokens ],
320+ positions = self .positions [:num_tokens ],
321+ hidden_states = self .hidden_states [:num_tokens ],
322+ )
323+
246324
247325# NOTE(woosuk): Currently, the below code is not used and we always use argmax
248326# to sample the draft tokens. We will use this after we find a way to manage
0 commit comments