diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4831cb5348c7..2f92e02f01cd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -50,6 +50,8 @@ # yapf: enable +logger = init_logger(__name__) + if TYPE_CHECKING: from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.layers.quantization import QuantizationMethods @@ -1079,7 +1081,7 @@ def create_speculative_config( target_parallel_config: ParallelConfig, enable_chunked_prefill: bool, disable_log_stats: bool, - ) -> Optional["SpeculativeConfig"]: + ) -> tuple[ModelConfig, Optional["SpeculativeConfig"]]: """Initializes and returns a SpeculativeConfig object based on `speculative_config`. @@ -1087,12 +1089,21 @@ def create_speculative_config( SpeculativeConfig object. The `speculative_config` can either be provided as a JSON string input via CLI arguments or directly as a dictionary from the engine. + + Returns: + A tuple of (possibly updated model_config, speculative_config). + If a speculators model is detected, model_config is updated to + point to the target model and speculative_config is configured + with the draft model. """ + from dataclasses import replace from vllm.transformers_utils.config import get_config from vllm.transformers_utils.configs.speculators.base import ( SpeculatorsConfig) + updated_model_config = target_model_config + if self.speculative_config is None: hf_config = get_config( self.hf_config_path or target_model_config.model, @@ -1103,25 +1114,64 @@ def create_speculative_config( # details from the config directly # no user input required / expected if isinstance(hf_config, SpeculatorsConfig): - # We create one since we don't create one - self.speculative_config = {} - self.speculative_config[ - "num_speculative_tokens"] = hf_config.num_lookahead_tokens - self.speculative_config["model"] = target_model_config.model - self.speculative_config["method"] = hf_config.method + # Get the complete vLLM config with algorithm-specific fields + try: + config_dict, _ = SpeculatorsConfig.get_config_dict( + target_model_config.model) + vllm_config = SpeculatorsConfig.get_vllm_config( + config_dict) + except Exception as e: + raise ValueError( + f"Failed to process speculators model " + f"'{target_model_config.model}': {e}") from e + + # Update model config to point to actual target model + updated_model_config = replace( + target_model_config, model=vllm_config["target_model"]) + + # Set up speculative config with original speculators model + # as draft + self.speculative_config = { + "model": + target_model_config.model, # Original speculators model + "num_speculative_tokens": + vllm_config["num_lookahead_tokens"], + "method": vllm_config["method"] + } + + # Add all algorithm-specific fields + for key, value in vllm_config.items(): + if key not in [ + "target_model", "num_lookahead_tokens", "method" + ]: + self.speculative_config[key] = value + + logger.info( + "Detected speculators model. Using target model: %s", + vllm_config['target_model']) + logger.info( + "Speculative config: %s", { + k: v + for k, v in self.speculative_config.items() if k not in + ["target_model_config", "target_parallel_config"] + }) else: - return None + return updated_model_config, None + + if self.speculative_config is None: + return updated_model_config, None # Note(Shangming): These parameters are not obtained from the cli arg # '--speculative-config' and must be passed in when creating the engine # config. self.speculative_config.update({ - "target_model_config": target_model_config, + "target_model_config": updated_model_config, "target_parallel_config": target_parallel_config, "enable_chunked_prefill": enable_chunked_prefill, "disable_log_stats": disable_log_stats, }) - return SpeculativeConfig(**self.speculative_config) + return updated_model_config, SpeculativeConfig( + **self.speculative_config) def create_engine_config( self, @@ -1363,7 +1413,7 @@ def create_engine_config( decode_context_parallel_size=self.decode_context_parallel_size, ) - speculative_config = self.create_speculative_config( + model_config, speculative_config = self.create_speculative_config( target_model_config=model_config, target_parallel_config=parallel_config, enable_chunked_prefill=self.enable_chunked_prefill, diff --git a/vllm/transformers_utils/configs/speculators/base.py b/vllm/transformers_utils/configs/speculators/base.py index d7c16e180c70..6ff761405f93 100644 --- a/vllm/transformers_utils/configs/speculators/base.py +++ b/vllm/transformers_utils/configs/speculators/base.py @@ -24,20 +24,7 @@ def from_pretrained( config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - speculators_model_type = config_dict.get("speculators_model_type") - if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES: - raise ValueError( - f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. " - "Please ensure you're loading a speculators-format model.") - - # validate fields - # TODO: @dsikka - use speculators pydantic model to validate - cls.validate_speculators_config(config_dict=config_dict) - # Convert from speculators config -> format that can be ingested by vLLM - vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict) - # Apply anything specific to the supported algorithm - algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type] - algo_updater(config_dict=config_dict, vllm_config=vllm_config) + vllm_config = cls.get_vllm_config(config_dict=config_dict) return cls(**vllm_config) @classmethod @@ -59,15 +46,37 @@ def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None: raise TypeError( "'transformer_layer_config' must be a dictionary if provided") + @classmethod + def get_vllm_config(cls, config_dict: dict[str, Any]) -> dict[str, Any]: + """ + Validate and convert speculators config dict to vLLM format. + + This method includes algorithm-specific processing and validation. + """ + speculators_model_type = config_dict.get("speculators_model_type") + if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES: + raise ValueError( + f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. " + "Please ensure you're loading a speculators-format model.") + + # validate fields + cls.validate_speculators_config(config_dict=config_dict) + # Convert from speculators config -> format that can be ingested by vLLM + vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict) + # Apply anything specific to the supported algorithm + algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type] + algo_updater(config_dict=config_dict, vllm_config=vllm_config) + return vllm_config + @classmethod def convert_speculators_to_vllm( cls, config_dict: dict[str, Any]) -> dict[str, Any]: """ Convert speculators config format to vLLM format. - + This method handles the translation of field names and structure between speculators and vLLM formats. - + Returns: Dictionary with vLLM-compatible configuration """