5454from  vllm .model_executor .models .utils  import  sequence_parallel_chunk 
5555from  vllm .sequence  import  IntermediateTensors 
5656
57- from  .interfaces  import  MixtureOfExperts , SupportsLoRA , SupportsPP 
57+ from  .interfaces  import  (MixtureOfExperts , SupportsEagle3 , SupportsLoRA ,
58+                          SupportsPP )
5859from  .utils  import  (AutoWeightsLoader , PPMissingLayer , extract_layer_index ,
5960                    is_pp_missing_parameter ,
6061                    make_empty_intermediate_tensors_factory , make_layers ,
@@ -396,6 +397,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
396397        self .make_empty_intermediate_tensors  =  (
397398            make_empty_intermediate_tensors_factory (
398399                ["hidden_states" , "residual" ], config .hidden_size ))
400+         self .aux_hidden_state_layers  =  tuple [int , ...]()
399401
400402    def  get_input_embeddings (self , input_ids : torch .Tensor ) ->  torch .Tensor :
401403        return  self .embed_tokens (input_ids )
@@ -417,14 +419,25 @@ def forward(
417419            assert  intermediate_tensors  is  not   None 
418420            hidden_states  =  intermediate_tensors ["hidden_states" ]
419421            residual  =  intermediate_tensors ["residual" ]
420-         for  layer  in  islice (self .layers , self .start_layer , self .end_layer ):
422+ 
423+         aux_hidden_states  =  []
424+         for  idx , layer  in  enumerate (
425+                 islice (self .layers , self .start_layer , self .end_layer )):
426+             if  idx  in  self .aux_hidden_state_layers :
427+                 aux_hidden_states .append (hidden_states  +  residual )
421428            hidden_states , residual  =  layer (positions , hidden_states , residual )
429+ 
422430        if  not  get_pp_group ().is_last_rank :
423431            return  IntermediateTensors ({
424432                "hidden_states" : hidden_states ,
425433                "residual" : residual 
426434            })
435+ 
427436        hidden_states , _  =  self .norm (hidden_states , residual )
437+ 
438+         if  len (aux_hidden_states ) >  0 :
439+             return  hidden_states , aux_hidden_states 
440+ 
428441        return  hidden_states 
429442
430443    def  get_expert_mapping (self ) ->  list [tuple [str , str , int , str ]]:
@@ -568,7 +581,7 @@ def load_weights(self, weights: Iterable[tuple[str,
568581        return  loaded_params 
569582
570583
571- class  Qwen3MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA ,
584+ class  Qwen3MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA ,  SupportsEagle3 , 
572585                          MixtureOfExperts ):
573586    packed_modules_mapping  =  {
574587        "qkv_proj" : [
@@ -628,6 +641,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
628641        self .num_routed_experts  =  example_layer .n_routed_experts 
629642        self .num_redundant_experts  =  example_layer .n_redundant_experts 
630643
644+     def  set_aux_hidden_state_layers (self , layers : tuple [int , ...]) ->  None :
645+         self .model .aux_hidden_state_layers  =  layers 
646+ 
647+     def  get_eagle3_aux_hidden_state_layers (self ) ->  tuple [int , ...]:
648+         num_layers  =  len (self .model .layers )
649+         return  (2 , num_layers  //  2 , num_layers  -  3 )
650+ 
631651    def  set_eplb_state (
632652        self ,
633653        expert_load_view : torch .Tensor ,
0 commit comments