Skip to content

Commit 36e2a87

Browse files
PapaGooseamd-xiaoyu12
authored andcommitted
[Speculators][Speculative Decoding] Fix Qwen 2 Eagle3 Support (vllm-project#23337)
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
1 parent 6568a8f commit 36e2a87

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/model_executor/models/qwen2.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from vllm.sequence import IntermediateTensors
5353
from vllm.transformers_utils.config import is_interleaved
5454

55-
from .interfaces import SupportsLoRA, SupportsPP
55+
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
5656
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
5757
is_pp_missing_parameter,
5858
make_empty_intermediate_tensors_factory, make_layers,
@@ -442,7 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str,
442442
return loaded_params
443443

444444

445-
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
445+
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
446446
packed_modules_mapping = {
447447
"qkv_proj": [
448448
"q_proj",
@@ -488,6 +488,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
488488
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
489489
return self.model.get_input_embeddings(input_ids)
490490

491+
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
492+
self.model.aux_hidden_state_layers = layers
493+
494+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
495+
num_layers = len(self.model.layers)
496+
return (2, num_layers // 2, num_layers - 3)
497+
491498
def forward(
492499
self,
493500
input_ids: torch.Tensor,

0 commit comments

Comments
 (0)