Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@

# yapf: enable

logger = init_logger(__name__)

if TYPE_CHECKING:
from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization import QuantizationMethods
Expand Down Expand Up @@ -1079,20 +1081,29 @@ def create_speculative_config(
target_parallel_config: ParallelConfig,
enable_chunked_prefill: bool,
disable_log_stats: bool,
) -> Optional["SpeculativeConfig"]:
) -> tuple[ModelConfig, Optional["SpeculativeConfig"]]:
"""Initializes and returns a SpeculativeConfig object based on
`speculative_config`.

This function utilizes `speculative_config` to create a
SpeculativeConfig object. The `speculative_config` can either be
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine.

Returns:
A tuple of (possibly updated model_config, speculative_config).
If a speculators model is detected, model_config is updated to
point to the target model and speculative_config is configured
with the draft model.
"""
from dataclasses import replace

from vllm.transformers_utils.config import get_config
from vllm.transformers_utils.configs.speculators.base import (
SpeculatorsConfig)

updated_model_config = target_model_config

if self.speculative_config is None:
hf_config = get_config(
self.hf_config_path or target_model_config.model,
Expand All @@ -1103,25 +1114,64 @@ def create_speculative_config(
# details from the config directly
# no user input required / expected
if isinstance(hf_config, SpeculatorsConfig):
# We create one since we don't create one
self.speculative_config = {}
self.speculative_config[
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
self.speculative_config["model"] = target_model_config.model
self.speculative_config["method"] = hf_config.method
# Get the complete vLLM config with algorithm-specific fields
try:
config_dict, _ = SpeculatorsConfig.get_config_dict(
target_model_config.model)
vllm_config = SpeculatorsConfig.get_vllm_config(
config_dict)
except Exception as e:
raise ValueError(
f"Failed to process speculators model "
f"'{target_model_config.model}': {e}") from e

# Update model config to point to actual target model
updated_model_config = replace(
target_model_config, model=vllm_config["target_model"])

# Set up speculative config with original speculators model
# as draft
self.speculative_config = {
"model":
target_model_config.model, # Original speculators model
"num_speculative_tokens":
vllm_config["num_lookahead_tokens"],
"method": vllm_config["method"]
}

# Add all algorithm-specific fields
for key, value in vllm_config.items():
if key not in [
"target_model", "num_lookahead_tokens", "method"
]:
self.speculative_config[key] = value

logger.info(
"Detected speculators model. Using target model: %s",
vllm_config['target_model'])
logger.info(
"Speculative config: %s", {
k: v
for k, v in self.speculative_config.items() if k not in
["target_model_config", "target_parallel_config"]
})
else:
return None
return updated_model_config, None

if self.speculative_config is None:
return updated_model_config, None

# Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine
# config.
self.speculative_config.update({
"target_model_config": target_model_config,
"target_model_config": updated_model_config,
"target_parallel_config": target_parallel_config,
"enable_chunked_prefill": enable_chunked_prefill,
"disable_log_stats": disable_log_stats,
})
return SpeculativeConfig(**self.speculative_config)
return updated_model_config, SpeculativeConfig(
**self.speculative_config)

def create_engine_config(
self,
Expand Down Expand Up @@ -1363,7 +1413,7 @@ def create_engine_config(
decode_context_parallel_size=self.decode_context_parallel_size,
)

speculative_config = self.create_speculative_config(
model_config, speculative_config = self.create_speculative_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
enable_chunked_prefill=self.enable_chunked_prefill,
Expand Down
41 changes: 25 additions & 16 deletions vllm/transformers_utils/configs/speculators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,7 @@ def from_pretrained(
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
**kwargs)

speculators_model_type = config_dict.get("speculators_model_type")
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
raise ValueError(
f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. "
"Please ensure you're loading a speculators-format model.")

# validate fields
# TODO: @dsikka - use speculators pydantic model to validate
cls.validate_speculators_config(config_dict=config_dict)
# Convert from speculators config -> format that can be ingested by vLLM
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
vllm_config = cls.get_vllm_config(config_dict=config_dict)
return cls(**vllm_config)

@classmethod
Expand All @@ -59,15 +46,37 @@ def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
raise TypeError(
"'transformer_layer_config' must be a dictionary if provided")

@classmethod
def get_vllm_config(cls, config_dict: dict[str, Any]) -> dict[str, Any]:
"""
Validate and convert speculators config dict to vLLM format.

This method includes algorithm-specific processing and validation.
"""
speculators_model_type = config_dict.get("speculators_model_type")
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
raise ValueError(
f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. "
"Please ensure you're loading a speculators-format model.")

# validate fields
cls.validate_speculators_config(config_dict=config_dict)
# Convert from speculators config -> format that can be ingested by vLLM
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
return vllm_config

@classmethod
def convert_speculators_to_vllm(
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
"""
Convert speculators config format to vLLM format.

This method handles the translation of field names and structure
between speculators and vLLM formats.

Returns:
Dictionary with vLLM-compatible configuration
"""
Expand Down