@@ -330,6 +330,8 @@ def __init__(self,
330330 else :
331331 self .norm = PPMissingLayer ()
332332
333+ self .aux_hidden_state_layers : tuple [int ] = tuple ()
334+
333335 self .make_empty_intermediate_tensors = (
334336 make_empty_intermediate_tensors_factory (
335337 ["hidden_states" , "residual" ], config .hidden_size ))
@@ -355,7 +357,11 @@ def forward(
355357 hidden_states = intermediate_tensors ["hidden_states" ]
356358 residual = intermediate_tensors ["residual" ]
357359
358- for layer in self .layers [self .start_layer :self .end_layer ]:
360+ aux_hidden_states = []
361+ for idx , layer in enumerate (
362+ self .layers [self .start_layer :self .end_layer ]):
363+ if idx in self .aux_hidden_state_layers :
364+ aux_hidden_states .append (hidden_states + residual )
359365 hidden_states , residual = layer (positions , hidden_states , residual )
360366
361367 if not get_pp_group ().is_last_rank :
@@ -365,6 +371,9 @@ def forward(
365371 })
366372
367373 hidden_states , _ = self .norm (hidden_states , residual )
374+
375+ if len (aux_hidden_states ) > 0 :
376+ return hidden_states , aux_hidden_states
368377 return hidden_states
369378
370379 def load_weights (self , weights : Iterable [Tuple [str ,
@@ -517,6 +526,13 @@ def __init__(self,
517526 self .make_empty_intermediate_tensors = (
518527 self .model .make_empty_intermediate_tensors )
519528
529+ def set_aux_hidden_state_layers (self , layers : tuple [int ]) -> None :
530+ self .model .aux_hidden_state_layers = layers
531+
532+ def get_eagle3_aux_hidden_state_layers (self ) -> tuple [int ]:
533+ num_layers = len (self .model .layers )
534+ return (2 , num_layers // 2 , num_layers - 3 )
535+
520536 def _init_model (self ,
521537 vllm_config : VllmConfig ,
522538 prefix : str = "" ,
0 commit comments