Skip to content

Commit 58ed49c

Browse files
committed
[SpecDecode] Support EAGLE for Qwen3 MoE
Signed-off-by: seven-mile <i@7li.moe>
1 parent d3d649e commit 58ed49c

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

vllm/model_executor/models/qwen3_moe.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@
5454
from vllm.model_executor.models.utils import sequence_parallel_chunk
5555
from vllm.sequence import IntermediateTensors
5656

57-
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
57+
from .interfaces import (MixtureOfExperts, SupportsEagle3, SupportsLoRA,
58+
SupportsPP)
5859
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
5960
is_pp_missing_parameter,
6061
make_empty_intermediate_tensors_factory, make_layers,
@@ -396,6 +397,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
396397
self.make_empty_intermediate_tensors = (
397398
make_empty_intermediate_tensors_factory(
398399
["hidden_states", "residual"], config.hidden_size))
400+
self.aux_hidden_state_layers = tuple[int, ...]()
399401

400402
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
401403
return self.embed_tokens(input_ids)
@@ -417,14 +419,25 @@ def forward(
417419
assert intermediate_tensors is not None
418420
hidden_states = intermediate_tensors["hidden_states"]
419421
residual = intermediate_tensors["residual"]
420-
for layer in islice(self.layers, self.start_layer, self.end_layer):
422+
423+
aux_hidden_states = []
424+
for idx, layer in enumerate(
425+
islice(self.layers, self.start_layer, self.end_layer)):
426+
if idx in self.aux_hidden_state_layers:
427+
aux_hidden_states.append(hidden_states + residual)
421428
hidden_states, residual = layer(positions, hidden_states, residual)
429+
422430
if not get_pp_group().is_last_rank:
423431
return IntermediateTensors({
424432
"hidden_states": hidden_states,
425433
"residual": residual
426434
})
435+
427436
hidden_states, _ = self.norm(hidden_states, residual)
437+
438+
if len(aux_hidden_states) > 0:
439+
return hidden_states, aux_hidden_states
440+
428441
return hidden_states
429442

430443
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
@@ -568,7 +581,7 @@ def load_weights(self, weights: Iterable[tuple[str,
568581
return loaded_params
569582

570583

571-
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
584+
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3,
572585
MixtureOfExperts):
573586
packed_modules_mapping = {
574587
"qkv_proj": [
@@ -628,6 +641,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
628641
self.num_routed_experts = example_layer.n_routed_experts
629642
self.num_redundant_experts = example_layer.n_redundant_experts
630643

644+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
645+
self.model.aux_hidden_state_layers = layers
646+
647+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
648+
num_layers = len(self.model.layers)
649+
return (2, num_layers // 2, num_layers - 3)
650+
631651
def set_eplb_state(
632652
self,
633653
expert_load_view: torch.Tensor,

0 commit comments

Comments
 (0)