From 030f994efeecbdceedc0ed4b8a4a845bd471bfef Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Thu, 6 Mar 2025 21:41:16 +0800 Subject: [PATCH 01/21] Initial commit Signed-off-by: Shangming Cai --- vllm/config.py | 506 +++++++++++-------------- vllm/engine/arg_utils.py | 93 +++-- vllm/spec_decode/spec_decode_worker.py | 8 +- 3 files changed, 300 insertions(+), 307 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index a5d8ee9303d0..1c14dc7aaa8d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1768,12 +1768,107 @@ def __init__(self, device: str = "auto") -> None: self.device = torch.device(self.device_type) +@dataclass class SpeculativeConfig: - """Configuration for speculative decoding. - - The configuration is currently specialized to draft-model speculative - decoding with top-1 proposals. """ + Configuration for speculative decoding. + Configurable parameters include: + - Top-level Speculative Decoding Control: + - num_speculative_tokens (Optional[int]): The number of speculative + tokens, if provided. It will default to the number in the draft + model config if present, otherwise, it is required. + - proposer (Optional[str]): The name of the speculative method to use. + Defaults to the model name if not provided. + - acceptance_method (str): The method to use for accepting draft + tokens. This can take two possible values: 'rejection_sampler' and + 'typical_acceptance_sampler' for RejectionSampler and + TypicalAcceptanceSampler respectively. If not specified, it + defaults to 'rejection_sampler'. + - disable_logprobs (bool): If set to True, token log probabilities are + not returned during speculative decoding. If set to False, token + log probabilities are returned according to the log probability + settings in SamplingParams. If not specified, it defaults to True. + + - Model Configuration: + - model (Optional[str]): The name of the speculative model, + if provided. + - quantization (Optional[str]): Quantization method that was used to + quantize the speculative model weights. If None, we assume the + model weights are not quantized. + - max_model_len (Optional[int]): The maximum model length of the + speculative model. Used when testing the ability to skip + speculation for some sequences. + - draft_tensor_parallel_size (Optional[int]): The degree of the tensor + parallelism for the draft model. Can be 1 or match the target + model's tensor parallel size. + + - Advanced Token Control: + - disable_mqa_scorer (bool): Disable the MQA scorer for the speculative + model and fall back to batch expansion for scoring. If not + specified, it defaults to False. + - disable_by_batch_size (Optional[int]): Disable speculative decoding + for new incoming requests when the number of enqueued requests is + larger than this value, if provided. + - ngram_prompt_lookup_max (Optional[int]): Maximum size of ngram token + window when using Ngram proposer, if provided. + - ngram_prompt_lookup_min (Optional[int]): Minimum size of ngram token + window when using Ngram proposer, if provided. + - typical_acceptance_sampler_posterior_threshold (Optional[float]): + A threshold value that sets a lower bound on the posterior + probability of a token in the target model for it to be accepted. + This threshold is used only when we use the + TypicalAcceptanceSampler for token acceptance. + - typical_acceptance_sampler_posterior_alpha (Optional[float]): Scaling + factor for entropy-based threshold, applied when using + TypicalAcceptanceSampler. + + Non-configurable internal parameters include: + - Model Configuration: + - target_model_config (ModelConfig): The configuration of the target + model. + - draft_model_config (ModelConfig): The configuration of the draft + model initialized internal. + - Parallelism Configuration: + - target_parallel_config (ParallelConfig): The parallel configuration + for the target model. + - draft_parallel_config (ParallelConfig): The parallel configuration + for the draft model initialized internal. + - Execution Control: + - enable_chunked_prefill (bool): Whether vLLM is configured to use + chunked prefill or not. Used for raising an error since it's not + yet compatible with speculative decode. + - disable_log_stats (bool): Whether to disable the periodic printing of + stage times in speculative decoding. + """ + # speculative configs from cli args + num_speculative_tokens: int + model: Optional[str] = None + proposer: Optional[str] = None + quantization: Optional[str] = None + max_model_len: Optional[int] = None + revision: Optional[str] = None + code_revision: Optional[str] = None + draft_tensor_parallel_size: Optional[int] = None + disable_mqa_scorer: bool = False + disable_by_batch_size: Optional[int] = None + ngram_prompt_lookup_max: Optional[int] = None + ngram_prompt_lookup_min: Optional[int] = None + acceptance_method: str = "rejection_sampler" + typical_acceptance_sampler_posterior_threshold: Optional[float] = None + typical_acceptance_sampler_posterior_alpha: Optional[float] = None + disable_logprobs: bool = True + + # required configuration params passed from engine + target_model_config: ModelConfig + target_parallel_config: ParallelConfig + enable_chunked_prefill: bool + disable_log_stats: bool + + # params generated in the post-init stage + draft_model_config: ModelConfig = field(default=None, + init=True) # type: ignore + draft_parallel_config: ParallelConfig = field(default=None, + init=True) # type: ignore def compute_hash(self) -> str: """ @@ -1793,6 +1888,11 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str + @classmethod + def from_dict(cls, dict_value: dict) -> "SpeculativeConfig": + """Parse the CLI value for the speculative config.""" + return cls(**dict_value) + @staticmethod def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: if hf_config.model_type == "deepseek_v3": @@ -1805,231 +1905,137 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: }) return hf_config - @staticmethod - def maybe_create_spec_config( - target_model_config: ModelConfig, - target_parallel_config: ParallelConfig, - target_dtype: str, - speculative_model: Optional[str], - speculative_model_quantization: Optional[str], - speculative_draft_tensor_parallel_size: Optional[int], - num_speculative_tokens: Optional[int], - speculative_disable_mqa_scorer: Optional[bool], - speculative_max_model_len: Optional[int], - enable_chunked_prefill: bool, - disable_log_stats: bool, - speculative_disable_by_batch_size: Optional[int], - ngram_prompt_lookup_max: Optional[int], - ngram_prompt_lookup_min: Optional[int], - draft_token_acceptance_method: str, - typical_acceptance_sampler_posterior_threshold: Optional[float], - typical_acceptance_sampler_posterior_alpha: Optional[float], - disable_logprobs: Optional[bool], - ) -> Optional["SpeculativeConfig"]: - """Create a SpeculativeConfig if possible, else return None. - - This function attempts to create a SpeculativeConfig object based on the - provided parameters. If the necessary conditions are met, it returns an - instance of SpeculativeConfig. Otherwise, it returns None. - - Args: - target_model_config (ModelConfig): The configuration of the target - model. - target_parallel_config (ParallelConfig): The parallel configuration - for the target model. - target_dtype (str): The data type used for the target model. - speculative_model (Optional[str]): The name of the speculative - model, if provided. - speculative_model_quantization (Optional[str]): Quantization method - that was used to quantize the speculative model weights. If - None, we assume the model weights are not quantized. - speculative_draft_tensor_parallel_size (Optional[int]): The degree - of the tensor parallelism for the draft model. - num_speculative_tokens (Optional[int]): The number of speculative - tokens, if provided. Will default to the number in the draft - model config if present, otherwise is required. - speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA - scorer for the speculative model and fall back to batch - expansion for scoring. - speculative_max_model_len (Optional[int]): The maximum model len of - the speculative model. Used when testing the ability to skip - speculation for some sequences. - enable_chunked_prefill (bool): Whether vLLM is configured to use - chunked prefill or not. Used for raising an error since its not - yet compatible with spec decode. - speculative_disable_by_batch_size (Optional[int]): Disable - speculative decoding for new incoming requests when the number - of enqueue requests is larger than this value, if provided. - ngram_prompt_lookup_max (Optional[int]): Max size of ngram token - window, if provided. - ngram_prompt_lookup_min (Optional[int]): Min size of ngram token - window, if provided. - draft_token_acceptance_method (str): The method to use for - accepting draft tokens. This can take two possible - values 'rejection_sampler' and 'typical_acceptance_sampler' - for RejectionSampler and TypicalAcceptanceSampler - respectively. - typical_acceptance_sampler_posterior_threshold (Optional[float]): - A threshold value that sets a lower bound on the posterior - probability of a token in the target model for it to be - accepted. This threshold is used only when we use the - TypicalAcceptanceSampler for token acceptance. - typical_acceptance_sampler_posterior_alpha (Optional[float]): - A scaling factor for the entropy-based threshold in the - TypicalAcceptanceSampler. - disable_logprobs (Optional[bool]): If set to True, token log - probabilities are not returned during speculative decoding. - If set to False, token log probabilities are returned - according to the log probability settings in SamplingParams. - If not specified, it defaults to True. - - Returns: - Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if - the necessary conditions are met, else None. - """ - if speculative_model is None: - if num_speculative_tokens is not None: - if target_model_config.hf_text_config.model_type \ - == "deepseek_v3": - # use the draft model from the same model: - speculative_model = target_model_config.model - else: - raise ValueError( - "num_speculative_tokens was provided without " - "speculative_model.") + def __post_init__(self): + if self.proposer is None and self.model is not None: + # Note: After next release, the proposer parameter will be used + # to specify the speculative method, and the model parameter will + # be used when the draft model or head is needed. If users do not + # specify a proposer, the speculative method will be considered as + # the draft model by default. + self.proposer = self.model + + if self.model is None and self.num_speculative_tokens is not None: + # TODO(Shangming): Refactor mtp configuration logic when supporting + # mtp acceleration for more models besides deepseek_v3 + if (self.proposer == "mtp" + or self.target_model_config.hf_text_config.model_type + == "deepseek_v3"): + # use the draft model from the same model: + self.model = self.target_model_config.model else: - return None - - if (speculative_disable_by_batch_size is not None - and speculative_disable_by_batch_size < 2): - raise ValueError("Expect the batch size threshold of disabling " - "speculative decoding is > 1, but got " - f"{speculative_disable_by_batch_size=}") - if (enable_chunked_prefill and speculative_model == "eagle"): - raise ValueError("Chunked prefill and EAGLE are not compatible.") - # TODO: The user should be able to specify revision/max model len - # for the draft model. It is not currently supported. - draft_revision = None - draft_code_revision = None - draft_quantization = speculative_model_quantization - - if speculative_model == "[ngram]": - if ngram_prompt_lookup_min is None: - ngram_prompt_lookup_min = 1 - if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1: - raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0") - if ngram_prompt_lookup_min < 1: - raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0") - if ngram_prompt_lookup_min > ngram_prompt_lookup_max: - raise ValueError(f"{ngram_prompt_lookup_min=} cannot be " - f"larger than {ngram_prompt_lookup_max=}") + raise ValueError("num_speculative_tokens was provided without " + "speculative model.") + + if self.proposer in ["ngram", "[ngram]"]: + if self.ngram_prompt_lookup_min is None: + self.ngram_prompt_lookup_min = 1 + if (self.ngram_prompt_lookup_max is None + or self.ngram_prompt_lookup_max < 1): + raise ValueError("ngram_prompt_lookup_max=" + f"{self.ngram_prompt_lookup_max} must be > 0") + if self.ngram_prompt_lookup_min < 1: + raise ValueError("ngram_prompt_lookup_min=" + f"{self.ngram_prompt_lookup_min} must be > 0") + if self.ngram_prompt_lookup_min > self.ngram_prompt_lookup_max: + raise ValueError( + f"ngram_prompt_lookup_min={self.ngram_prompt_lookup_min} " + "cannot be larger than ngram_prompt_lookup_max=" + f"{self.ngram_prompt_lookup_max}") # TODO: current we still need extract vocab_size from target model # config, in future, we may try refactor it out, and set # draft related config as None here. - draft_model_config = target_model_config - draft_parallel_config = target_parallel_config + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config else: - ngram_prompt_lookup_max = 0 - ngram_prompt_lookup_min = 0 - draft_model_config = ModelConfig( - model=speculative_model, + self.ngram_prompt_lookup_max = 0 + self.ngram_prompt_lookup_min = 0 + + self.draft_model_config = ModelConfig( + model=self.model, task="draft", - tokenizer=target_model_config.tokenizer, - tokenizer_mode=target_model_config.tokenizer_mode, - trust_remote_code=target_model_config.trust_remote_code, - allowed_local_media_path=target_model_config. + tokenizer=self.target_model_config.tokenizer, + tokenizer_mode=self.target_model_config.tokenizer_mode, + trust_remote_code=self.target_model_config.trust_remote_code, + allowed_local_media_path=self.target_model_config. allowed_local_media_path, - dtype=target_model_config.dtype, - seed=target_model_config.seed, - revision=draft_revision, - code_revision=draft_code_revision, - tokenizer_revision=target_model_config.tokenizer_revision, + dtype=self.target_model_config.dtype, + seed=self.target_model_config.seed, + revision=self.revision, + code_revision=self.code_revision, + tokenizer_revision=self.target_model_config.tokenizer_revision, max_model_len=None, - spec_target_max_model_len=target_model_config.max_model_len, - quantization=draft_quantization, - enforce_eager=target_model_config.enforce_eager, - max_seq_len_to_capture=target_model_config. + spec_target_max_model_len=self.target_model_config. + max_model_len, + quantization=self.quantization, + enforce_eager=self.target_model_config.enforce_eager, + max_seq_len_to_capture=self.target_model_config. max_seq_len_to_capture, - max_logprobs=target_model_config.max_logprobs, + max_logprobs=self.target_model_config.max_logprobs, hf_overrides=SpeculativeConfig.hf_config_override, ) - draft_hf_config = draft_model_config.hf_config + # Detect proposer type or EAGLE prefix to replace hf_config for + # EAGLE draft_model + if (self.proposer == "ealge" + or "eagle-" in self.draft_model_config.model.lower()): + if self.enable_chunked_prefill: + raise ValueError( + "Chunked prefill and EAGLE are not compatible.") - # Detect EAGLE prefix to replace hf_config for EAGLE draft_model - if "eagle-" in draft_model_config.model.lower(): from vllm.transformers_utils.configs.eagle import EAGLEConfig - if isinstance(draft_model_config.hf_config, EAGLEConfig): + if isinstance(self.draft_model_config.hf_config, EAGLEConfig): pass else: - eagle_config = EAGLEConfig(draft_model_config.hf_config) - draft_model_config.hf_config = eagle_config - - if (num_speculative_tokens is not None - and hasattr(draft_hf_config, "num_lookahead_tokens")): - draft_hf_config.num_lookahead_tokens = num_speculative_tokens - n_predict = getattr(draft_hf_config, "n_predict", None) + eagle_config = EAGLEConfig( + self.draft_model_config.hf_config) + self.draft_model_config.hf_config = eagle_config + + if (self.num_speculative_tokens is not None + and hasattr(self.draft_model_config.hf_config, + "num_lookahead_tokens")): + self.draft_model_config.hf_config.num_lookahead_tokens = \ + self.num_speculative_tokens + + n_predict = getattr(self.draft_model_config.hf_config, "n_predict", + None) if n_predict is not None: - if num_speculative_tokens is None: + if self.num_speculative_tokens is None: # Default to max value defined in draft model config. - num_speculative_tokens = n_predict - elif num_speculative_tokens > n_predict: - # Verify provided value doesn't exceed the maximum - # supported by the draft model. + self.num_speculative_tokens = n_predict + elif self.num_speculative_tokens > n_predict and \ + self.num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. raise ValueError( - "This speculative model supports a maximum of " - f"num_speculative_tokens={n_predict}, but " - f"{num_speculative_tokens=} was provided.") - - speculative_draft_tensor_parallel_size = \ - SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( - target_parallel_config, - speculative_draft_tensor_parallel_size, - draft_hf_config + f"{self.num_speculative_tokens=} must be divisible by " + f"{n_predict=}") + + self.draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_tp( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_model_config.hf_config ) - draft_model_config.max_model_len = ( + self.draft_model_config.max_model_len = ( SpeculativeConfig._maybe_override_draft_max_model_len( - speculative_max_model_len, - draft_model_config.max_model_len, - target_model_config.max_model_len, + self.max_model_len, + self.draft_model_config.max_model_len, + self.target_model_config.max_model_len, )) - draft_parallel_config = ( + self.draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( - target_parallel_config, - speculative_draft_tensor_parallel_size, draft_hf_config)) + self.target_parallel_config, + self.draft_tensor_parallel_size)) - if num_speculative_tokens is None: - raise ValueError( - "num_speculative_tokens must be provided with " - "speculative_model unless the draft model config contains an " - "n_predict parameter.") + if self.acceptance_method == "typical_acceptance_sampler": + if self.typical_acceptance_sampler_posterior_threshold is None: + self.typical_acceptance_sampler_posterior_threshold = 0.09 + if self.typical_acceptance_sampler_posterior_alpha is None: + self.typical_acceptance_sampler_posterior_alpha = 0.3 - if typical_acceptance_sampler_posterior_threshold is None: - typical_acceptance_sampler_posterior_threshold = 0.09 - if typical_acceptance_sampler_posterior_alpha is None: - typical_acceptance_sampler_posterior_alpha = 0.3 - if disable_logprobs is None: - disable_logprobs = True - - return SpeculativeConfig( - draft_model_config, - draft_parallel_config, - num_speculative_tokens, - speculative_disable_mqa_scorer, - speculative_disable_by_batch_size, - ngram_prompt_lookup_max, - ngram_prompt_lookup_min, - draft_token_acceptance_method=draft_token_acceptance_method, - typical_acceptance_sampler_posterior_threshold=\ - typical_acceptance_sampler_posterior_threshold, - typical_acceptance_sampler_posterior_alpha=\ - typical_acceptance_sampler_posterior_alpha, - disable_logprobs=disable_logprobs, - disable_log_stats=disable_log_stats, - ) + self._verify_args() @staticmethod def _maybe_override_draft_max_model_len( @@ -2067,7 +2073,7 @@ def _maybe_override_draft_max_model_len( ) @staticmethod - def _verify_and_get_draft_model_tensor_parallel_size( + def _verify_and_get_draft_tp( target_parallel_config: ParallelConfig, speculative_draft_tensor_parallel_size: Optional[int], draft_hf_config: PretrainedConfig) -> int: @@ -2099,7 +2105,6 @@ def _verify_and_get_draft_model_tensor_parallel_size( def create_draft_parallel_config( target_parallel_config: ParallelConfig, speculative_draft_tensor_parallel_size: int, - draft_hf_config: PretrainedConfig, ) -> ParallelConfig: """Create a parallel config for use by the draft worker. @@ -2123,74 +2128,13 @@ def create_draft_parallel_config( return draft_parallel_config - def __init__( - self, - draft_model_config: ModelConfig, - draft_parallel_config: ParallelConfig, - num_speculative_tokens: int, - speculative_disable_mqa_scorer: Optional[bool], - speculative_disable_by_batch_size: Optional[int], - ngram_prompt_lookup_max: Optional[int], - ngram_prompt_lookup_min: Optional[int], - draft_token_acceptance_method: str, - typical_acceptance_sampler_posterior_threshold: float, - typical_acceptance_sampler_posterior_alpha: float, - disable_logprobs: bool, - disable_log_stats: bool, - ): - """Create a SpeculativeConfig object. - - Args: - draft_model_config: ModelConfig for the draft model. - draft_parallel_config: ParallelConfig for the draft model. - num_speculative_tokens: The number of tokens to sample from the - draft model before scoring with the target model. - speculative_disable_by_batch_size: Disable speculative - decoding for new incoming requests when the number of - enqueue requests is larger than this value. - ngram_prompt_lookup_max: Max size of ngram token window. - ngram_prompt_lookup_min: Min size of ngram token window. - draft_token_acceptance_method (str): The method to use for - accepting draft tokens. This can take two possible - values 'rejection_sampler' and 'typical_acceptance_sampler' - for RejectionSampler and TypicalAcceptanceSampler - respectively. - typical_acceptance_sampler_posterior_threshold (Optional[float]): - A threshold value that sets a lower bound on the posterior - probability of a token in the target model for it to be - accepted. This threshold is used only when we use the - TypicalAcceptanceSampler for token acceptance. - typical_acceptance_sampler_posterior_alpha (Optional[float]): - A scaling factor for the entropy-based threshold in the - TypicalAcceptanceSampler. - disable_logprobs: If set to True, token log probabilities will not - be returned even if requested by sampling parameters. This - reduces latency by skipping logprob calculation in proposal - sampling, target sampling, and after accepted tokens are - determined. If set to False, log probabilities will be - returned. - disable_log_stats: Whether to disable periodic printing of stage - times in speculative decoding. - """ - self.draft_model_config = draft_model_config - self.draft_parallel_config = draft_parallel_config - self.num_speculative_tokens = num_speculative_tokens - self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer - self.speculative_disable_by_batch_size = \ - speculative_disable_by_batch_size - self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 - self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 - self.draft_token_acceptance_method = draft_token_acceptance_method - self.typical_acceptance_sampler_posterior_threshold = \ - typical_acceptance_sampler_posterior_threshold - self.typical_acceptance_sampler_posterior_alpha = \ - typical_acceptance_sampler_posterior_alpha - self.disable_logprobs = disable_logprobs - self.disable_log_stats = disable_log_stats - - self._verify_args() - def _verify_args(self) -> None: + if self.num_speculative_tokens is None: + raise ValueError( + "num_speculative_tokens must be provided with " + "speculative model unless the draft model config contains an " + "n_predict parameter.") + if self.num_speculative_tokens <= 0: raise ValueError("Expected num_speculative_tokens to be greater " f"than zero ({self.num_speculative_tokens}).") @@ -2200,18 +2144,17 @@ def _verify_args(self) -> None: self.draft_parallel_config) # Validate and set draft token acceptance related settings. - if (self.draft_token_acceptance_method is None): - raise ValueError("draft_token_acceptance_method is not set. " + if (self.acceptance_method is None): + raise ValueError("acceptance_method is not set. " "Expected values are rejection_sampler or " "typical_acceptance_sampler.") - if (self.draft_token_acceptance_method != 'rejection_sampler' - and self.draft_token_acceptance_method - != 'typical_acceptance_sampler'): + if (self.acceptance_method != 'rejection_sampler' + and self.acceptance_method != 'typical_acceptance_sampler'): raise ValueError( - "Expected draft_token_acceptance_method to be either " + "Expected acceptance_method to be either " "rejection_sampler or typical_acceptance_sampler. Instead it " - f"is {self.draft_token_acceptance_method}") + f"is {self.acceptance_method}") if (self.typical_acceptance_sampler_posterior_threshold < 0 or self.typical_acceptance_sampler_posterior_alpha < 0): @@ -2224,6 +2167,12 @@ def _verify_args(self) -> None: f"typical_acceptance_sampler_posterior_alpha = " f"{self.typical_acceptance_sampler_posterior_alpha}") + if (self.disable_by_batch_size is not None + and self.disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{self.disable_by_batch_size=}") + @property def num_lookahead_slots(self) -> int: """The number of additional slots the scheduler should allocate per @@ -3206,7 +3155,8 @@ class VllmConfig: init=True) # type: ignore load_config: LoadConfig = field(default=None, init=True) # type: ignore lora_config: Optional[LoRAConfig] = None - speculative_config: Optional[SpeculativeConfig] = None + speculative_config: SpeculativeConfig = field(default=None, + init=True) # type: ignore decoding_config: Optional[DecodingConfig] = None observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 26d4a84b841c..252302c16789 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -174,7 +174,10 @@ class EngineArgs: guided_decoding_backend: str = 'xgrammar' logits_processor_pattern: Optional[str] = None - # Speculative decoding configuration. + + speculative_config: Optional[Dict[str, Any]] = None + + # TODO(Shangming): Deprecate these out-of-date params after next release speculative_model: Optional[str] = None speculative_model_quantization: Optional[str] = None speculative_draft_tensor_parallel_size: Optional[int] = None @@ -187,9 +190,9 @@ class EngineArgs: spec_decoding_acceptance_method: str = 'rejection_sampler' typical_acceptance_sampler_posterior_threshold: Optional[float] = None typical_acceptance_sampler_posterior_alpha: Optional[float] = None - qlora_adapter_name_or_path: Optional[str] = None disable_logprobs_during_spec_decoding: Optional[bool] = None + qlora_adapter_name_or_path: Optional[str] = None show_hidden_metrics_for_version: Optional[str] = None otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None @@ -756,7 +759,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: const="True", help='If set, the prefill requests can be chunked based on the ' 'max_num_batched_tokens.') - + parser.add_argument('--speculative-config', + type=json.loads, + default=None, + help='The configurations for speculative decoding.' + ' Should be a JSON string.') parser.add_argument( '--speculative-model', type=nullable_str, @@ -1136,6 +1143,64 @@ def create_load_config(self) -> LoadConfig: ignore_patterns=self.ignore_patterns, ) + def create_speculative_config( + self, + target_model_config: ModelConfig, + target_parallel_config: ParallelConfig, + enable_chunked_prefill: bool, + disable_log_stats: bool, + ) -> Optional["SpeculativeConfig"]: + + if self.speculative_config is None: + if (self.speculative_model is None + and self.num_speculative_tokens is None): + return None + + # TODO(Shangming): Deprecate this way of setting SpeculativeConfig, + # only allow '--speculative-config' after next release + logger.warning_once( + "Please use '--speculative-config' to set all configurations " + "related to speculative decoding. The current method of " + "specifying the model through '--speculative-model' and " + "adding related parameters (e.g., '--num-speculative-tokens') " + "separately will be deprecated in the next release.") + + spec_config_dict = { + "model": self.speculative_model, + "quantization": self.speculative_model_quantization, + "max_model_len": self.speculative_max_model_len, + "draft_tensor_parallel_size": + self.speculative_draft_tensor_parallel_size, + "num_speculative_tokens": self.num_speculative_tokens, + "disable_mqa_scorer": self.speculative_disable_mqa_scorer, + "disable_by_batch_size": + self.speculative_disable_by_batch_size, + "ngram_prompt_lookup_max": self.ngram_prompt_lookup_max, + "ngram_prompt_lookup_min": self.ngram_prompt_lookup_min, + "acceptance_method": self.spec_decoding_acceptance_method, + "typical_acceptance_sampler_posterior_threshold": + self.typical_acceptance_sampler_posterior_threshold, + "typical_acceptance_sampler_posterior_alpha": + self.typical_acceptance_sampler_posterior_alpha, + "disable_logprobs": self.disable_logprobs_during_spec_decoding, + } + + self.speculative_config = spec_config_dict + + # Note(Shangming): These parameters are not obtained from the cli arg + # '--speculative-config' and must be passed in when creating the engine + # config. + self.speculative_config.update({ + "target_model_config": target_model_config, + "target_parallel_config": target_parallel_config, + "enable_chunked_prefill": enable_chunked_prefill, + "disable_log_stats": disable_log_stats, + }) + speculative_config = SpeculativeConfig.from_dict( + self.speculative_config) + + return speculative_config + def create_engine_config(self, usage_context: Optional[UsageContext] = None ) -> VllmConfig: @@ -1226,31 +1291,11 @@ def create_engine_config(self, msg = "Chunked prefill is not supported for pooling models" raise ValueError(msg) - speculative_config = SpeculativeConfig.maybe_create_spec_config( + speculative_config = self.create_speculative_config( target_model_config=model_config, target_parallel_config=parallel_config, - target_dtype=self.dtype, - speculative_model=self.speculative_model, - speculative_model_quantization = \ - self.speculative_model_quantization, - speculative_draft_tensor_parallel_size = \ - self.speculative_draft_tensor_parallel_size, - num_speculative_tokens=self.num_speculative_tokens, - speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer, - speculative_disable_by_batch_size=self. - speculative_disable_by_batch_size, - speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, disable_log_stats=self.disable_log_stats, - ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, - ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, - draft_token_acceptance_method=\ - self.spec_decoding_acceptance_method, - typical_acceptance_sampler_posterior_threshold=self. - typical_acceptance_sampler_posterior_threshold, - typical_acceptance_sampler_posterior_alpha=self. - typical_acceptance_sampler_posterior_alpha, - disable_logprobs=self.disable_logprobs_during_spec_decoding, ) # Reminder: Please update docs/source/features/compatibility_matrix.md diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 871a3aee6306..79efabdae10a 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -99,11 +99,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": spec_decode_worker = SpecDecodeWorker.create_worker( scorer_worker=target_worker, draft_worker_kwargs=draft_worker_kwargs, - disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer, - disable_by_batch_size=speculative_config. - speculative_disable_by_batch_size, - draft_token_acceptance_method=speculative_config. - draft_token_acceptance_method, + disable_mqa_scorer=speculative_config.disable_mqa_scorer, + disable_by_batch_size=speculative_config.disable_by_batch_size, + draft_token_acceptance_method=speculative_config.acceptance_method, typical_acceptance_sampler_posterior_threshold=speculative_config. typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=speculative_config. From a61726d1f17759517090b57fdb9878ff67df8651 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Fri, 7 Mar 2025 14:20:15 +0800 Subject: [PATCH 02/21] fix verify Signed-off-by: Shangming Cai --- vllm/config.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 1c14dc7aaa8d..0ae0cb424f18 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1859,10 +1859,13 @@ class SpeculativeConfig: disable_logprobs: bool = True # required configuration params passed from engine - target_model_config: ModelConfig - target_parallel_config: ParallelConfig - enable_chunked_prefill: bool - disable_log_stats: bool + target_model_config: ModelConfig = field(default=None, + init=True) # type: ignore + target_parallel_config: ParallelConfig = field(default=None, + init=True) # type: ignore + enable_chunked_prefill: bool = field(default=None, + init=True) # type: ignore + disable_log_stats: bool = field(default=None, init=True) # type: ignore # params generated in the post-init stage draft_model_config: ModelConfig = field(default=None, @@ -2156,16 +2159,17 @@ def _verify_args(self) -> None: "rejection_sampler or typical_acceptance_sampler. Instead it " f"is {self.acceptance_method}") - if (self.typical_acceptance_sampler_posterior_threshold < 0 - or self.typical_acceptance_sampler_posterior_alpha < 0): - raise ValueError( - "Expected typical_acceptance_sampler_posterior_threshold " - "and typical_acceptance_sampler_posterior_alpha to be > 0. " - "Instead found " - f"typical_acceptance_sampler_posterior_threshold = " - f"{self.typical_acceptance_sampler_posterior_threshold} and " - f"typical_acceptance_sampler_posterior_alpha = " - f"{self.typical_acceptance_sampler_posterior_alpha}") + if self.acceptance_method == "typical_acceptance_sampler": + if (self.typical_acceptance_sampler_posterior_threshold < 0 + or self.typical_acceptance_sampler_posterior_alpha < 0): + raise ValueError( + "Expected typical_acceptance_sampler_posterior_threshold " + "and typical_acceptance_sampler_posterior_alpha to be > 0." + " Instead found " + f"typical_acceptance_sampler_posterior_threshold = " + f"{self.typical_acceptance_sampler_posterior_threshold} " + f"and typical_acceptance_sampler_posterior_alpha = " + f"{self.typical_acceptance_sampler_posterior_alpha}") if (self.disable_by_batch_size is not None and self.disable_by_batch_size < 2): From 8fe1dc0ca9980382004471c17925a823e7e3e0f4 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Fri, 7 Mar 2025 14:57:55 +0800 Subject: [PATCH 03/21] fix cli Signed-off-by: Shangming Cai --- vllm/engine/arg_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 252302c16789..bce6ca090325 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -175,7 +175,7 @@ class EngineArgs: guided_decoding_backend: str = 'xgrammar' logits_processor_pattern: Optional[str] = None - speculative_config: Optional[Dict[str, Any]] = None + speculative_config: Optional[Union[str, Dict[str, Any]]] = None # TODO(Shangming): Deprecate these out-of-date params after next release speculative_model: Optional[str] = None @@ -760,7 +760,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='If set, the prefill requests can be chunked based on the ' 'max_num_batched_tokens.') parser.add_argument('--speculative-config', - type=json.loads, + type=nullable_str, default=None, help='The configurations for speculative decoding.' ' Should be a JSON string.') @@ -1186,7 +1186,11 @@ def create_speculative_config( } self.speculative_config = spec_config_dict - + else: + if isinstance(self.speculative_config, str): + import ast + self.speculative_config = ast.literal_eval( + self.speculative_config) # Note(Shangming): These parameters are not obtained from the cli arg # '--speculative-config' and must be passed in when creating the engine # config. From 2332547306856b83f5679be5d0c89c62a4c094fd Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Fri, 7 Mar 2025 17:24:16 +0800 Subject: [PATCH 04/21] modify e2e tests Signed-off-by: Shangming Cai --- tests/spec_decode/e2e/test_compatibility.py | 29 ++- .../spec_decode/e2e/test_eagle_correctness.py | 97 +++++---- tests/spec_decode/e2e/test_integration.py | 44 ++-- .../e2e/test_integration_dist_tp2.py | 77 +++---- .../e2e/test_integration_dist_tp4.py | 28 ++- tests/spec_decode/e2e/test_logprobs.py | 190 ++++++++++-------- .../e2e/test_medusa_correctness.py | 95 +++++---- tests/spec_decode/e2e/test_mlp_correctness.py | 89 ++++---- tests/spec_decode/e2e/test_mtp_correctness.py | 64 +++--- .../e2e/test_multistep_correctness.py | 177 ++++++++++------ .../spec_decode/e2e/test_ngram_correctness.py | 154 ++++++++------ tests/spec_decode/e2e/test_seed.py | 10 +- vllm/config.py | 36 ++-- 13 files changed, 631 insertions(+), 459 deletions(-) diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index 83d1551afe5a..4fd52cf7e2cb 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -7,28 +7,39 @@ from .conftest import get_output_from_llm_generator -@pytest.mark.parametrize("common_llm_kwargs", [{ - "model": "meta-llama/Llama-3.2-1B-Instruct", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, -}]) +@pytest.mark.parametrize("common_llm_kwargs", + [{ + "model": "meta-llama/Llama-3.2-1B-Instruct", + }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ { # Speculative max model len > overridden max model len should raise. + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "max_model_len": 129, + }, "max_model_len": 128, - "speculative_max_model_len": 129, }, { # Speculative max model len > draft max model len should raise. # https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12 - "speculative_max_model_len": 2048 + 1, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "max_model_len": 2048 + 1, + }, }, { # Speculative max model len > target max model len should raise. - # https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18 - "speculative_max_model_len": 131072 + 1, + # https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18 + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "max_model_len": 131072 + 1, + }, }, ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index 42a84071d94d..eee535a146f4 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -57,8 +57,10 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_model": SPEC_MODEL, +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": False, + "disable_logprobs": False, }, - { - "speculative_model": SPEC_MODEL, +}, { + "speculative_config": { + "model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": True, + "disable_logprobs": True, }, -]) +}]) @pytest.mark.parametrize("output_len", [ 128, ]) @@ -119,18 +122,19 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, batch_size: int, output_len: int, seed: int, logprobs: int): - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -151,8 +155,10 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -193,8 +199,10 @@ def test_eagle_e2e_greedy_correctness_cuda_graph( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( @@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption( "test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": k, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": k, + }, } # Try a range of num. speculative tokens for k in range(1, 1 + MAX_SPEC_TOKENS) @@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "speculative_disable_by_batch_size": 4 - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_by_batch_size": 4, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", @@ -324,8 +335,10 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "yuhuili/EAGLE-llama2-chat-7B", - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": "yuhuili/EAGLE-llama2-chat-7B", + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( @@ -372,8 +385,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( @@ -420,8 +435,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct", - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": "yuhuili/EAGLE-Qwen2-7B-Instruct", + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py index c67fa85146c6..6d6106f05db5 100644 --- a/tests/spec_decode/e2e/test_integration.py +++ b/tests/spec_decode/e2e/test_integration.py @@ -23,8 +23,10 @@ [ { # Identical models. - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, }, ]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - "speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", - "num_speculative_tokens": 5, - }, -]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", []) @pytest.mark.parametrize( "test_llm_kwargs", [ # Explicitly specify draft model quantization { - "speculative_model_quantization": "gptq", + "speculative_config": { + "model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", + "num_speculative_tokens": 5, + "quantization": "gptq", + }, }, # Explicitly specify GPTQ-based draft model to use marlin quantization { - "speculative_model_quantization": "marlin", + "speculative_config": { + "model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", + "num_speculative_tokens": 5, + "quantization": "marlin", + }, }, # Not explicitly specify draft model quantization { - "speculative_model_quantization": None, + "speculative_config": { + "model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", + "num_speculative_tokens": 5, + "quantization": None, + }, }, ]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_disable_mqa_scorer": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_mqa_scorer": True, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index e5a542b6d84c..3729629dace7 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -27,18 +27,19 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("test_llm_kwargs", [ [ - "--speculative-model", - "JackFram/llama-68m", - "--num-speculative-tokens", - "3", + "--speculative_config", + str({ + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }), ], [ - "--speculative-model", - "[ngram]", - "--num-speculative-tokens", - "5", - "--ngram-prompt-lookup-max", - "3", + "--speculative_config", + str({ + "model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }), ], ]) @pytest.mark.parametrize("batch_size", [2]) @@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, ]]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [[]]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]]) -@pytest.mark.parametrize("model, test_llm_kwargs", - [("JackFram/llama-68m", [ - "--speculative-model", - "JackFram/llama-68m", - "--num_speculative-tokens", - "5", - "--speculative-draft-tensor-parallel-size", - "1", - ]), - ("ibm-granite/granite-3b-code-instruct", [ - "--speculative-model", - "ibm-granite/granite-3b-code-instruct", - "--num_speculative-tokens", - "5", - "--speculative-draft-tensor-parallel-size", - "1", - ])]) +@pytest.mark.parametrize( + "model, test_llm_kwargs", + [("JackFram/llama-68m", [ + "--speculative_config", + str({ + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "draft_tensor_parallel_size": 1, + }), + ]), + ("ibm-granite/granite-3b-code-instruct", [ + "--speculative_config", + str({ + "model": "ibm-granite/granite-3b-code-instruct", + "num_speculative_tokens": 5, + "draft_tensor_parallel_size": 1, + }), + ])]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seed", [1]) def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, @@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("model, test_llm_kwargs", [("JackFram/llama-68m", [ - "--speculative-model", - "JackFram/llama-68m", - "--num_speculative-tokens", - "3", + "--speculative_config", + str({ + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }), ]), ("JackFram/llama-68m", [ - "--speculative-model", - "JackFram/llama-68m", - "--num_speculative-tokens", - "3", - "--speculative-draft-tensor-parallel-size", - "1", + "--speculative_config", + str({ + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "draft_tensor_parallel_size": 1, + }), ])]) @pytest.mark.parametrize("logprobs", [None, 2]) @pytest.mark.parametrize("batch_size", [2]) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py index cb9c46dc7071..d42d9029fef6 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp4.py @@ -24,12 +24,7 @@ "4", ]]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ - [ - "--speculative-model", - f"{SPEC_MODEL}", - "--num-speculative-tokens", - "5", - ], + [], ]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize( @@ -37,8 +32,12 @@ [ #TODO(wooyeon): add spec_draft_dp=2 case [ - "--speculative-draft-tensor-parallel-size", - "1", + "--speculative_config", + str({ + "model": f"{SPEC_MODEL}", + "num_speculative_tokens": 5, + "draft_tensor_parallel_size": 1, + }), ], ]) @pytest.mark.parametrize("batch_size", [2]) @@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs, "test_llm_kwargs", [ [ - "--speculative-model", - f"{SPEC_MODEL}", - "--num-speculative-tokens", - "5", - # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. - "--speculative-max-model-len", - "32", + "--speculative_config", + str({ + "model": f"{SPEC_MODEL}", + "num_speculative_tokens": 5, + "max_model_len": 32, + }), ], ]) @pytest.mark.parametrize("batch_size", [8]) diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 5991a8b02353..cb2dae541411 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -20,16 +20,19 @@ }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": False, - }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_logprobs": False, + }, +}, { + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_logprobs": True, + }, +}]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( "output_len", @@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, as well as with and without chunked prefill. """ maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": False, - }, { - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 6, - "disable_logprobs_during_spec_decoding": False, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs": False, + }, +}, { + "speculative_config": { + "model": "JackFram/llama-160m", + "num_speculative_tokens": 6, + "disable_logprobs": False, + }, +}]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( "output_len", @@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, output_len: int, seed: int, logprobs: int): """Veriy logprob greedy equality with different speculation lens. """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize( "test_llm_kwargs", [{ - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": False, - - # Artificially limit the draft model max model len; this forces vLLM - # to skip speculation once the sequences grow beyond 32-k tokens. - "speculative_max_model_len": 32, + "speculative_config": { + "model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs": False, + # Artificially limit the draft model max model len; this forces + # vLLM to skip speculation once the sequences grow beyond 32-k + # tokens. + "max_model_len": 32, + }, }]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( @@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify logprobs greedy equality when some sequences skip speculation. """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": False, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs": False, + }, +}]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize( "output_len", @@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_logprobs": True, + }, +}]) @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize( @@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs, """Check the behavior when logprobs are disabled. Token choices should match with the base model. """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 807f41cc9e5c..1be0e00384ee 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -60,8 +60,10 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": False, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": False, + }, }, { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": True, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": True, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -132,19 +138,20 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, prefill_chunk_size: int): """Verify greedy equality with different batch size.""" maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -165,8 +172,10 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -214,8 +223,10 @@ def test_medusa_e2e_greedy_correctness_cuda_graph( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( @@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption( "test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": k, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": k, + }, } # Try a range of num. speculative tokens for k in range(1, 1 + MAX_SPEC_TOKENS) @@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "speculative_disable_by_batch_size": 4 - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_by_batch_size": 4, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", @@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, # Main model "model_name": MAIN_MODEL, - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "speculative_disable_by_batch_size": 4 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_disable_mqa_scorer": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_by_batch_size": 4, + "disable_mqa_scorer": True, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 59beca47acd0..3efda40066b3 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -62,7 +62,9 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, + "speculative_config": { + "model": SPEC_MODEL, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "disable_logprobs_during_spec_decoding": False, + "speculative_config": { + "model": SPEC_MODEL, + "disable_logprobs": False, + }, }, { - "speculative_model": SPEC_MODEL, - "disable_logprobs_during_spec_decoding": True, + "speculative_config": { + "model": SPEC_MODEL, + "disable_logprobs": True, + }, }, ]) @pytest.mark.parametrize("output_len", [8]) @@ -133,19 +139,20 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, # up sampling different tokens at the tail (ie top tokens don't change). # TL;DR: sd+cp == org+cp but sd+cp != org..is this expected? maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, + "speculative_config": { + "model": SPEC_MODEL, + }, }, ]) @pytest.mark.parametrize("output_len", [2048]) @@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, # Main model "model_name": MAIN_MODEL, - # Speculative model - "speculative_model": SPEC_MODEL, + # Speculative config + "speculative_config": { + "model": SPEC_MODEL, + }, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) @@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, + "speculative_config": { + "model": SPEC_MODEL, + }, }, ]) @pytest.mark.parametrize( @@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, + "speculative_config": { + "model": SPEC_MODEL, + }, }, ]) @pytest.mark.parametrize( @@ -382,8 +397,10 @@ def patched_pad_vocab_size(vocab_size, pad_to=None): "test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": k, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": k, + }, } # Try a range of num. speculative tokens for k in range(1, 1 + MAX_SPEC_TOKENS) @@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": SPEC_MODEL, - "speculative_disable_by_batch_size": 4 - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "disable_by_batch_size": 4, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", @@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - "speculative_model": SPEC_MODEL, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_disable_mqa_scorer": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "disable_mqa_scorer": True, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", diff --git a/tests/spec_decode/e2e/test_mtp_correctness.py b/tests/spec_decode/e2e/test_mtp_correctness.py index 0bad19f61d30..371e6834b639 100644 --- a/tests/spec_decode/e2e/test_mtp_correctness.py +++ b/tests/spec_decode/e2e/test_mtp_correctness.py @@ -57,7 +57,9 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": False, + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": False, + }, }, { - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": True, + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": True, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -119,18 +125,19 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, batch_size: int, output_len: int, seed: int, logprobs: int): - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -152,7 +159,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -198,7 +207,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( @@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption( "test_llm_kwargs", [ { - "num_speculative_tokens": k, + "speculative_config": { + "num_speculative_tokens": k, + }, } # Try a range of num. speculative tokens for k in range(1, 1 + MAX_SPEC_TOKENS) @@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "num_speculative_tokens": MAX_SPEC_TOKENS, - "speculative_disable_by_batch_size": 4 - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_by_batch_size": 4 + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index d396e52a9ddc..023e92d47551 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -61,15 +61,19 @@ "per_test_common_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { # Chunked prefill enabled with small value # to make sure we get mixed batches. - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, }, ]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "enable_chunked_prefill": False, - "disable_logprobs_during_spec_decoding": False - }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4, - "disable_logprobs_during_spec_decoding": False - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "disable_logprobs": False, + }, + "enable_chunked_prefill": False, +}, { + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_logprobs": False, + }, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4, +}]) @pytest.mark.parametrize( "output_len", [ @@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( whether all speculative tokens are accepted. """ ensure_all_accepted = per_test_common_llm_kwargs.get( - "model_name") == test_llm_kwargs.get("speculative_model") + "model_name") == test_llm_kwargs.get("speculative_config")["model"] run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, "test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. - "speculative_max_model_len": 32, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "max_model_len": 32, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "max_model_len": 32, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4, - "speculative_max_model_len": 32, }, ]) @pytest.mark.parametrize("batch_size", [8]) @@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_disable_by_batch_size": 2, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "disable_by_batch_size": 2, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_disable_by_batch_size": 2, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "disable_by_batch_size": 2, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4, @@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs, "test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": k, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": k, + }, "enable_chunked_prefill": False, } # Try a range of common k, as well as large speculation. for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] ] + [{ - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": k, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": k, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4, @@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, "test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": k, - "spec_decoding_acceptance_method": "typical_acceptance_sampler", + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": k, + "acceptance_method": "typical_acceptance_sampler", + }, "enable_chunked_prefill": False } # Try a range of common k. for k in [1, 2, 3] ] + [{ - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": k, - "spec_decoding_acceptance_method": "typical_acceptance_sampler", + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": k, + "acceptance_method": "typical_acceptance_sampler", + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 1aff53cb55c9..96f50f9ec43c 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -48,16 +48,20 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "speculative_disable_mqa_scorer": False, + "speculative_config": { + "proposer": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_mqa_scorer": False, + }, }, { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "speculative_disable_mqa_scorer": True, + "speculative_config": { + "proposer": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_mqa_scorer": True, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "disable_logprobs_during_spec_decoding": False, + "speculative_config": { + "proposer": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs": False, + }, }, { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "disable_logprobs_during_spec_decoding": True, + "speculative_config": { + "proposer": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs": True, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -125,19 +133,20 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, batch_size: int, output_len: int, seed: int, logprobs: int): """Verify greedy equality on a tiny model with different batch size.""" - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "speculative_config": { + "proposer": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "speculative_config": { + "proposer": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_mqa_scorer": True, + }, "enable_chunked_prefill": True, - "speculative_disable_mqa_scorer": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 }, @@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption( "test_llm_kwargs", [ { - "speculative_model": "[ngram]", - "num_speculative_tokens": k, - "ngram_prompt_lookup_max": 3, + "speculative_config": { + "proposer": "[ngram]", + "num_speculative_tokens": k, + "ngram_prompt_lookup_max": 3, + }, } # Try a range of common k, as well as large speculation. for k in [1, 3, 5] ] + [ { - "speculative_model": "[ngram]", - "num_speculative_tokens": k, - "ngram_prompt_lookup_max": 1, + "speculative_config": { + "proposer": "[ngram]", + "num_speculative_tokens": k, + "ngram_prompt_lookup_max": 1, + }, } # Try a range of common k, as well as large speculation. for k in [1, 3, 5] @@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "speculative_disable_by_batch_size": 4 - }, { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "speculative_disable_by_batch_size": 4, - "enable_chunked_prefill": True, - "speculative_disable_mqa_scorer": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "proposer": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_by_batch_size": 4 + }, +}, { + "speculative_config": { + "proposer": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_by_batch_size": 4, + "disable_mqa_scorer": True, + }, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", @@ -316,18 +336,18 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_disable_mqa_scorer": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "proposer": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_by_batch_size": 4, + "disable_mqa_scorer": True, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", diff --git a/tests/spec_decode/e2e/test_seed.py b/tests/spec_decode/e2e/test_seed.py index b7d279f2919b..3dc37172285e 100644 --- a/tests/spec_decode/e2e/test_seed.py +++ b/tests/spec_decode/e2e/test_seed.py @@ -19,11 +19,11 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # speculative model - "speculative_model": "JackFram/llama-160m", - - # num speculative tokens - "num_speculative_tokens": 3, + # speculative config + "speculative_config": { + "model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + }, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) diff --git a/vllm/config.py b/vllm/config.py index 0ae0cb424f18..d88379e7afc2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1841,7 +1841,8 @@ class SpeculativeConfig: stage times in speculative decoding. """ # speculative configs from cli args - num_speculative_tokens: int + num_speculative_tokens: Optional[int] = field(default=None, + init=True) # type: ignore model: Optional[str] = None proposer: Optional[str] = None quantization: Optional[str] = None @@ -1911,10 +1912,11 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: def __post_init__(self): if self.proposer is None and self.model is not None: # Note: After next release, the proposer parameter will be used - # to specify the speculative method, and the model parameter will - # be used when the draft model or head is needed. If users do not - # specify a proposer, the speculative method will be considered as - # the draft model by default. + # to specify the speculative method, which helps to extend the + # configuration of non-model-based proposers, and the model + # parameter will be used when the draft model or head is needed. + # If users do not specify the proposer, the speculative method will + # be considered as the model-based method by default. self.proposer = self.model if self.model is None and self.num_speculative_tokens is not None: @@ -1925,6 +1927,8 @@ def __post_init__(self): == "deepseek_v3"): # use the draft model from the same model: self.model = self.target_model_config.model + elif self.proposer in ["ngram", "[ngram]"]: + self.model = self.proposer else: raise ValueError("num_speculative_tokens was provided without " "speculative model.") @@ -2159,17 +2163,17 @@ def _verify_args(self) -> None: "rejection_sampler or typical_acceptance_sampler. Instead it " f"is {self.acceptance_method}") - if self.acceptance_method == "typical_acceptance_sampler": - if (self.typical_acceptance_sampler_posterior_threshold < 0 - or self.typical_acceptance_sampler_posterior_alpha < 0): - raise ValueError( - "Expected typical_acceptance_sampler_posterior_threshold " - "and typical_acceptance_sampler_posterior_alpha to be > 0." - " Instead found " - f"typical_acceptance_sampler_posterior_threshold = " - f"{self.typical_acceptance_sampler_posterior_threshold} " - f"and typical_acceptance_sampler_posterior_alpha = " - f"{self.typical_acceptance_sampler_posterior_alpha}") + if self.acceptance_method == "typical_acceptance_sampler" and ( + self.typical_acceptance_sampler_posterior_threshold < 0 + or self.typical_acceptance_sampler_posterior_alpha < 0): + raise ValueError( + "Expected typical_acceptance_sampler_posterior_threshold " + "and typical_acceptance_sampler_posterior_alpha to be > 0. " + "Instead found " + f"typical_acceptance_sampler_posterior_threshold = " + f"{self.typical_acceptance_sampler_posterior_threshold} and " + f"typical_acceptance_sampler_posterior_alpha = " + f"{self.typical_acceptance_sampler_posterior_alpha}") if (self.disable_by_batch_size is not None and self.disable_by_batch_size < 2): From 395b9d6600cc7588fbeae3c38be4c426c68812f8 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Fri, 7 Mar 2025 18:56:18 +0800 Subject: [PATCH 05/21] fix doc and typo Signed-off-by: Shangming Cai --- docs/source/features/spec_decode.md | 39 ++++++++++++++++++----------- vllm/config.py | 2 +- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/docs/source/features/spec_decode.md b/docs/source/features/spec_decode.md index cc8d6fceb7d6..bbf2944652fa 100644 --- a/docs/source/features/spec_decode.md +++ b/docs/source/features/spec_decode.md @@ -30,8 +30,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="facebook/opt-6.7b", tensor_parallel_size=1, - speculative_model="facebook/opt-125m", - num_speculative_tokens=5, + speculative_config={ + "model": "facebook/opt-125m", + "num_speculative_tokens": 5, + }, ) outputs = llm.generate(prompts, sampling_params) @@ -45,10 +47,14 @@ To perform the same with an online mode launch the server: ```bash python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \ - --seed 42 -tp 1 --speculative_model facebook/opt-125m \ - --num_speculative_tokens 5 --gpu_memory_utilization 0.8 + --seed 42 -tp 1 --gpu_memory_utilization 0.8 \ + --speculative_config '{"model": "facebook/opt-125m", "num_speculative_tokens": 5}' ``` +:::{warning} +Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release. +::: + Then use a client: ```python @@ -101,9 +107,11 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="facebook/opt-6.7b", tensor_parallel_size=1, - speculative_model="[ngram]", - num_speculative_tokens=5, - ngram_prompt_lookup_max=4, + speculative_config={ + "proposer": "[ngram]", # Or you can also specify "model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 4, + }, ) outputs = llm.generate(prompts, sampling_params) @@ -131,8 +139,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="meta-llama/Meta-Llama-3.1-70B-Instruct", tensor_parallel_size=4, - speculative_model="ibm-ai-platform/llama3-70b-accelerator", - speculative_draft_tensor_parallel_size=1, + speculative_config={ + "model": "ibm-ai-platform/llama3-70b-accelerator", + "draft_tensor_parallel_size": 1, + }, ) outputs = llm.generate(prompts, sampling_params) @@ -175,8 +185,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="meta-llama/Meta-Llama-3-8B-Instruct", tensor_parallel_size=4, - speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", - speculative_draft_tensor_parallel_size=1, + speculative_config={ + "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", + "draft_tensor_parallel_size": 1, + }, ) outputs = llm.generate(prompts, sampling_params) @@ -194,11 +206,10 @@ A few important things to consider when using the EAGLE based draft models: be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304). If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the [script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model, - and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using - the latest version of vLLM, please leave a comment or raise an issue. + and specify `model="path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue. 2. The EAGLE based draft models need to be run without tensor parallelism - (i.e. speculative_draft_tensor_parallel_size is set to 1), although + (i.e. draft_tensor_parallel_size is set to 1 in `speculative_config`), although it is possible to run the main model using tensor parallelism (see example above). 3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is diff --git a/vllm/config.py b/vllm/config.py index 955d339914e6..5f1880d3391e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2030,7 +2030,7 @@ def __post_init__(self): # Ensure divisibility for MTP module reuse. raise ValueError( f"num_speculative_tokens:{self.num_speculative_tokens}" - f"must be divisible by {n_predict=}") + f" must be divisible by {n_predict=}") self.draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_tp( From cf3030ea8056844df1bba7a555d7c44023020eb1 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Fri, 7 Mar 2025 20:31:11 +0800 Subject: [PATCH 06/21] fix mypy Signed-off-by: Shangming Cai --- vllm/config.py | 176 ++++++++++++++++++++++++++----------------------- 1 file changed, 92 insertions(+), 84 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 5f1880d3391e..da69f56af77c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1789,7 +1789,7 @@ class SpeculativeConfig: Configuration for speculative decoding. Configurable parameters include: - Top-level Speculative Decoding Control: - - num_speculative_tokens (Optional[int]): The number of speculative + - num_speculative_tokens (int): The number of speculative tokens, if provided. It will default to the number in the draft model config if present, otherwise, it is required. - proposer (Optional[str]): The name of the speculative method to use. @@ -1856,8 +1856,8 @@ class SpeculativeConfig: stage times in speculative decoding. """ # speculative configs from cli args - num_speculative_tokens: Optional[int] = field(default=None, - init=True) # type: ignore + num_speculative_tokens: int = field(default=None, + init=True) # type: ignore model: Optional[str] = None proposer: Optional[str] = None quantization: Optional[str] = None @@ -1973,83 +1973,88 @@ def __post_init__(self): self.ngram_prompt_lookup_max = 0 self.ngram_prompt_lookup_min = 0 - self.draft_model_config = ModelConfig( - model=self.model, - task="draft", - tokenizer=self.target_model_config.tokenizer, - tokenizer_mode=self.target_model_config.tokenizer_mode, - trust_remote_code=self.target_model_config.trust_remote_code, - allowed_local_media_path=self.target_model_config. - allowed_local_media_path, - dtype=self.target_model_config.dtype, - seed=self.target_model_config.seed, - revision=self.revision, - code_revision=self.code_revision, - tokenizer_revision=self.target_model_config.tokenizer_revision, - max_model_len=None, - spec_target_max_model_len=self.target_model_config. - max_model_len, - quantization=self.quantization, - enforce_eager=self.target_model_config.enforce_eager, - max_seq_len_to_capture=self.target_model_config. - max_seq_len_to_capture, - max_logprobs=self.target_model_config.max_logprobs, - hf_overrides=SpeculativeConfig.hf_config_override, - ) - - # Detect proposer type or EAGLE prefix to replace hf_config for - # EAGLE draft_model - if (self.proposer == "ealge" - or "eagle-" in self.draft_model_config.model.lower()): - if self.enable_chunked_prefill: - raise ValueError( - "Chunked prefill and EAGLE are not compatible.") + if self.model is not None: + self.draft_model_config = ModelConfig( + model=self.model, + task="draft", + tokenizer=self.target_model_config.tokenizer, + tokenizer_mode=self.target_model_config.tokenizer_mode, + trust_remote_code=self.target_model_config. + trust_remote_code, + allowed_local_media_path=self.target_model_config. + allowed_local_media_path, + dtype=self.target_model_config.dtype, + seed=self.target_model_config.seed, + revision=self.revision, + code_revision=self.code_revision, + tokenizer_revision=self.target_model_config. + tokenizer_revision, + max_model_len=None, + spec_target_max_model_len=self.target_model_config. + max_model_len, + quantization=self.quantization, + enforce_eager=self.target_model_config.enforce_eager, + max_seq_len_to_capture=self.target_model_config. + max_seq_len_to_capture, + max_logprobs=self.target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, + ) - from vllm.transformers_utils.configs.eagle import EAGLEConfig - if isinstance(self.draft_model_config.hf_config, EAGLEConfig): - pass - else: - eagle_config = EAGLEConfig( - self.draft_model_config.hf_config) - self.draft_model_config.hf_config = eagle_config - - if (self.num_speculative_tokens is not None - and hasattr(self.draft_model_config.hf_config, - "num_lookahead_tokens")): - self.draft_model_config.hf_config.num_lookahead_tokens = \ - self.num_speculative_tokens - - n_predict = getattr(self.draft_model_config.hf_config, "n_predict", - None) - if n_predict is not None: - if self.num_speculative_tokens is None: - # Default to max value defined in draft model config. - self.num_speculative_tokens = n_predict - elif self.num_speculative_tokens > n_predict and \ - self.num_speculative_tokens % n_predict != 0: - # Ensure divisibility for MTP module reuse. - raise ValueError( - f"num_speculative_tokens:{self.num_speculative_tokens}" - f" must be divisible by {n_predict=}") - - self.draft_tensor_parallel_size = \ - SpeculativeConfig._verify_and_get_draft_tp( - self.target_parallel_config, - self.draft_tensor_parallel_size, - self.draft_model_config.hf_config - ) + # Detect proposer type or EAGLE prefix to replace hf_config for + # EAGLE draft_model + if (self.proposer == "ealge" + or "eagle-" in self.draft_model_config.model.lower()): + if self.enable_chunked_prefill: + raise ValueError( + "Chunked prefill and EAGLE are not compatible.") + + from vllm.transformers_utils.configs.eagle import ( + EAGLEConfig) + if isinstance(self.draft_model_config.hf_config, + EAGLEConfig): + pass + else: + eagle_config = EAGLEConfig( + self.draft_model_config.hf_config) + self.draft_model_config.hf_config = eagle_config + + if (self.num_speculative_tokens is not None + and hasattr(self.draft_model_config.hf_config, + "num_lookahead_tokens")): + self.draft_model_config.hf_config.num_lookahead_tokens = \ + self.num_speculative_tokens + + n_predict = getattr(self.draft_model_config.hf_config, + "n_predict", None) + if n_predict is not None: + if self.num_speculative_tokens is None: + # Default to max value defined in draft model config. + self.num_speculative_tokens = n_predict + elif self.num_speculative_tokens > n_predict and \ + self.num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. + raise ValueError( + f"num_speculative_tokens:{self.num_speculative_tokens}" + f" must be divisible by {n_predict=}") + + self.draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_tp( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_model_config.hf_config + ) - self.draft_model_config.max_model_len = ( - SpeculativeConfig._maybe_override_draft_max_model_len( - self.max_model_len, - self.draft_model_config.max_model_len, - self.target_model_config.max_model_len, - )) + self.draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + self.max_model_len, + self.draft_model_config.max_model_len, + self.target_model_config.max_model_len, + )) - self.draft_parallel_config = ( - SpeculativeConfig.create_draft_parallel_config( - self.target_parallel_config, - self.draft_tensor_parallel_size)) + self.draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + self.target_parallel_config, + self.draft_tensor_parallel_size)) if self.acceptance_method == "typical_acceptance_sampler": if self.typical_acceptance_sampler_posterior_threshold is None: @@ -2179,15 +2184,17 @@ def _verify_args(self) -> None: f"is {self.acceptance_method}") if self.acceptance_method == "typical_acceptance_sampler" and ( - self.typical_acceptance_sampler_posterior_threshold < 0 - or self.typical_acceptance_sampler_posterior_alpha < 0): + (self.typical_acceptance_sampler_posterior_threshold is not None + and self.typical_acceptance_sampler_posterior_threshold < 0) or + (self.typical_acceptance_sampler_posterior_alpha is not None + and self.typical_acceptance_sampler_posterior_alpha < 0)): raise ValueError( "Expected typical_acceptance_sampler_posterior_threshold " - "and typical_acceptance_sampler_posterior_alpha to be > 0. " - "Instead found " + "and typical_acceptance_sampler_posterior_alpha to be > 0." + " Instead found " f"typical_acceptance_sampler_posterior_threshold = " - f"{self.typical_acceptance_sampler_posterior_threshold} and " - f"typical_acceptance_sampler_posterior_alpha = " + f"{self.typical_acceptance_sampler_posterior_threshold} " + f"and typical_acceptance_sampler_posterior_alpha = " f"{self.typical_acceptance_sampler_posterior_alpha}") if (self.disable_by_batch_size is not None @@ -2207,7 +2214,8 @@ def num_lookahead_slots(self) -> int: return self.num_speculative_tokens def __repr__(self) -> str: - if self.ngram_prompt_lookup_max > 0: + if (self.ngram_prompt_lookup_max is not None + and self.ngram_prompt_lookup_max > 0): draft_model = "[ngram]" else: draft_model = self.draft_model_config.model From e577c01d82ac576873afea7bd5a5948b82a0b5f8 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Fri, 7 Mar 2025 20:42:56 +0800 Subject: [PATCH 07/21] fix mypy Signed-off-by: Shangming Cai --- vllm/engine/arg_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8c0047d0f38f..793a88b2a38e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1237,6 +1237,8 @@ def create_speculative_config( # Note(Shangming): These parameters are not obtained from the cli arg # '--speculative-config' and must be passed in when creating the engine # config. + + assert isinstance(self.speculative_config, dict) self.speculative_config.update({ "target_model_config": target_model_config, "target_parallel_config": target_parallel_config, From 1c5530a75d8cc5fddc58eb0d8f051e6eda5c6365 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Sat, 8 Mar 2025 01:17:54 +0800 Subject: [PATCH 08/21] fix typo Signed-off-by: Shangming Cai --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index da69f56af77c..94f4b49d31cb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2002,7 +2002,7 @@ def __post_init__(self): # Detect proposer type or EAGLE prefix to replace hf_config for # EAGLE draft_model - if (self.proposer == "ealge" + if (self.proposer == "eagle" or "eagle-" in self.draft_model_config.model.lower()): if self.enable_chunked_prefill: raise ValueError( From 52dfefd49f6c78996403caee73a8b1403de0bbc5 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Wed, 12 Mar 2025 16:38:38 +0800 Subject: [PATCH 09/21] Update doc. Signed-off-by: Shangming Cai --- docs/source/features/spec_decode.md | 2 +- examples/offline_inference/mlpspeculator.py | 4 +- .../spec_decode/e2e/test_ngram_correctness.py | 1 - tests/v1/e2e/test_ngram_spec_decode.py | 10 +++-- vllm/config.py | 38 +++++++++++-------- 5 files changed, 33 insertions(+), 22 deletions(-) diff --git a/docs/source/features/spec_decode.md b/docs/source/features/spec_decode.md index bbf2944652fa..c71e2d1477c9 100644 --- a/docs/source/features/spec_decode.md +++ b/docs/source/features/spec_decode.md @@ -206,7 +206,7 @@ A few important things to consider when using the EAGLE based draft models: be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304). If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the [script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model, - and specify `model="path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue. + and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue. 2. The EAGLE based draft models need to be run without tensor parallelism (i.e. draft_tensor_parallel_size is set to 1 in `speculative_config`), although diff --git a/examples/offline_inference/mlpspeculator.py b/examples/offline_inference/mlpspeculator.py index 61641245de83..380c53fab220 100644 --- a/examples/offline_inference/mlpspeculator.py +++ b/examples/offline_inference/mlpspeculator.py @@ -50,7 +50,9 @@ def time_generation(llm: LLM, prompts: list[str], # Create an LLM with spec decoding llm = LLM( model="meta-llama/Llama-2-13b-chat-hf", - speculative_model="ibm-ai-platform/llama-13b-accelerator", + speculative_config={ + "model": "ibm-ai-platform/llama-13b-accelerator", + }, ) print("With speculation") diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 96f50f9ec43c..418a95e33aa8 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -344,7 +344,6 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, "proposer": "[ngram]", "num_speculative_tokens": 5, "ngram_prompt_lookup_max": 3, - "disable_by_batch_size": 4, "disable_mqa_scorer": True, }, }]) diff --git a/tests/v1/e2e/test_ngram_spec_decode.py b/tests/v1/e2e/test_ngram_spec_decode.py index 150caa150a59..2afb56ebacfc 100644 --- a/tests/v1/e2e/test_ngram_spec_decode.py +++ b/tests/v1/e2e/test_ngram_spec_decode.py @@ -37,10 +37,12 @@ def test_ngram_correctness(monkeypatch, test_prompts, sampling_config, del ref_llm spec_llm = LLM(model=model_name, - speculative_model='[ngram]', - ngram_prompt_lookup_max=5, - ngram_prompt_lookup_min=3, - num_speculative_tokens=3) + speculative_config={ + "proposer": "[ngram]", + "ngram_prompt_lookup_max": 5, + "ngram_prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }) spec_outputs = spec_llm.generate(test_prompts, sampling_config) for ref_output, spec_output in zip(ref_outputs, spec_outputs): assert ref_output.outputs[0].text == spec_output.outputs[0].text, \ diff --git a/vllm/config.py b/vllm/config.py index 94f4b49d31cb..c4aac28ab86e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1799,28 +1799,34 @@ class SpeculativeConfig: 'typical_acceptance_sampler' for RejectionSampler and TypicalAcceptanceSampler respectively. If not specified, it defaults to 'rejection_sampler'. + - draft_tensor_parallel_size (Optional[int]): The degree of the tensor + parallelism for the draft model. Can be 1 or match the target + model's tensor parallel size. - disable_logprobs (bool): If set to True, token log probabilities are not returned during speculative decoding. If set to False, token log probabilities are returned according to the log probability settings in SamplingParams. If not specified, it defaults to True. - Model Configuration: - - model (Optional[str]): The name of the speculative model, - if provided. + - model (Optional[str]): The name of the draft model, if provided. - quantization (Optional[str]): Quantization method that was used to - quantize the speculative model weights. If None, we assume the - model weights are not quantized. + quantize the draft model weights. If None, we assume the + model weights are not quantized. Note that it only takes effect + when using the draft model-based speculative method. - max_model_len (Optional[int]): The maximum model length of the - speculative model. Used when testing the ability to skip + draft model. Used when testing the ability to skip speculation for some sequences. - - draft_tensor_parallel_size (Optional[int]): The degree of the tensor - parallelism for the draft model. Can be 1 or match the target - model's tensor parallel size. + - revision: The specific model version to use for the draft model. It + can be a branch name, a tag name, or a commit id. If unspecified, + will use the default version. + - code_revision: The specific revision to use for the draft model code + on Hugging Face Hub. It can be a branch name, a tag name, or a + commit id. If unspecified, will use the default version. - Advanced Token Control: - - disable_mqa_scorer (bool): Disable the MQA scorer for the speculative - model and fall back to batch expansion for scoring. If not - specified, it defaults to False. + - disable_mqa_scorer (bool): Disable the MQA scorer and fall back to + batch expansion for scoring proposals. If not specified, it + defaults to False. - disable_by_batch_size (Optional[int]): Disable speculative decoding for new incoming requests when the number of enqueued requests is larger than this value, if provided. @@ -1858,21 +1864,23 @@ class SpeculativeConfig: # speculative configs from cli args num_speculative_tokens: int = field(default=None, init=True) # type: ignore - model: Optional[str] = None proposer: Optional[str] = None + acceptance_method: str = "rejection_sampler" + draft_tensor_parallel_size: Optional[int] = None + disable_logprobs: bool = True + + model: Optional[str] = None quantization: Optional[str] = None max_model_len: Optional[int] = None revision: Optional[str] = None code_revision: Optional[str] = None - draft_tensor_parallel_size: Optional[int] = None + disable_mqa_scorer: bool = False disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None - acceptance_method: str = "rejection_sampler" typical_acceptance_sampler_posterior_threshold: Optional[float] = None typical_acceptance_sampler_posterior_alpha: Optional[float] = None - disable_logprobs: bool = True # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, From dbb695763533cc8c6f3632a7603dfcae76263124 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Mon, 17 Mar 2025 16:59:21 +0800 Subject: [PATCH 10/21] fix mypy due to PR 13726 Signed-off-by: Shangming Cai --- vllm/engine/arg_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 39975a4415a2..3801f155b3a8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1298,6 +1298,8 @@ def create_engine_config( else: self._set_default_args_v0(model_config) + assert self.enable_chunked_prefill is not None + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, From b5d59361fa0921300ea783cec894e90c78b6fd5f Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Wed, 19 Mar 2025 10:52:02 +0800 Subject: [PATCH 11/21] minor Signed-off-by: Shangming Cai --- vllm/config.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 8cde19f26748..c2d0d7c762c2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1811,14 +1811,14 @@ class SpeculativeConfig: TypicalAcceptanceSampler respectively. If not specified, it defaults to 'rejection_sampler'. - draft_tensor_parallel_size (Optional[int]): The degree of the tensor - parallelism for the draft model. Can be 1 or match the target - model's tensor parallel size. + parallelism for the draft model. Can only be 1 or the same as the + target model's tensor parallel size. - disable_logprobs (bool): If set to True, token log probabilities are not returned during speculative decoding. If set to False, token log probabilities are returned according to the log probability settings in SamplingParams. If not specified, it defaults to True. - - Model Configuration: + - Draft Model Configuration: - model (Optional[str]): The name of the draft model, if provided. - quantization (Optional[str]): Quantization method that was used to quantize the draft model weights. If None, we assume the @@ -1834,7 +1834,7 @@ class SpeculativeConfig: on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. - - Advanced Token Control: + - Advanced Control: - disable_mqa_scorer (bool): Disable the MQA scorer and fall back to batch expansion for scoring proposals. If not specified, it defaults to False. @@ -1956,18 +1956,17 @@ def __post_init__(self): if self.model is None and self.num_speculative_tokens is not None: # TODO(Shangming): Refactor mtp configuration logic when supporting # mtp acceleration for more models besides deepseek_v3 - if (self.proposer == "mtp" - or self.target_model_config.hf_text_config.model_type - == "deepseek_v3"): + if self.target_model_config.hf_text_config.model_type \ + == "deepseek_v3": # use the draft model from the same model: self.model = self.target_model_config.model - elif self.proposer in ["ngram", "[ngram]"]: + elif "ngram" in self.proposer: self.model = self.proposer else: raise ValueError("num_speculative_tokens was provided without " "speculative model.") - if self.proposer in ["ngram", "[ngram]"]: + if "ngram" in self.proposer: if self.ngram_prompt_lookup_min is None: self.ngram_prompt_lookup_min = 1 if (self.ngram_prompt_lookup_max is None @@ -2190,7 +2189,7 @@ def _verify_args(self) -> None: self.draft_parallel_config) # Validate and set draft token acceptance related settings. - if (self.acceptance_method is None): + if self.acceptance_method is None: raise ValueError("acceptance_method is not set. " "Expected values are rejection_sampler or " "typical_acceptance_sampler.") From c8700d9ad7b782c4143823a6bb621deea345745d Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Wed, 19 Mar 2025 11:10:12 +0800 Subject: [PATCH 12/21] minor fix for mypy Signed-off-by: Shangming Cai --- vllm/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c2d0d7c762c2..9f37110e1edd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1960,13 +1960,13 @@ def __post_init__(self): == "deepseek_v3": # use the draft model from the same model: self.model = self.target_model_config.model - elif "ngram" in self.proposer: + elif self.proposer in ("ngram", "[ngram]"): self.model = self.proposer else: raise ValueError("num_speculative_tokens was provided without " "speculative model.") - if "ngram" in self.proposer: + if self.proposer in ("ngram", "[ngram]"): if self.ngram_prompt_lookup_min is None: self.ngram_prompt_lookup_min = 1 if (self.ngram_prompt_lookup_max is None From 41de64662c995b4093eb0f15a45e41169e3acbf3 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Wed, 19 Mar 2025 11:56:37 +0800 Subject: [PATCH 13/21] add docstring Signed-off-by: Shangming Cai --- vllm/config.py | 3 ++- vllm/engine/arg_utils.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9f37110e1edd..34afb0baea5c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1804,7 +1804,8 @@ class SpeculativeConfig: tokens, if provided. It will default to the number in the draft model config if present, otherwise, it is required. - proposer (Optional[str]): The name of the speculative method to use. - Defaults to the model name if not provided. + If not provided, it assumes the speculative method is a model-based + method by default. - acceptance_method (str): The method to use for accepting draft tokens. This can take two possible values: 'rejection_sampler' and 'typical_acceptance_sampler' for RejectionSampler and diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3801f155b3a8..fd537918d55f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1195,7 +1195,19 @@ def create_speculative_config( enable_chunked_prefill: bool, disable_log_stats: bool, ) -> Optional["SpeculativeConfig"]: - + """Initializes and returns a SpeculativeConfig object based on + `speculative_config`. + + This function utilizes `speculative_config` to create a + SpeculativeConfig object. The `speculative_config` can either be + provided as a JSON string input via CLI arguments or directly as a + dictionary from the engine. If `speculative_config` is not set, this + function will attempt to construct a configuration dictionary using + certain parameters, which are scheduled for deprecation in the next + release. Note that in next releases, `speculative_config` must be + provided, and the deprecated standalone speculative-related parameters + will be removed. + """ if self.speculative_config is None: if (self.speculative_model is None and self.num_speculative_tokens is None): From f2da77920f9184c77ce65de7c8ec6244744dfd97 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Thu, 20 Mar 2025 13:01:52 +0800 Subject: [PATCH 14/21] refactor again Signed-off-by: Shangming Cai --- docs/source/features/spec_decode.md | 4 +- tests/spec_decode/e2e/test_integration.py | 2 +- .../e2e/test_integration_dist_tp2.py | 2 +- .../spec_decode/e2e/test_ngram_correctness.py | 48 +++--- tests/v1/e2e/test_ngram_spec_decode.py | 6 +- vllm/config.py | 158 ++++++++++-------- vllm/engine/arg_utils.py | 8 +- vllm/spec_decode/spec_decode_worker.py | 8 +- 8 files changed, 129 insertions(+), 107 deletions(-) diff --git a/docs/source/features/spec_decode.md b/docs/source/features/spec_decode.md index d6c2b5b528b0..47ade3c1afff 100644 --- a/docs/source/features/spec_decode.md +++ b/docs/source/features/spec_decode.md @@ -108,9 +108,9 @@ llm = LLM( model="facebook/opt-6.7b", tensor_parallel_size=1, speculative_config={ - "proposer": "[ngram]", # Or you can also specify "model": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 4, + "prompt_lookup_max": 4, }, ) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py index 6d6106f05db5..9dfc1b2fd91e 100644 --- a/tests/spec_decode/e2e/test_integration.py +++ b/tests/spec_decode/e2e/test_integration.py @@ -137,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, seed: int): - """Verify that ngram speculative decoding generates the same output + """Verify that speculative decoding generates the same output with batch expansion scorer and mqa scorer. """ run_equality_correctness_test(vllm_runner, diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index 3729629dace7..5c90dce57c1c 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -38,7 +38,7 @@ str({ "model": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "prompt_lookup_max": 3, }), ], ]) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 418a95e33aa8..e2890814d499 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -49,17 +49,17 @@ @pytest.mark.parametrize("test_llm_kwargs", [ { "speculative_config": { - "proposer": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "prompt_lookup_max": 3, "disable_mqa_scorer": False, }, }, { "speculative_config": { - "proposer": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "prompt_lookup_max": 3, "disable_mqa_scorer": True, }, }, @@ -106,17 +106,17 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("test_llm_kwargs", [ { "speculative_config": { - "proposer": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "prompt_lookup_max": 3, "disable_logprobs": False, }, }, { "speculative_config": { - "proposer": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "prompt_lookup_max": 3, "disable_logprobs": True, }, }, @@ -169,17 +169,17 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("test_llm_kwargs", [ { "speculative_config": { - "proposer": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "prompt_lookup_max": 3, }, "enable_chunked_prefill": False, }, { "speculative_config": { - "proposer": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "prompt_lookup_max": 3, "disable_mqa_scorer": True, }, "enable_chunked_prefill": True, @@ -228,9 +228,9 @@ def test_ngram_e2e_greedy_correctness_with_preemption( [ { "speculative_config": { - "proposer": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": k, - "ngram_prompt_lookup_max": 3, + "prompt_lookup_max": 3, }, } # Try a range of common k, as well as large speculation. @@ -238,9 +238,9 @@ def test_ngram_e2e_greedy_correctness_with_preemption( ] + [ { "speculative_config": { - "proposer": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": k, - "ngram_prompt_lookup_max": 1, + "prompt_lookup_max": 1, }, } # Try a range of common k, as well as large speculation. @@ -260,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, seed: int): """Verify that ngram speculative decoding produces exact equality to without spec decode with many different values of k and - different ngram_prompt_lookup_max. + different ngram prompt_lookup_max. """ run_equality_correctness_test(vllm_runner, common_llm_kwargs, @@ -285,16 +285,16 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ "speculative_config": { - "proposer": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "prompt_lookup_max": 3, "disable_by_batch_size": 4 }, }, { "speculative_config": { - "proposer": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "prompt_lookup_max": 3, "disable_by_batch_size": 4, "disable_mqa_scorer": True, }, @@ -316,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, seed: int): """Verify that ngram speculative decoding produces exact equality to without spec decode with many different values of k and - different ngram_prompt_lookup_max. + different ngram prompt_lookup_max. """ run_equality_correctness_test(vllm_runner, common_llm_kwargs, @@ -341,9 +341,9 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ "speculative_config": { - "proposer": "[ngram]", + "method": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "prompt_lookup_max": 3, "disable_mqa_scorer": True, }, }]) diff --git a/tests/v1/e2e/test_ngram_spec_decode.py b/tests/v1/e2e/test_ngram_spec_decode.py index e2a6704a538a..58659cfd9473 100644 --- a/tests/v1/e2e/test_ngram_spec_decode.py +++ b/tests/v1/e2e/test_ngram_spec_decode.py @@ -73,9 +73,9 @@ def test_ngram_correctness( spec_llm = LLM( model=model_name, speculative_config={ - "proposer": "[ngram]", - "ngram_prompt_lookup_max": 5, - "ngram_prompt_lookup_min": 3, + "method": "[ngram]", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, "num_speculative_tokens": 3, }, max_model_len=1024, diff --git a/vllm/config.py b/vllm/config.py index 34afb0baea5c..f9c3cfc36659 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1803,14 +1803,42 @@ class SpeculativeConfig: - num_speculative_tokens (int): The number of speculative tokens, if provided. It will default to the number in the draft model config if present, otherwise, it is required. - - proposer (Optional[str]): The name of the speculative method to use. - If not provided, it assumes the speculative method is a model-based - method by default. + - model (Optional[str]): The name of the draft model, eagle head, + or additional weights, if provided. + - method (Optional[str]): The name of the speculative method to use. + If users provide and set the `model` param, the speculative method + type will be detected automatically, if `model` param is not + provided, the appropriate method name must be provided. + - Possible values: + - ngram: A light speculative method without the need for a + model, which is based on prompt lookup decoding. + - prompt_lookup_max (Optional[int]): + Maximum size of ngram token window when using Ngram + proposer, if provided. + - prompt_lookup_min (Optional[int]): + Minimum size of ngram token window when using Ngram + proposer, if provided. + - eagle + - medusa + - mlp_speculator + - draft_model - acceptance_method (str): The method to use for accepting draft tokens. This can take two possible values: 'rejection_sampler' and 'typical_acceptance_sampler' for RejectionSampler and TypicalAcceptanceSampler respectively. If not specified, it defaults to 'rejection_sampler'. + - Possible values: + - rejection_sampler + - typical_acceptance_sampler + - posterior_threshold (Optional[float]): + A threshold value that sets a lower bound on the + posterior probability of a token in the target model + for it to be accepted. This threshold is used only + when we use the TypicalAcceptanceSampler for token + acceptance. + - posterior_alpha (Optional[float]): + Scaling factor for entropy-based threshold, applied + when using TypicalAcceptanceSampler. - draft_tensor_parallel_size (Optional[int]): The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size. @@ -1842,18 +1870,6 @@ class SpeculativeConfig: - disable_by_batch_size (Optional[int]): Disable speculative decoding for new incoming requests when the number of enqueued requests is larger than this value, if provided. - - ngram_prompt_lookup_max (Optional[int]): Maximum size of ngram token - window when using Ngram proposer, if provided. - - ngram_prompt_lookup_min (Optional[int]): Minimum size of ngram token - window when using Ngram proposer, if provided. - - typical_acceptance_sampler_posterior_threshold (Optional[float]): - A threshold value that sets a lower bound on the posterior - probability of a token in the target model for it to be accepted. - This threshold is used only when we use the - TypicalAcceptanceSampler for token acceptance. - - typical_acceptance_sampler_posterior_alpha (Optional[float]): Scaling - factor for entropy-based threshold, applied when using - TypicalAcceptanceSampler. Non-configurable internal parameters include: - Model Configuration: @@ -1876,7 +1892,7 @@ class SpeculativeConfig: # speculative configs from cli args num_speculative_tokens: int = field(default=None, init=True) # type: ignore - proposer: Optional[str] = None + method: Optional[str] = None acceptance_method: str = "rejection_sampler" draft_tensor_parallel_size: Optional[int] = None disable_logprobs: bool = True @@ -1889,10 +1905,10 @@ class SpeculativeConfig: disable_mqa_scorer: bool = False disable_by_batch_size: Optional[int] = None - ngram_prompt_lookup_max: Optional[int] = None - ngram_prompt_lookup_min: Optional[int] = None - typical_acceptance_sampler_posterior_threshold: Optional[float] = None - typical_acceptance_sampler_posterior_alpha: Optional[float] = None + prompt_lookup_max: Optional[int] = None + prompt_lookup_min: Optional[int] = None + posterior_threshold: Optional[float] = None + posterior_alpha: Optional[float] = None # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, @@ -1945,14 +1961,28 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: return hf_config def __post_init__(self): - if self.proposer is None and self.model is not None: - # Note: After next release, the proposer parameter will be used - # to specify the speculative method, which helps to extend the - # configuration of non-model-based proposers, and the model - # parameter will be used when the draft model or head is needed. - # If users do not specify the proposer, the speculative method will - # be considered as the model-based method by default. - self.proposer = self.model + + # Note: After next release, the method parameter will be used to + # specify the speculative method, which helps to extend the + # configuration of non-model-based proposers, and the model parameter + # will be used when the draft model or head is needed. + # If users do not specify the method, the speculative method will be + # considered as the draft-model-based method by default. + + if self.method is None and self.model is not None: + # Automatically set the method to ensure a smooth transition during + # configuration refactoring. + if self.model in ("ngram", "[ngram]"): + self.method = "ngram" + elif "eagle-" in self.draft_model_config.model.lower(): + self.method = "eagle" + elif self.draft_model_config.hf_config.model_type == "medusa": + self.method = "medusa" + elif (self.draft_model_config.hf_config.model_type == + "mlp_speculator"): + self.method = "mlp_speculator" + else: + self.method = "draft_model" if self.model is None and self.num_speculative_tokens is not None: # TODO(Shangming): Refactor mtp configuration logic when supporting @@ -1961,27 +1991,25 @@ def __post_init__(self): == "deepseek_v3": # use the draft model from the same model: self.model = self.target_model_config.model - elif self.proposer in ("ngram", "[ngram]"): - self.model = self.proposer + elif self.method in ("ngram", "[ngram]"): + self.model = self.method else: raise ValueError("num_speculative_tokens was provided without " "speculative model.") - if self.proposer in ("ngram", "[ngram]"): - if self.ngram_prompt_lookup_min is None: - self.ngram_prompt_lookup_min = 1 - if (self.ngram_prompt_lookup_max is None - or self.ngram_prompt_lookup_max < 1): - raise ValueError("ngram_prompt_lookup_max=" - f"{self.ngram_prompt_lookup_max} must be > 0") - if self.ngram_prompt_lookup_min < 1: - raise ValueError("ngram_prompt_lookup_min=" - f"{self.ngram_prompt_lookup_min} must be > 0") - if self.ngram_prompt_lookup_min > self.ngram_prompt_lookup_max: - raise ValueError( - f"ngram_prompt_lookup_min={self.ngram_prompt_lookup_min} " - "cannot be larger than ngram_prompt_lookup_max=" - f"{self.ngram_prompt_lookup_max}") + if self.method in ("ngram", "[ngram]"): + if self.prompt_lookup_min is None: + self.prompt_lookup_min = 1 + if self.prompt_lookup_max is None or self.prompt_lookup_max < 1: + raise ValueError("prompt_lookup_max=" + f"{self.prompt_lookup_max} must be > 0") + if self.prompt_lookup_min < 1: + raise ValueError("prompt_lookup_min=" + f"{self.prompt_lookup_min} must be > 0") + if self.prompt_lookup_min > self.prompt_lookup_max: + raise ValueError(f"prompt_lookup_min={self.prompt_lookup_min} " + "cannot be larger than prompt_lookup_max=" + f"{self.prompt_lookup_max}") # TODO: current we still need extract vocab_size from target model # config, in future, we may try refactor it out, and set @@ -1989,8 +2017,8 @@ def __post_init__(self): self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config else: - self.ngram_prompt_lookup_max = 0 - self.ngram_prompt_lookup_min = 0 + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 if self.model is not None: self.draft_model_config = ModelConfig( @@ -2019,10 +2047,8 @@ def __post_init__(self): hf_overrides=SpeculativeConfig.hf_config_override, ) - # Detect proposer type or EAGLE prefix to replace hf_config for - # EAGLE draft_model - if (self.proposer == "eagle" - or "eagle-" in self.draft_model_config.model.lower()): + # Replace hf_config for EAGLE draft_model + if self.method == "eagle": if self.enable_chunked_prefill: raise ValueError( "Chunked prefill and EAGLE are not compatible.") @@ -2076,10 +2102,10 @@ def __post_init__(self): self.draft_tensor_parallel_size)) if self.acceptance_method == "typical_acceptance_sampler": - if self.typical_acceptance_sampler_posterior_threshold is None: - self.typical_acceptance_sampler_posterior_threshold = 0.09 - if self.typical_acceptance_sampler_posterior_alpha is None: - self.typical_acceptance_sampler_posterior_alpha = 0.3 + if self.posterior_threshold is None: + self.posterior_threshold = 0.09 + if self.posterior_alpha is None: + self.posterior_alpha = 0.3 self._verify_args() @@ -2203,18 +2229,15 @@ def _verify_args(self) -> None: f"is {self.acceptance_method}") if self.acceptance_method == "typical_acceptance_sampler" and ( - (self.typical_acceptance_sampler_posterior_threshold is not None - and self.typical_acceptance_sampler_posterior_threshold < 0) or - (self.typical_acceptance_sampler_posterior_alpha is not None - and self.typical_acceptance_sampler_posterior_alpha < 0)): + (self.posterior_threshold is not None + and self.posterior_threshold < 0) or + (self.posterior_alpha is not None and self.posterior_alpha < 0)): raise ValueError( - "Expected typical_acceptance_sampler_posterior_threshold " - "and typical_acceptance_sampler_posterior_alpha to be > 0." - " Instead found " - f"typical_acceptance_sampler_posterior_threshold = " - f"{self.typical_acceptance_sampler_posterior_threshold} " - f"and typical_acceptance_sampler_posterior_alpha = " - f"{self.typical_acceptance_sampler_posterior_alpha}") + "Expected the posterior_threshold and posterior_alpha of " + "typical_acceptance_sampler to be > 0. " + "Instead found posterior_threshold = " + f"{self.posterior_threshold} and posterior_alpha = " + f"{self.posterior_alpha}") if (self.disable_by_batch_size is not None and self.disable_by_batch_size < 2): @@ -2233,8 +2256,7 @@ def num_lookahead_slots(self) -> int: return self.num_speculative_tokens def __repr__(self) -> str: - if (self.ngram_prompt_lookup_max is not None - and self.ngram_prompt_lookup_max > 0): + if self.prompt_lookup_max is not None and self.prompt_lookup_max > 0: draft_model = "[ngram]" else: draft_model = self.draft_model_config.model diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fd537918d55f..36f05d4ff394 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1232,12 +1232,12 @@ def create_speculative_config( "disable_mqa_scorer": self.speculative_disable_mqa_scorer, "disable_by_batch_size": self.speculative_disable_by_batch_size, - "ngram_prompt_lookup_max": self.ngram_prompt_lookup_max, - "ngram_prompt_lookup_min": self.ngram_prompt_lookup_min, + "prompt_lookup_max": self.ngram_prompt_lookup_max, + "prompt_lookup_min": self.ngram_prompt_lookup_min, "acceptance_method": self.spec_decoding_acceptance_method, - "typical_acceptance_sampler_posterior_threshold": + "posterior_threshold": self.typical_acceptance_sampler_posterior_threshold, - "typical_acceptance_sampler_posterior_alpha": + "posterior_alpha": self.typical_acceptance_sampler_posterior_alpha, "disable_logprobs": self.disable_logprobs_during_spec_decoding, } diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index f6b201d0e548..a724beade129 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -92,8 +92,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": # Override draft-model specific worker args. draft_worker_kwargs.update( vllm_config=draft_worker_config, - ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, - ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, + ngram_prompt_lookup_max=speculative_config.prompt_lookup_max, + ngram_prompt_lookup_min=speculative_config.prompt_lookup_min, ) spec_decode_worker = SpecDecodeWorker.create_worker( @@ -103,9 +103,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": disable_by_batch_size=speculative_config.disable_by_batch_size, draft_token_acceptance_method=speculative_config.acceptance_method, typical_acceptance_sampler_posterior_threshold=speculative_config. - typical_acceptance_sampler_posterior_threshold, + posterior_threshold, typical_acceptance_sampler_posterior_alpha=speculative_config. - typical_acceptance_sampler_posterior_alpha, + posterior_alpha, disable_logprobs=speculative_config.disable_logprobs, disable_log_stats=speculative_config.disable_log_stats, num_speculative_tokens=speculative_config.num_speculative_tokens, From ab66f51ff0d137137737ea5baf8758b68f543325 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Thu, 20 Mar 2025 13:55:52 +0800 Subject: [PATCH 15/21] minor fix for docstring Signed-off-by: Shangming Cai --- vllm/config.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index f9c3cfc36659..4450c51db75e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1799,7 +1799,7 @@ class SpeculativeConfig: """ Configuration for speculative decoding. Configurable parameters include: - - Top-level Speculative Decoding Control: + - General Speculative Decoding Control: - num_speculative_tokens (int): The number of speculative tokens, if provided. It will default to the number in the draft model config if present, otherwise, it is required. @@ -1807,17 +1807,17 @@ class SpeculativeConfig: or additional weights, if provided. - method (Optional[str]): The name of the speculative method to use. If users provide and set the `model` param, the speculative method - type will be detected automatically, if `model` param is not - provided, the appropriate method name must be provided. + type will be detected automatically if possible, if `model` param + is not provided, the method name must be provided. - Possible values: - - ngram: A light speculative method without the need for a - model, which is based on prompt lookup decoding. + - ngram + Related additional configuration: - prompt_lookup_max (Optional[int]): Maximum size of ngram token window when using Ngram - proposer, if provided. + proposer, required when method is set to ngram. - prompt_lookup_min (Optional[int]): Minimum size of ngram token window when using Ngram - proposer, if provided. + proposer, if provided. Defaults to 1. - eagle - medusa - mlp_speculator @@ -1830,6 +1830,7 @@ class SpeculativeConfig: - Possible values: - rejection_sampler - typical_acceptance_sampler + Related additional configuration: - posterior_threshold (Optional[float]): A threshold value that sets a lower bound on the posterior probability of a token in the target model @@ -1871,6 +1872,9 @@ class SpeculativeConfig: for new incoming requests when the number of enqueued requests is larger than this value, if provided. + Although the parameters above are structured hierarchically, there is no + need to nest them during configuration. + Non-configurable internal parameters include: - Model Configuration: - target_model_config (ModelConfig): The configuration of the target From 4fec12917d244d6e54d0080b77f2ad39e071271f Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Thu, 20 Mar 2025 14:10:03 +0800 Subject: [PATCH 16/21] fix doc Signed-off-by: Shangming Cai --- vllm/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 4450c51db75e..0a042829cba1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1849,7 +1849,6 @@ class SpeculativeConfig: settings in SamplingParams. If not specified, it defaults to True. - Draft Model Configuration: - - model (Optional[str]): The name of the draft model, if provided. - quantization (Optional[str]): Quantization method that was used to quantize the draft model weights. If None, we assume the model weights are not quantized. Note that it only takes effect From 348b3f2d6dd0c82beeb482c8d4756fa2c5f25cf9 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Thu, 20 Mar 2025 14:49:46 +0800 Subject: [PATCH 17/21] minor Signed-off-by: Shangming Cai --- vllm/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 0a042829cba1..913d58b18ed3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1969,8 +1969,10 @@ def __post_init__(self): # specify the speculative method, which helps to extend the # configuration of non-model-based proposers, and the model parameter # will be used when the draft model or head is needed. - # If users do not specify the method, the speculative method will be - # considered as the draft-model-based method by default. + # If users do not specify the method, the speculative method will + # be detected automatically if possible. If the speculative method can + # not be detected, it will be considered as the draft-model-based + # method by default. if self.method is None and self.model is not None: # Automatically set the method to ensure a smooth transition during From b4bcec17c40952cf301dc55ad5496bc2ac8ef4fd Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Thu, 20 Mar 2025 15:21:12 +0800 Subject: [PATCH 18/21] minor Signed-off-by: Shangming Cai --- vllm/config.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 913d58b18ed3..dded5f2ccfb5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1979,15 +1979,8 @@ def __post_init__(self): # configuration refactoring. if self.model in ("ngram", "[ngram]"): self.method = "ngram" - elif "eagle-" in self.draft_model_config.model.lower(): + elif "eagle-" in self.model.lower(): self.method = "eagle" - elif self.draft_model_config.hf_config.model_type == "medusa": - self.method = "medusa" - elif (self.draft_model_config.hf_config.model_type == - "mlp_speculator"): - self.method = "mlp_speculator" - else: - self.method = "draft_model" if self.model is None and self.num_speculative_tokens is not None: # TODO(Shangming): Refactor mtp configuration logic when supporting @@ -2052,6 +2045,14 @@ def __post_init__(self): hf_overrides=SpeculativeConfig.hf_config_override, ) + if self.draft_model_config.hf_config.model_type == "medusa": + self.method = "medusa" + elif (self.draft_model_config.hf_config.model_type == + "mlp_speculator"): + self.method = "mlp_speculator" + else: + self.method = "draft_model" + # Replace hf_config for EAGLE draft_model if self.method == "eagle": if self.enable_chunked_prefill: From 91ffe36bde75b11835e51c9c8ea26bbb861efd2d Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Thu, 20 Mar 2025 16:39:36 +0800 Subject: [PATCH 19/21] fix v1 Signed-off-by: Shangming Cai --- docs/source/features/spec_decode.md | 2 +- tests/spec_decode/e2e/conftest.py | 2 +- vllm/config.py | 6 ++++-- vllm/v1/worker/gpu_model_runner.py | 7 +++---- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/source/features/spec_decode.md b/docs/source/features/spec_decode.md index 47ade3c1afff..3e1f1d5be752 100644 --- a/docs/source/features/spec_decode.md +++ b/docs/source/features/spec_decode.md @@ -108,7 +108,7 @@ llm = LLM( model="facebook/opt-6.7b", tensor_parallel_size=1, speculative_config={ - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 4, }, diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index fe4a1c13fc73..921081f3c3f2 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -56,7 +56,7 @@ def generate(): def maybe_assert_ngram_worker(llm): # Verify the proposer worker is ngram if ngram is specified. if (llm.llm_engine.speculative_config is not None - and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0): + and llm.llm_engine.speculative_config.method == "ngram"): from vllm.spec_decode.ngram_worker import NGramWorker assert isinstance( llm.llm_engine.model_executor.driver_worker.proposer_worker, diff --git a/vllm/config.py b/vllm/config.py index dded5f2ccfb5..f1cfa0bcc7ef 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1990,12 +1990,14 @@ def __post_init__(self): # use the draft model from the same model: self.model = self.target_model_config.model elif self.method in ("ngram", "[ngram]"): - self.model = self.method + self.model = "ngram" else: raise ValueError("num_speculative_tokens was provided without " "speculative model.") if self.method in ("ngram", "[ngram]"): + # Unified to "ngram" internally + self.method = "ngram" if self.prompt_lookup_min is None: self.prompt_lookup_min = 1 if self.prompt_lookup_max is None or self.prompt_lookup_max < 1: @@ -2263,7 +2265,7 @@ def num_lookahead_slots(self) -> int: def __repr__(self) -> str: if self.prompt_lookup_max is not None and self.prompt_lookup_max > 0: - draft_model = "[ngram]" + draft_model = "ngram" else: draft_model = self.draft_model_config.model num_spec_tokens = self.num_speculative_tokens diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 66015382bfe8..c28225f6cdd2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -150,8 +150,7 @@ def __init__( if self.speculative_config: self.use_spec_decode = True self.rejection_sampler = RejectionSampler() - # TODO: find a better way to check if we are using ngram. - assert self.speculative_config.ngram_prompt_lookup_min, \ + assert self.speculative_config.method == "ngram", \ "Currently, only ngram spec decode is supported in V1." if get_pp_group().is_last_rank: self.drafter = NgramProposer() @@ -159,7 +158,7 @@ def __init__( # This usually takes less than 1 second. self.drafter.propose( np.zeros(1024, dtype=np.int32), - self.speculative_config.ngram_prompt_lookup_min, + self.speculative_config.prompt_lookup_min, self.speculative_config.num_speculative_tokens, ) @@ -1115,7 +1114,7 @@ def generate_draft_token_ids( self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx], - self.speculative_config.ngram_prompt_lookup_min, + self.speculative_config.prompt_lookup_min, self.speculative_config.num_speculative_tokens, ) if drafter_output is None or len(drafter_output) == 0: From a81e95388831f11477e44848a5da6681ccb9d0f4 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Thu, 20 Mar 2025 16:54:25 +0800 Subject: [PATCH 20/21] unified test conifg Signed-off-by: Shangming Cai --- .../e2e/test_integration_dist_tp2.py | 2 +- .../spec_decode/e2e/test_ngram_correctness.py | 22 +++++++++---------- tests/v1/e2e/test_ngram_spec_decode.py | 2 +- vllm/engine/arg_utils.py | 5 +++-- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index 5c90dce57c1c..b8a2631b9140 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -36,7 +36,7 @@ [ "--speculative_config", str({ - "model": "[ngram]", + "model": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 3, }), diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index e2890814d499..3af89dc74e7f 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -49,7 +49,7 @@ @pytest.mark.parametrize("test_llm_kwargs", [ { "speculative_config": { - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 3, "disable_mqa_scorer": False, @@ -57,7 +57,7 @@ }, { "speculative_config": { - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 3, "disable_mqa_scorer": True, @@ -106,7 +106,7 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("test_llm_kwargs", [ { "speculative_config": { - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 3, "disable_logprobs": False, @@ -114,7 +114,7 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, }, { "speculative_config": { - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 3, "disable_logprobs": True, @@ -169,7 +169,7 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("test_llm_kwargs", [ { "speculative_config": { - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 3, }, @@ -177,7 +177,7 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, }, { "speculative_config": { - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 3, "disable_mqa_scorer": True, @@ -228,7 +228,7 @@ def test_ngram_e2e_greedy_correctness_with_preemption( [ { "speculative_config": { - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": k, "prompt_lookup_max": 3, }, @@ -238,7 +238,7 @@ def test_ngram_e2e_greedy_correctness_with_preemption( ] + [ { "speculative_config": { - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": k, "prompt_lookup_max": 1, }, @@ -285,14 +285,14 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ "speculative_config": { - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 3, "disable_by_batch_size": 4 }, }, { "speculative_config": { - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 3, "disable_by_batch_size": 4, @@ -341,7 +341,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ "speculative_config": { - "method": "[ngram]", + "method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 3, "disable_mqa_scorer": True, diff --git a/tests/v1/e2e/test_ngram_spec_decode.py b/tests/v1/e2e/test_ngram_spec_decode.py index 58659cfd9473..7c7c2f02c078 100644 --- a/tests/v1/e2e/test_ngram_spec_decode.py +++ b/tests/v1/e2e/test_ngram_spec_decode.py @@ -73,7 +73,7 @@ def test_ngram_correctness( spec_llm = LLM( model=model_name, speculative_config={ - "method": "[ngram]", + "method": "ngram", "prompt_lookup_max": 5, "prompt_lookup_min": 3, "num_speculative_tokens": 3, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 36f05d4ff394..10148f5dc9e0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1614,7 +1614,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: if (self.speculative_model is not None or self.num_speculative_tokens is not None): # This is supported but experimental (handled below). - if self.speculative_model == "[ngram]": + if self.speculative_model in ("ngram", "[ngram]"): pass else: _raise_or_fallback(feature_name="Speculative Decoding", @@ -1654,7 +1654,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: return False # ngram is supported on V1, but off by default for now. - if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"): + if self.speculative_model in ( + "ngram", "[ngram]") and _warn_or_fallback("ngram"): return False # Non-CUDA is supported on V1, but off by default for now. From 346c8321e902cdbe7fb5910bad358b705b63441d Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Thu, 20 Mar 2025 19:45:50 +0800 Subject: [PATCH 21/21] minor fix for eagle config Signed-off-by: Shangming Cai --- vllm/config.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 703f3c99cd7e..59cf8ad3b898 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1982,14 +1982,6 @@ def __post_init__(self): # not be detected, it will be considered as the draft-model-based # method by default. - if self.method is None and self.model is not None: - # Automatically set the method to ensure a smooth transition during - # configuration refactoring. - if self.model in ("ngram", "[ngram]"): - self.method = "ngram" - elif "eagle-" in self.model.lower(): - self.method = "eagle" - if self.model is None and self.num_speculative_tokens is not None: # TODO(Shangming): Refactor mtp configuration logic when supporting # mtp acceleration for more models besides deepseek_v3 @@ -2003,6 +1995,12 @@ def __post_init__(self): raise ValueError("num_speculative_tokens was provided without " "speculative model.") + # Automatically configure the ngram method during configuration + # refactoring to ensure a smooth transition. + if self.method is None and (self.model is not None + and self.model in ("ngram", "[ngram]")): + self.method = "ngram" + if self.method in ("ngram", "[ngram]"): # Unified to "ngram" internally self.method = "ngram" @@ -2055,7 +2053,10 @@ def __post_init__(self): hf_overrides=SpeculativeConfig.hf_config_override, ) - if self.draft_model_config.hf_config.model_type == "medusa": + # Automatically detect the method + if "eagle-" in self.draft_model_config.model.lower(): + self.method = "eagle" + elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" elif (self.draft_model_config.hf_config.model_type == "mlp_speculator"):