|
9 | 9 | from vllm.logger import init_logger |
10 | 10 | from vllm.model_executor.model_loader.loader import get_model_loader |
11 | 11 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype |
| 12 | +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM |
12 | 13 | from vllm.model_executor.models import ModelRegistry |
13 | 14 | from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata |
14 | 15 | from vllm.v1.sample.metadata import SamplingMetadata |
@@ -39,11 +40,10 @@ def __init__( |
39 | 40 |
|
40 | 41 | self.hidden_size = vllm_config.model_config.get_hidden_size() |
41 | 42 |
|
42 | | - # TODO: make eagle3 compatible with cudagraph |
43 | | - self.use_cuda_graph = self.method != 'eagle3' and \ |
44 | | - (self.vllm_config.compilation_config.level |
45 | | - == CompilationLevel.PIECEWISE and |
46 | | - not self.vllm_config.model_config.enforce_eager) |
| 43 | + self.use_cuda_graph = ( |
| 44 | + self.vllm_config.compilation_config.level |
| 45 | + == CompilationLevel.PIECEWISE and |
| 46 | + not self.vllm_config.model_config.enforce_eager) |
47 | 47 |
|
48 | 48 | self.cudagraph_batch_sizes = list( |
49 | 49 | reversed( |
@@ -90,6 +90,12 @@ def propose( |
90 | 90 | batch_size = next_token_ids.shape[0] |
91 | 91 | last_token_indices = cu_num_tokens[1:] - 1 |
92 | 92 |
|
| 93 | + if self.method == "eagle3": |
| 94 | + assert isinstance(self.model, Eagle3LlamaForCausalLM) |
| 95 | + target_hidden_states = self.model.combine_hidden_states( |
| 96 | + target_hidden_states) |
| 97 | + assert target_hidden_states.shape[-1] == self.hidden_size |
| 98 | + |
93 | 99 | # Shift the input ids by one token. |
94 | 100 | # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] |
95 | 101 | self.input_ids[:num_tokens - 1] = target_token_ids[1:] |
@@ -126,12 +132,8 @@ def propose( |
126 | 132 | # copy inputs to buffer for cudagraph |
127 | 133 | self.positions[:num_tokens] = target_positions |
128 | 134 |
|
129 | | - if self.method == 'eagle': |
130 | | - self.hidden_states[:num_tokens] = target_hidden_states |
131 | | - hidden_states = self.hidden_states |
132 | | - else: |
133 | | - # TODO: make eagle3 compatible with cuda graph |
134 | | - hidden_states = target_hidden_states |
| 135 | + self.hidden_states[:num_tokens] = target_hidden_states |
| 136 | + hidden_states = self.hidden_states |
135 | 137 |
|
136 | 138 | with set_forward_context(attn_metadata, |
137 | 139 | self.vllm_config, |
@@ -209,10 +211,8 @@ def propose( |
209 | 211 | self.input_ids[:batch_size] = input_ids |
210 | 212 | self.positions[:batch_size] = clamped_positions |
211 | 213 |
|
212 | | - if self.method == 'eagle': |
213 | | - # TODO: make eagle3 compatible with cudagraph. |
214 | | - self.hidden_states[:batch_size] = hidden_states |
215 | | - hidden_states = self.hidden_states |
| 214 | + self.hidden_states[:batch_size] = hidden_states |
| 215 | + hidden_states = self.hidden_states |
216 | 216 |
|
217 | 217 | # Run the model. |
218 | 218 | with set_forward_context(attn_metadata, |
|
0 commit comments