@@ -24,6 +24,12 @@ def from_pretrained(
2424 config_dict , _ = cls .get_config_dict (pretrained_model_name_or_path ,
2525 ** kwargs )
2626
27+ vllm_config = cls .extract_vllm_speculative_config (config_dict )
28+ return cls (** vllm_config )
29+
30+ @classmethod
31+ def extract_vllm_speculative_config (
32+ cls , config_dict : dict [str , Any ]) -> dict [str , Any ]:
2733 speculators_model_type = config_dict .get ("speculators_model_type" )
2834 if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES :
2935 raise ValueError (
@@ -34,11 +40,12 @@ def from_pretrained(
3440 # TODO: @dsikka - use speculators pydantic model to validate
3541 cls .validate_speculators_config (config_dict = config_dict )
3642 # Convert from speculators config -> format that can be ingested by vLLM
37- vllm_config = cls .convert_speculators_to_vllm (config_dict = config_dict )
43+ vllm_config = cls .build_vllm_speculative_config (
44+ config_dict = config_dict )
3845 # Apply anything specific to the supported algorithm
3946 algo_updater = SUPPORTED_SPECULATORS_TYPES [speculators_model_type ]
4047 algo_updater (config_dict = config_dict , vllm_config = vllm_config )
41- return cls ( ** vllm_config )
48+ return vllm_config
4249
4350 @classmethod
4451 def validate_speculators_config (cls , config_dict : dict [str , Any ]) -> None :
@@ -60,32 +67,45 @@ def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
6067 "'transformer_layer_config' must be a dictionary if provided" )
6168
6269 @classmethod
63- def convert_speculators_to_vllm (
70+ def build_vllm_speculative_config (
6471 cls , config_dict : dict [str , Any ]) -> dict [str , Any ]:
6572 """
66- Convert speculators config format to vLLM format.
67-
68- This method handles the translation of field names and structure
69- between speculators and vLLM formats.
70-
73+ Build vLLM-compatible speculative configuration from speculators format.
74+
75+ This method extracts and transforms speculative configuration from the
76+ speculators format into the structure expected by vLLM.
77+
78+ Args:
79+ config_dict: Configuration dictionary in speculators format
80+
7181 Returns:
72- Dictionary with vLLM-compatible configuration
82+ Dictionary with vLLM-compatible speculative configuration
7383 """
74- # Currently we only support one proposal method
84+ # Extract speculators configuration
7585 spec_config = config_dict ["speculators_config" ]
76- first_method = spec_config .get ("proposal_methods" )[0 ]
77- num_lookahead_tokens = first_method .get ("speculative_tokens" )
7886
79- if num_lookahead_tokens is None :
87+ # Currently we only support one proposal method
88+ proposal_methods = spec_config .get ("proposal_methods" )
89+ if not proposal_methods :
90+ raise ValueError ("No proposal methods found in speculators config" )
91+
92+ first_method = proposal_methods [0 ]
93+ num_speculative_tokens = first_method .get ("speculative_tokens" )
94+
95+ if num_speculative_tokens is None :
8096 raise ValueError (
8197 "Missing 'speculative_tokens' in proposal method. "
8298 f"Got: { first_method } " )
8399
84- # Build base vLLM config
100+ # Build base vLLM speculative configuration
85101 vllm_config = {
86102 "method" : config_dict .get ("speculators_model_type" ),
87- "num_lookahead_tokens " : num_lookahead_tokens ,
103+ "num_speculative_tokens " : num_speculative_tokens ,
88104 "target_model" : spec_config .get ("verifier" )["name_or_path" ]
89105 }
90- vllm_config .update (config_dict ["transformer_layer_config" ])
106+
107+ # Merge transformer layer configuration if present
108+ transformer_config = config_dict .get ("transformer_layer_config" , {})
109+ vllm_config .update (transformer_config )
110+
91111 return vllm_config
0 commit comments