@@ -2951,19 +2951,14 @@ def load_model(self, eep_scale_up: bool = False) -> None:
29512951
29522952 # Try to get auxiliary layers from speculative config,
29532953 # otherwise use model's default layers
2954- aux_layers = (
2955- self ._get_eagle3_aux_layers_from_config ()
2956- or self .model .get_eagle3_aux_hidden_state_layers ()
2957- )
2958-
2959- if (
2960- aux_layers
2961- != self .model .get_eagle3_aux_hidden_state_layers ()
2962- ):
2954+ aux_layers = self ._get_eagle3_aux_layers_from_config ()
2955+ if aux_layers :
29632956 logger .info (
29642957 "Using auxiliary layers from speculative config: %s" ,
29652958 aux_layers ,
29662959 )
2960+ else :
2961+ aux_layers = self .model .get_eagle3_aux_hidden_state_layers ()
29672962
29682963 self .model .set_aux_hidden_state_layers (aux_layers )
29692964 time_after_load = time .perf_counter ()
@@ -3021,7 +3016,11 @@ def load_model(self, eep_scale_up: bool = False) -> None:
30213016 )
30223017
30233018 def _get_eagle3_aux_layers_from_config (self ) -> Optional [tuple [int , ...]]:
3024- """Extract Eagle3 auxiliary layer IDs from speculative config.
3019+ """Extract Eagle3 auxiliary layer indices from speculative config.
3020+
3021+ These indices specify which hidden states from the base model should
3022+ be used as auxiliary inputs for the Eagle3 drafter model during
3023+ speculative decoding.
30253024
30263025 Returns:
30273026 Tuple of layer indices if found in draft model config,
@@ -3031,18 +3030,13 @@ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
30313030 and self .speculative_config .draft_model_config ):
30323031 return None
30333032
3034- try :
3035- hf_config = self .speculative_config .draft_model_config .hf_config
3036- if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
3037- return None
3038-
3039- layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
3040- if layer_ids and isinstance (layer_ids , (list , tuple )):
3041- return tuple (layer_ids )
3042- except Exception as e :
3043- logger .warning (
3044- "Failed to read auxiliary layers from speculative config: %s" ,
3045- e )
3033+ hf_config = self .speculative_config .draft_model_config .hf_config
3034+ if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
3035+ return None
3036+
3037+ layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
3038+ if layer_ids and isinstance (layer_ids , (list , tuple )):
3039+ return tuple (layer_ids )
30463040
30473041 return None
30483042
0 commit comments