|
52 | 52 | from vllm.sequence import IntermediateTensors |
53 | 53 | from vllm.transformers_utils.config import is_interleaved |
54 | 54 |
|
55 | | -from .interfaces import SupportsLoRA, SupportsPP |
| 55 | +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP |
56 | 56 | from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, |
57 | 57 | is_pp_missing_parameter, |
58 | 58 | make_empty_intermediate_tensors_factory, make_layers, |
@@ -442,7 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str, |
442 | 442 | return loaded_params |
443 | 443 |
|
444 | 444 |
|
445 | | -class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): |
| 445 | +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): |
446 | 446 | packed_modules_mapping = { |
447 | 447 | "qkv_proj": [ |
448 | 448 | "q_proj", |
@@ -488,6 +488,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
488 | 488 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
489 | 489 | return self.model.get_input_embeddings(input_ids) |
490 | 490 |
|
| 491 | + def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: |
| 492 | + self.model.aux_hidden_state_layers = layers |
| 493 | + |
| 494 | + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: |
| 495 | + num_layers = len(self.model.layers) |
| 496 | + return (2, num_layers // 2, num_layers - 3) |
| 497 | + |
491 | 498 | def forward( |
492 | 499 | self, |
493 | 500 | input_ids: torch.Tensor, |
|
0 commit comments