@@ -2128,139 +2128,113 @@ def __post_init__(self):
21282128 self .device = torch .device (self .device_type )
21292129
21302130
2131+ SpeculativeMethod = Literal ["ngram" , "eagle" , "medusa" , "mlp_speculator" ,
2132+ "draft_model" ]
2133+ SpeculativeAcceptanceMethod = Literal ["rejection_sampler" ,
2134+ "typical_acceptance_sampler" ]
2135+
2136+
2137+ @config
21312138@dataclass
21322139class SpeculativeConfig :
2133- """
2134- Configuration for speculative decoding.
2135- Configurable parameters include:
2136- - General Speculative Decoding Control:
2137- - num_speculative_tokens (int): The number of speculative
2138- tokens, if provided. It will default to the number in the draft
2139- model config if present, otherwise, it is required.
2140- - model (Optional[str]): The name of the draft model, eagle head,
2141- or additional weights, if provided.
2142- - method (Optional[str]): The name of the speculative method to use.
2143- If users provide and set the `model` param, the speculative method
2144- type will be detected automatically if possible, if `model` param
2145- is not provided, the method name must be provided.
2146- - Possible values:
2147- - ngram
2148- Related additional configuration:
2149- - prompt_lookup_max (Optional[int]):
2150- Maximum size of ngram token window when using Ngram
2151- proposer, required when method is set to ngram.
2152- - prompt_lookup_min (Optional[int]):
2153- Minimum size of ngram token window when using Ngram
2154- proposer, if provided. Defaults to 1.
2155- - eagle
2156- - medusa
2157- - mlp_speculator
2158- - draft_model
2159- - acceptance_method (str): The method to use for accepting draft
2160- tokens. This can take two possible values: 'rejection_sampler' and
2161- 'typical_acceptance_sampler' for RejectionSampler and
2162- TypicalAcceptanceSampler respectively. If not specified, it
2163- defaults to 'rejection_sampler'.
2164- - Possible values:
2165- - rejection_sampler
2166- - typical_acceptance_sampler
2167- Related additional configuration:
2168- - posterior_threshold (Optional[float]):
2169- A threshold value that sets a lower bound on the
2170- posterior probability of a token in the target model
2171- for it to be accepted. This threshold is used only
2172- when we use the TypicalAcceptanceSampler for token
2173- acceptance.
2174- - posterior_alpha (Optional[float]):
2175- Scaling factor for entropy-based threshold, applied
2176- when using TypicalAcceptanceSampler.
2177- - draft_tensor_parallel_size (Optional[int]): The degree of the tensor
2178- parallelism for the draft model. Can only be 1 or the same as the
2179- target model's tensor parallel size.
2180- - disable_logprobs (bool): If set to True, token log probabilities are
2181- not returned during speculative decoding. If set to False, token
2182- log probabilities are returned according to the log probability
2183- settings in SamplingParams. If not specified, it defaults to True.
2184-
2185- - Draft Model Configuration:
2186- - quantization (Optional[str]): Quantization method that was used to
2187- quantize the draft model weights. If None, we assume the
2188- model weights are not quantized. Note that it only takes effect
2189- when using the draft model-based speculative method.
2190- - max_model_len (Optional[int]): The maximum model length of the
2191- draft model. Used when testing the ability to skip
2192- speculation for some sequences.
2193- - revision: The specific model version to use for the draft model. It
2194- can be a branch name, a tag name, or a commit id. If unspecified,
2195- will use the default version.
2196- - code_revision: The specific revision to use for the draft model code
2197- on Hugging Face Hub. It can be a branch name, a tag name, or a
2198- commit id. If unspecified, will use the default version.
2140+ """Configuration for speculative decoding."""
21992141
2200- - Advanced Control:
2201- - disable_mqa_scorer (bool): Disable the MQA scorer and fall back to
2202- batch expansion for scoring proposals. If not specified, it
2203- defaults to False.
2204- - disable_by_batch_size (Optional[int]): Disable speculative decoding
2205- for new incoming requests when the number of enqueued requests is
2206- larger than this value, if provided.
2207-
2208- Although the parameters above are structured hierarchically, there is no
2209- need to nest them during configuration.
2210-
2211- Non-configurable internal parameters include:
2212- - Model Configuration:
2213- - target_model_config (ModelConfig): The configuration of the target
2214- model.
2215- - draft_model_config (ModelConfig): The configuration of the draft
2216- model initialized internal.
2217- - Parallelism Configuration:
2218- - target_parallel_config (ParallelConfig): The parallel configuration
2219- for the target model.
2220- - draft_parallel_config (ParallelConfig): The parallel configuration
2221- for the draft model initialized internal.
2222- - Execution Control:
2223- - enable_chunked_prefill (bool): Whether vLLM is configured to use
2224- chunked prefill or not. Used for raising an error since it's not
2225- yet compatible with speculative decode.
2226- - disable_log_stats (bool): Whether to disable the periodic printing of
2227- stage times in speculative decoding.
2228- """
2229- # speculative configs from cli args
2142+ # General speculative decoding control
22302143 num_speculative_tokens : int = field (default = None ,
22312144 init = True ) # type: ignore
2232- method : Optional [str ] = None
2233- acceptance_method : str = "rejection_sampler"
2145+ """The number of speculative tokens, if provided. It will default to the
2146+ number in the draft model config if present, otherwise, it is required."""
2147+ model : Optional [str ] = None
2148+ """The name of the draft model, eagle head, or additional weights, if
2149+ provided."""
2150+ method : Optional [SpeculativeMethod ] = None
2151+ """The name of the speculative method to use. If users provide and set the
2152+ `model` param, the speculative method type will be detected automatically
2153+ if possible, if `model` param is not provided, the method name must be
2154+ provided.
2155+
2156+ If using `ngram` method, the related configuration `prompt_lookup_max` and
2157+ `prompt_lookup_min` should be considered."""
2158+ acceptance_method : SpeculativeAcceptanceMethod = "rejection_sampler"
2159+ """The method to use for accepting draft tokens:\n
2160+ - "rejection_sampler" maps to `RejectionSampler`.\n
2161+ - "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`.
2162+
2163+ If using `typical_acceptance_sampler`, the related configuration
2164+ `posterior_threshold` and `posterior_alpha` should be considered."""
22342165 draft_tensor_parallel_size : Optional [int ] = None
2166+ """The degree of the tensor parallelism for the draft model. Can only be 1
2167+ or the same as the target model's tensor parallel size."""
22352168 disable_logprobs : bool = True
2169+ """If set to True, token log probabilities are not returned during
2170+ speculative decoding. If set to False, token log probabilities are returned
2171+ according to the log probability settings in SamplingParams."""
22362172
2237- model : Optional [ str ] = None
2173+ # Draft model configuration
22382174 quantization : Optional [str ] = None
2175+ """Quantization method that was used to quantize the draft model weights.
2176+ If `None`, we assume the model weights are not quantized. Note that it only
2177+ takes effect when using the draft model-based speculative method."""
22392178 max_model_len : Optional [int ] = None
2179+ """The maximum model length of the draft model. Used when testing the
2180+ ability to skip speculation for some sequences."""
22402181 revision : Optional [str ] = None
2182+ """The specific model version to use for the draft model. It can be a
2183+ branch name, a tag name, or a commit id. If unspecified, will use the
2184+ default version."""
22412185 code_revision : Optional [str ] = None
2186+ """The specific revision to use for the draft model code on Hugging Face
2187+ Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
2188+ will use the default version."""
22422189
2190+ # Advanced control
22432191 disable_mqa_scorer : bool = False
2192+ """Disable the MQA scorer and fall back to batch expansion for scoring
2193+ proposals."""
22442194 disable_by_batch_size : Optional [int ] = None
2195+ """Disable speculative decoding for new incoming requests when the number
2196+ of enqueued requests is larger than this value, if provided."""
2197+
2198+ # Ngram proposer configuration
22452199 prompt_lookup_max : Optional [int ] = None
2200+ """Maximum size of ngram token window when using Ngram proposer, required
2201+ when method is set to ngram."""
22462202 prompt_lookup_min : Optional [int ] = None
2203+ """Minimum size of ngram token window when using Ngram proposer, if
2204+ provided. Defaults to 1."""
2205+
2206+ # Typical acceptance sampler configuration
22472207 posterior_threshold : Optional [float ] = None
2208+ """A threshold value that sets a lower bound on the posterior probability
2209+ of a token in the target model for it to be accepted. This threshold is
2210+ used only when we use the `TypicalAcceptanceSampler` for token acceptance.
2211+ """
22482212 posterior_alpha : Optional [float ] = None
2213+ """Scaling factor for entropy-based threshold, applied when using
2214+ `TypicalAcceptanceSampler`."""
22492215
22502216 # required configuration params passed from engine
22512217 target_model_config : ModelConfig = field (default = None ,
22522218 init = True ) # type: ignore
2219+ """The configuration of the target model."""
22532220 target_parallel_config : ParallelConfig = field (default = None ,
22542221 init = True ) # type: ignore
2222+ """The parallel configuration for the target model."""
22552223 enable_chunked_prefill : bool = field (default = None ,
22562224 init = True ) # type: ignore
2225+ """Whether vLLM is configured to use chunked prefill or not. Used for
2226+ raising an error since it's not yet compatible with speculative decode."""
22572227 disable_log_stats : bool = field (default = None , init = True ) # type: ignore
2228+ """Whether to disable the periodic printing of stage times in speculative
2229+ decoding."""
22582230
22592231 # params generated in the post-init stage
22602232 draft_model_config : ModelConfig = field (default = None ,
22612233 init = True ) # type: ignore
2234+ """The configuration of the draft model initialized internal."""
22622235 draft_parallel_config : ParallelConfig = field (default = None ,
22632236 init = True ) # type: ignore
2237+ """The parallel configuration for the draft model initialized internal."""
22642238
22652239 def compute_hash (self ) -> str :
22662240 """
0 commit comments