Skip to content

Commit 50d2ca6

Browse files
committed
Enable engine-level arguments with speculators models
This commit implements enhanced engine layer detection for speculators models, allowing users to apply engine arguments directly using simplified syntax: ```bash vllm serve --seed 42 --tensor-parallel-size 4 "speculators-model" ``` Instead of verbose explicit configuration: ```bash vllm serve --seed 42 --tensor-parallel-size 4 "target-model" \ --speculative-config '{"model": "speculators-model", "method": "eagle3", ...}' ``` ## Key Changes ### Enhanced Engine Layer (`vllm/engine/arg_utils.py`) - Modified `create_speculative_config()` to return tuple of (ModelConfig, SpeculativeConfig) - Added automatic speculators model detection at model creation time - Implemented proper model resolution: speculators model → target model - Engine arguments now correctly applied to target model instead of speculators model ### Complete Algorithm Processing (`vllm/transformers_utils/configs/speculators/base.py`) - Added `get_vllm_config()` method with full algorithm-specific processing - Includes Eagle3 fields like draft_vocab_size, target_hidden_size - Leverages existing validation and transformation infrastructure ## Benefits - ✅ Proper architectural layering (engine layer handles model configuration) - ✅ Complete algorithm-specific field processing - ✅ Backward compatibility (existing workflows unchanged) - ✅ Simplified user experience - ✅ Single point of truth for speculative model logic ## Testing - ✅ Speculators model: Auto-detection and target model resolution - ✅ Regular model: No regression, normal serving unaffected - ✅ Engine arguments correctly applied in both cases 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com>
1 parent 1b962e2 commit 50d2ca6

File tree

2 files changed

+86
-27
lines changed

2 files changed

+86
-27
lines changed

vllm/engine/arg_utils.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050

5151
# yapf: enable
5252

53+
logger = init_logger(__name__)
54+
5355
if TYPE_CHECKING:
5456
from vllm.executor.executor_base import ExecutorBase
5557
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -1079,20 +1081,29 @@ def create_speculative_config(
10791081
target_parallel_config: ParallelConfig,
10801082
enable_chunked_prefill: bool,
10811083
disable_log_stats: bool,
1082-
) -> Optional["SpeculativeConfig"]:
1084+
) -> tuple[ModelConfig, Optional["SpeculativeConfig"]]:
10831085
"""Initializes and returns a SpeculativeConfig object based on
10841086
`speculative_config`.
10851087
10861088
This function utilizes `speculative_config` to create a
10871089
SpeculativeConfig object. The `speculative_config` can either be
10881090
provided as a JSON string input via CLI arguments or directly as a
10891091
dictionary from the engine.
1092+
1093+
Returns:
1094+
A tuple of (possibly updated model_config, speculative_config).
1095+
If a speculators model is detected, model_config is updated to
1096+
point to the target model and speculative_config is configured
1097+
with the draft model.
10901098
"""
1099+
from dataclasses import replace
10911100

10921101
from vllm.transformers_utils.config import get_config
10931102
from vllm.transformers_utils.configs.speculators.base import (
10941103
SpeculatorsConfig)
10951104

1105+
updated_model_config = target_model_config
1106+
10961107
if self.speculative_config is None:
10971108
hf_config = get_config(
10981109
self.hf_config_path or target_model_config.model,
@@ -1103,25 +1114,64 @@ def create_speculative_config(
11031114
# details from the config directly
11041115
# no user input required / expected
11051116
if isinstance(hf_config, SpeculatorsConfig):
1106-
# We create one since we don't create one
1107-
self.speculative_config = {}
1108-
self.speculative_config[
1109-
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
1110-
self.speculative_config["model"] = target_model_config.model
1111-
self.speculative_config["method"] = hf_config.method
1117+
# Get the complete vLLM config with algorithm-specific fields
1118+
try:
1119+
config_dict, _ = SpeculatorsConfig.get_config_dict(
1120+
target_model_config.model)
1121+
vllm_config = SpeculatorsConfig.get_vllm_config(
1122+
config_dict)
1123+
except Exception as e:
1124+
raise ValueError(
1125+
f"Failed to process speculators model "
1126+
f"'{target_model_config.model}': {e}") from e
1127+
1128+
# Update model config to point to actual target model
1129+
updated_model_config = replace(
1130+
target_model_config, model=vllm_config["target_model"])
1131+
1132+
# Set up speculative config with original speculators model
1133+
# as draft
1134+
self.speculative_config = {
1135+
"model":
1136+
target_model_config.model, # Original speculators model
1137+
"num_speculative_tokens":
1138+
vllm_config["num_lookahead_tokens"],
1139+
"method": vllm_config["method"]
1140+
}
1141+
1142+
# Add all algorithm-specific fields
1143+
for key, value in vllm_config.items():
1144+
if key not in [
1145+
"target_model", "num_lookahead_tokens", "method"
1146+
]:
1147+
self.speculative_config[key] = value
1148+
1149+
logger.info(
1150+
"Detected speculators model. Using target model: %s",
1151+
vllm_config['target_model'])
1152+
logger.info(
1153+
"Speculative config: %s", {
1154+
k: v
1155+
for k, v in self.speculative_config.items() if k not in
1156+
["target_model_config", "target_parallel_config"]
1157+
})
11121158
else:
1113-
return None
1159+
return updated_model_config, None
1160+
1161+
if self.speculative_config is None:
1162+
return updated_model_config, None
11141163

11151164
# Note(Shangming): These parameters are not obtained from the cli arg
11161165
# '--speculative-config' and must be passed in when creating the engine
11171166
# config.
11181167
self.speculative_config.update({
1119-
"target_model_config": target_model_config,
1168+
"target_model_config": updated_model_config,
11201169
"target_parallel_config": target_parallel_config,
11211170
"enable_chunked_prefill": enable_chunked_prefill,
11221171
"disable_log_stats": disable_log_stats,
11231172
})
1124-
return SpeculativeConfig(**self.speculative_config)
1173+
return updated_model_config, SpeculativeConfig(
1174+
**self.speculative_config)
11251175

11261176
def create_engine_config(
11271177
self,
@@ -1363,7 +1413,7 @@ def create_engine_config(
13631413
decode_context_parallel_size=self.decode_context_parallel_size,
13641414
)
13651415

1366-
speculative_config = self.create_speculative_config(
1416+
model_config, speculative_config = self.create_speculative_config(
13671417
target_model_config=model_config,
13681418
target_parallel_config=parallel_config,
13691419
enable_chunked_prefill=self.enable_chunked_prefill,

vllm/transformers_utils/configs/speculators/base.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,7 @@ def from_pretrained(
2424
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
2525
**kwargs)
2626

27-
speculators_model_type = config_dict.get("speculators_model_type")
28-
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
29-
raise ValueError(
30-
f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. "
31-
"Please ensure you're loading a speculators-format model.")
32-
33-
# validate fields
34-
# TODO: @dsikka - use speculators pydantic model to validate
35-
cls.validate_speculators_config(config_dict=config_dict)
36-
# Convert from speculators config -> format that can be ingested by vLLM
37-
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
38-
# Apply anything specific to the supported algorithm
39-
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
40-
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
27+
vllm_config = cls.get_vllm_config(config_dict=config_dict)
4128
return cls(**vllm_config)
4229

4330
@classmethod
@@ -59,15 +46,37 @@ def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
5946
raise TypeError(
6047
"'transformer_layer_config' must be a dictionary if provided")
6148

49+
@classmethod
50+
def get_vllm_config(cls, config_dict: dict[str, Any]) -> dict[str, Any]:
51+
"""
52+
Validate and convert speculators config dict to vLLM format.
53+
54+
This method includes algorithm-specific processing and validation.
55+
"""
56+
speculators_model_type = config_dict.get("speculators_model_type")
57+
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
58+
raise ValueError(
59+
f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. "
60+
"Please ensure you're loading a speculators-format model.")
61+
62+
# validate fields
63+
cls.validate_speculators_config(config_dict=config_dict)
64+
# Convert from speculators config -> format that can be ingested by vLLM
65+
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
66+
# Apply anything specific to the supported algorithm
67+
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
68+
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
69+
return vllm_config
70+
6271
@classmethod
6372
def convert_speculators_to_vllm(
6473
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
6574
"""
6675
Convert speculators config format to vLLM format.
67-
76+
6877
This method handles the translation of field names and structure
6978
between speculators and vLLM formats.
70-
79+
7180
Returns:
7281
Dictionary with vLLM-compatible configuration
7382
"""

0 commit comments

Comments
 (0)