99import  torch 
1010import  torch .nn  as  nn 
1111
12- from  vllm .config  import  CompilationLevel , VllmConfig , get_layers_from_vllm_config 
12+ from  vllm .config  import  (
13+     CompilationLevel ,
14+     CUDAGraphMode ,
15+     VllmConfig ,
16+     get_layers_from_vllm_config ,
17+ )
1318from  vllm .distributed .parallel_state  import  get_pp_group 
1419from  vllm .forward_context  import  set_forward_context 
1520from  vllm .logger  import  init_logger 
@@ -80,12 +85,25 @@ def __init__(
8085        self .attn_layer_names : list [str ] =  []
8186        self .indexer_layer_names : list [str ] =  []
8287
83-         self .use_cuda_graph  =  (
84-             not  current_platform .is_xpu ()
85-             and  self .vllm_config .compilation_config .level  ==  CompilationLevel .PIECEWISE 
86-             and  not  self .vllm_config .model_config .enforce_eager 
87-             and  not  self .speculative_config .enforce_eager 
88-         )
88+         self .use_cuda_graph  =  False 
89+ 
90+         compilation_config  =  self .vllm_config .compilation_config 
91+         if  compilation_config .level  ==  CompilationLevel .PIECEWISE :
92+             cudagraph_mode  =  compilation_config .cudagraph_mode 
93+             if  cudagraph_mode  !=  CUDAGraphMode .NONE  and  not  cudagraph_mode .has_mode (
94+                 CUDAGraphMode .PIECEWISE 
95+             ):
96+                 logger .warning (
97+                     "Currently the eagle proposer only supports cudagraph_mode " 
98+                     "PIECEWISE, if you want the drafter to use cuda graphs, " 
99+                     "please set compilation_config.cudagraph_mode to PIECEWISE " 
100+                     "or FULL_AND_PIECEWISE" 
101+                 )
102+             self .use_cuda_graph  =  (
103+                 cudagraph_mode .has_mode (CUDAGraphMode .PIECEWISE )
104+                 and  not  self .speculative_config .enforce_eager 
105+             )
106+ 
89107        self .cudagraph_batch_sizes  =  (
90108            list (reversed (self .vllm_config .compilation_config .cudagraph_capture_sizes ))
91109            if  self .use_cuda_graph 
@@ -239,12 +257,15 @@ def propose(
239257        per_layer_attn_metadata  =  {}
240258        for  layer_name  in  self .attn_layer_names :
241259            per_layer_attn_metadata [layer_name ] =  attn_metadata 
260+ 
242261        for  layer_name  in  self .indexer_layer_names :
243262            assert  draft_indexer_metadata  is  not   None 
244263            per_layer_attn_metadata [layer_name ] =  draft_indexer_metadata 
245264
265+         cudagraph_runtime_mode  =  CUDAGraphMode .NONE 
246266        if  self .use_cuda_graph  and  num_tokens  <=  self .cudagraph_batch_sizes [- 1 ]:
247267            num_input_tokens  =  self .vllm_config .pad_for_cudagraph (num_tokens )
268+             cudagraph_runtime_mode  =  CUDAGraphMode .PIECEWISE 
248269        else :
249270            num_input_tokens  =  num_tokens 
250271        # copy inputs to buffer for cudagraph 
@@ -267,7 +288,10 @@ def propose(
267288            inputs_embeds  =  None 
268289
269290        with  set_forward_context (
270-             per_layer_attn_metadata , self .vllm_config , num_tokens = num_input_tokens 
291+             per_layer_attn_metadata ,
292+             self .vllm_config ,
293+             num_tokens = num_input_tokens ,
294+             cudagraph_runtime_mode = cudagraph_runtime_mode ,
271295        ):
272296            ret_hidden_states  =  self .model (
273297                input_ids = input_ids ,
@@ -326,8 +350,10 @@ def propose(
326350
327351        if  self .use_cuda_graph  and  batch_size  <=  self .cudagraph_batch_sizes [- 1 ]:
328352            input_batch_size  =  self .vllm_config .pad_for_cudagraph (batch_size )
353+             cudagraph_runtime_mode  =  CUDAGraphMode .PIECEWISE 
329354        else :
330355            input_batch_size  =  batch_size 
356+             cudagraph_runtime_mode  =  CUDAGraphMode .NONE 
331357
332358        common_attn_metadata .num_actual_tokens  =  batch_size 
333359        common_attn_metadata .max_query_len  =  1 
@@ -424,7 +450,10 @@ def propose(
424450
425451            # Run the model. 
426452            with  set_forward_context (
427-                 per_layer_attn_metadata , self .vllm_config , num_tokens = input_batch_size 
453+                 per_layer_attn_metadata ,
454+                 self .vllm_config ,
455+                 num_tokens = input_batch_size ,
456+                 cudagraph_runtime_mode = cudagraph_runtime_mode ,
428457            ):
429458                ret_hidden_states  =  self .model (
430459                    input_ids = input_ids ,
@@ -731,11 +760,16 @@ def propose_tree(
731760
732761            if  self .use_cuda_graph  and  num_tokens  <=  self .cudagraph_batch_sizes [- 1 ]:
733762                num_input_tokens  =  self .vllm_config .pad_for_cudagraph (num_tokens )
763+                 cudagraph_runtime_mode  =  CUDAGraphMode .PIECEWISE 
734764            else :
735765                num_input_tokens  =  num_tokens 
766+                 cudagraph_runtime_mode  =  CUDAGraphMode .NONE 
736767            # Run the model. 
737768            with  set_forward_context (
738-                 per_layer_attn_metadata , self .vllm_config , num_tokens = num_input_tokens 
769+                 per_layer_attn_metadata ,
770+                 self .vllm_config ,
771+                 num_tokens = num_input_tokens ,
772+                 cudagraph_runtime_mode = cudagraph_runtime_mode ,
739773            ):
740774                last_hidden_states , hidden_states  =  self .model (
741775                    input_ids = self .input_ids [:num_input_tokens ],
@@ -1015,8 +1049,19 @@ def load_model(self, target_model: nn.Module) -> None:
10151049    def  dummy_run (
10161050        self ,
10171051        num_tokens : int ,
1052+         use_cudagraphs = True ,
10181053    ) ->  None :
1019-         with  set_forward_context (None , self .vllm_config , num_tokens = num_tokens ):
1054+         if  use_cudagraphs  and  num_tokens  <=  self .cudagraph_batch_sizes [- 1 ]:
1055+             num_tokens  =  self .vllm_config .pad_for_cudagraph (num_tokens )
1056+ 
1057+         with  set_forward_context (
1058+             None ,
1059+             self .vllm_config ,
1060+             num_tokens = num_tokens ,
1061+             cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE 
1062+             if  use_cudagraphs 
1063+             else  CUDAGraphMode .NONE ,
1064+         ):
10201065            if  self .supports_mm_inputs :
10211066                input_ids  =  None 
10221067                inputs_embeds  =  self .inputs_embeds [:num_tokens ]
0 commit comments