99import torch
1010import torch .nn as nn
1111
12- from vllm .config import CompilationLevel , VllmConfig , get_layers_from_vllm_config
12+ from vllm .config import (
13+ CompilationLevel ,
14+ CUDAGraphMode ,
15+ VllmConfig ,
16+ get_layers_from_vllm_config ,
17+ )
1318from vllm .distributed .parallel_state import get_pp_group
1419from vllm .forward_context import set_forward_context
1520from vllm .logger import init_logger
@@ -80,12 +85,25 @@ def __init__(
8085 self .attn_layer_names : list [str ] = []
8186 self .indexer_layer_names : list [str ] = []
8287
83- self .use_cuda_graph = (
84- not current_platform .is_xpu ()
85- and self .vllm_config .compilation_config .level == CompilationLevel .PIECEWISE
86- and not self .vllm_config .model_config .enforce_eager
87- and not self .speculative_config .enforce_eager
88- )
88+ self .use_cuda_graph = False
89+
90+ compilation_config = self .vllm_config .compilation_config
91+ if compilation_config .level == CompilationLevel .PIECEWISE :
92+ cudagraph_mode = compilation_config .cudagraph_mode
93+ if cudagraph_mode != CUDAGraphMode .NONE and not cudagraph_mode .has_mode (
94+ CUDAGraphMode .PIECEWISE
95+ ):
96+ logger .warning (
97+ "Currently the eagle proposer only supports cudagraph_mode "
98+ "PIECEWISE, if you want the drafter to use cuda graphs, "
99+ "please set compilation_config.cudagraph_mode to PIECEWISE "
100+ "or FULL_AND_PIECEWISE"
101+ )
102+ self .use_cuda_graph = (
103+ cudagraph_mode .has_mode (CUDAGraphMode .PIECEWISE )
104+ and not self .speculative_config .enforce_eager
105+ )
106+
89107 self .cudagraph_batch_sizes = (
90108 list (reversed (self .vllm_config .compilation_config .cudagraph_capture_sizes ))
91109 if self .use_cuda_graph
@@ -239,12 +257,15 @@ def propose(
239257 per_layer_attn_metadata = {}
240258 for layer_name in self .attn_layer_names :
241259 per_layer_attn_metadata [layer_name ] = attn_metadata
260+
242261 for layer_name in self .indexer_layer_names :
243262 assert draft_indexer_metadata is not None
244263 per_layer_attn_metadata [layer_name ] = draft_indexer_metadata
245264
265+ cudagraph_runtime_mode = CUDAGraphMode .NONE
246266 if self .use_cuda_graph and num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
247267 num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
268+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
248269 else :
249270 num_input_tokens = num_tokens
250271 # copy inputs to buffer for cudagraph
@@ -267,7 +288,10 @@ def propose(
267288 inputs_embeds = None
268289
269290 with set_forward_context (
270- per_layer_attn_metadata , self .vllm_config , num_tokens = num_input_tokens
291+ per_layer_attn_metadata ,
292+ self .vllm_config ,
293+ num_tokens = num_input_tokens ,
294+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
271295 ):
272296 ret_hidden_states = self .model (
273297 input_ids = input_ids ,
@@ -326,8 +350,10 @@ def propose(
326350
327351 if self .use_cuda_graph and batch_size <= self .cudagraph_batch_sizes [- 1 ]:
328352 input_batch_size = self .vllm_config .pad_for_cudagraph (batch_size )
353+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
329354 else :
330355 input_batch_size = batch_size
356+ cudagraph_runtime_mode = CUDAGraphMode .NONE
331357
332358 common_attn_metadata .num_actual_tokens = batch_size
333359 common_attn_metadata .max_query_len = 1
@@ -424,7 +450,10 @@ def propose(
424450
425451 # Run the model.
426452 with set_forward_context (
427- per_layer_attn_metadata , self .vllm_config , num_tokens = input_batch_size
453+ per_layer_attn_metadata ,
454+ self .vllm_config ,
455+ num_tokens = input_batch_size ,
456+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
428457 ):
429458 ret_hidden_states = self .model (
430459 input_ids = input_ids ,
@@ -731,11 +760,16 @@ def propose_tree(
731760
732761 if self .use_cuda_graph and num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
733762 num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
763+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
734764 else :
735765 num_input_tokens = num_tokens
766+ cudagraph_runtime_mode = CUDAGraphMode .NONE
736767 # Run the model.
737768 with set_forward_context (
738- per_layer_attn_metadata , self .vllm_config , num_tokens = num_input_tokens
769+ per_layer_attn_metadata ,
770+ self .vllm_config ,
771+ num_tokens = num_input_tokens ,
772+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
739773 ):
740774 last_hidden_states , hidden_states = self .model (
741775 input_ids = self .input_ids [:num_input_tokens ],
@@ -1015,8 +1049,19 @@ def load_model(self, target_model: nn.Module) -> None:
10151049 def dummy_run (
10161050 self ,
10171051 num_tokens : int ,
1052+ use_cudagraphs = True ,
10181053 ) -> None :
1019- with set_forward_context (None , self .vllm_config , num_tokens = num_tokens ):
1054+ if use_cudagraphs and num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
1055+ num_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
1056+
1057+ with set_forward_context (
1058+ None ,
1059+ self .vllm_config ,
1060+ num_tokens = num_tokens ,
1061+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
1062+ if use_cudagraphs
1063+ else CUDAGraphMode .NONE ,
1064+ ):
10201065 if self .supports_mm_inputs :
10211066 input_ids = None
10221067 inputs_embeds = self .inputs_embeds [:num_tokens ]
0 commit comments