@@ -2736,14 +2736,14 @@ def load_model(self, eep_scale_up: bool = False) -> None:
27362736
27372737 # Try to get auxiliary layers from speculative config,
27382738 # otherwise use model's default layers
2739- aux_layers = (self ._get_eagle3_aux_layers_from_config () or
2740- self .model .get_eagle3_aux_hidden_state_layers ())
2741-
2742- if aux_layers != self .model .get_eagle3_aux_hidden_state_layers (
2743- ):
2739+ aux_layers = self ._get_eagle3_aux_layers_from_config ()
2740+ if aux_layers :
27442741 logger .info (
27452742 "Using auxiliary layers from speculative config: %s" ,
27462743 aux_layers )
2744+ else :
2745+ aux_layers = self .model .get_eagle3_aux_hidden_state_layers (
2746+ )
27472747
27482748 self .model .set_aux_hidden_state_layers (aux_layers )
27492749 time_after_load = time .perf_counter ()
@@ -2797,7 +2797,11 @@ def load_model(self, eep_scale_up: bool = False) -> None:
27972797 CUDAGraphMode .NONE , self .device )
27982798
27992799 def _get_eagle3_aux_layers_from_config (self ) -> Optional [tuple [int , ...]]:
2800- """Extract Eagle3 auxiliary layer IDs from speculative config.
2800+ """Extract Eagle3 auxiliary layer indices from speculative config.
2801+
2802+ These indices specify which hidden states from the base model should
2803+ be used as auxiliary inputs for the Eagle3 drafter model during
2804+ speculative decoding.
28012805
28022806 Returns:
28032807 Tuple of layer indices if found in draft model config,
@@ -2807,18 +2811,13 @@ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
28072811 and self .speculative_config .draft_model_config ):
28082812 return None
28092813
2810- try :
2811- hf_config = self .speculative_config .draft_model_config .hf_config
2812- if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
2813- return None
2814-
2815- layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
2816- if layer_ids and isinstance (layer_ids , (list , tuple )):
2817- return tuple (layer_ids )
2818- except Exception as e :
2819- logger .warning (
2820- "Failed to read auxiliary layers from speculative config: %s" ,
2821- e )
2814+ hf_config = self .speculative_config .draft_model_config .hf_config
2815+ if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
2816+ return None
2817+
2818+ layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
2819+ if layer_ids and isinstance (layer_ids , (list , tuple )):
2820+ return tuple (layer_ids )
28222821
28232822 return None
28242823
0 commit comments