diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index ad726fa8ce51..3d9dfa50093a 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -23,11 +23,23 @@ ] -@pytest.fixture(scope="module") -def llm(): +@pytest.fixture(scope="module", params=["autoregressive", "speculative"]) +def llm(request): + + def get_llm_kwargs(mode: str): + if mode == "autoregressive": + return {} + return { + "speculative_config": { + "model": "Qwen/Qwen2.5-0.5B-Instruct", + "num_speculative_tokens": 3, + }, + } + + test_llm_kwargs = get_llm_kwargs(request.param) # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0) + llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0, **test_llm_kwargs) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6cc9b881464e..affdd5ee68b4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -13,7 +13,8 @@ import vllm.envs as envs from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VllmConfig) + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VllmConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -511,7 +512,8 @@ async def add_request_async( default_guided_backend=self.decoding_config. guided_decoding_backend, reasoning_backend=self.decoding_config.reasoning_backend, - model_config=self.model_config) + model_config=self.model_config, + speculative_config=self.speculative_config) self._add_processed_request( request_id=request_id, @@ -536,9 +538,13 @@ async def collective_rpc_async(self, async def build_guided_decoding_logits_processor_async( - sampling_params: SamplingParams, tokenizer: AnyTokenizer, - default_guided_backend: str, reasoning_backend: Optional[str], - model_config: ModelConfig) -> SamplingParams: + sampling_params: SamplingParams, + tokenizer: AnyTokenizer, + default_guided_backend: str, + reasoning_backend: Optional[str], + model_config: ModelConfig, + speculative_config: Optional[SpeculativeConfig] = None +) -> SamplingParams: """Constructs logits processors based on the guided_decoding, logits_bias, and allowed_token_ids fields in sampling_params. Deletes those fields and adds the constructed logits processors to the @@ -564,7 +570,8 @@ async def build_guided_decoding_logits_processor_async( guided_params=guided_decoding, tokenizer=tokenizer, reasoning_backend=reasoning_backend, - model_config=model_config) + model_config=model_config, + speculative_config=speculative_config) if processor: if sampling_params.logits_processors is None: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c23530990611..15278e269ce3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2102,7 +2102,7 @@ def _build_logits_processors( tokenizer=tokenizer, model_config=self.model_config, reasoning_backend=self.decoding_config.reasoning_backend, - ) + speculative_config=self.speculative_config) if processor: logits_processors.append(processor) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index eb3ae89394ec..81ca0a52b089 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -96,6 +96,7 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, self.vllm_config = engine_config self.model_config = engine_config.model_config self.decoding_config = engine_config.decoding_config + self.speculative_config = engine_config.speculative_config # Create the tokenizer group. self.tokenizer = init_tokenizer_from_configs( @@ -620,6 +621,7 @@ async def _process_request( else DecodingConfig.guided_decoding_backend), model_config=self.model_config, reasoning_backend=self.decoding_config.reasoning_backend, + speculative_config=self.speculative_config, ) # 1) Create output queue for this requests. diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 8fdcdcafa980..8c8a0b49d22d 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from vllm.logger import init_logger from vllm.model_executor.guided_decoding.utils import ( @@ -13,7 +13,7 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer - from vllm.config import ModelConfig + from vllm.config import ModelConfig, SpeculativeConfig from vllm.logits_process import LogitsProcessor from vllm.sampling_params import GuidedDecodingParams @@ -100,6 +100,7 @@ async def get_guided_decoding_logits_processor( guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, model_config: ModelConfig, + speculative_config: Optional[SpeculativeConfig] = None, reasoning_backend: str | None = None) -> LogitsProcessor | None: reasoner = None @@ -126,12 +127,14 @@ async def get_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config, reasoner) + guided_params, tokenizer, model_config, reasoner, + speculative_config) if guided_params.backend_name == 'guidance': from vllm.model_executor.guided_decoding.guidance_decoding import ( get_local_guidance_guided_decoding_logits_processor) return get_local_guidance_guided_decoding_logits_processor( guided_params, tokenizer) + raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'" @@ -139,10 +142,12 @@ async def get_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer: PreTrainedTokenizer, - model_config: ModelConfig, - reasoning_backend: str | None = None) -> LogitsProcessor | None: + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + reasoning_backend: str | None = None, + speculative_config: Optional[SpeculativeConfig] = None +) -> LogitsProcessor | None: guided_params = maybe_backend_fallback(guided_params) reasoner = None @@ -167,7 +172,8 @@ def get_local_guided_decoding_logits_processor( from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa get_local_xgrammar_guided_decoding_logits_processor) return get_local_xgrammar_guided_decoding_logits_processor( - guided_params, tokenizer, model_config, reasoner) + guided_params, tokenizer, model_config, reasoner, + speculative_config) if guided_params.backend_name == 'guidance': from vllm.model_executor.guided_decoding.guidance_decoding import ( get_local_guidance_guided_decoding_logits_processor) diff --git a/vllm/model_executor/guided_decoding/guidance_logits_processors.py b/vllm/model_executor/guided_decoding/guidance_logits_processors.py index 26fcafe31c76..dcbe2a3da9e7 100644 --- a/vllm/model_executor/guided_decoding/guidance_logits_processors.py +++ b/vllm/model_executor/guided_decoding/guidance_logits_processors.py @@ -36,6 +36,7 @@ def __init__( self.tokenizer_name = tokenizer.name_or_path self.new_sampling = False self.initialized = False + self.num_processed_tokens = 0 def _initialize(self): if self.initialized: @@ -69,7 +70,17 @@ def __call__( # to avoid pickling ll_tokenizer and ll_interpreter self._initialize() + if self.num_processed_tokens > 0 and self.num_processed_tokens >= len( + input_ids): + diff = self.num_processed_tokens - len(input_ids) + 1 + self.ll_matcher.rollback(diff) + self.num_processed_tokens -= diff + if self.new_sampling and len(input_ids) > 0: + # The tokens are not truly consumed when the matcher is stopped, + # despite consume_token returning True. This is a workaround. + self.num_processed_tokens += 1 if not self.ll_matcher.is_stopped( + ) else 0 self.ll_matcher.consume_token(input_ids[-1]) err = self.ll_matcher.get_error() if err: diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index ff223c3c9b83..20783f9acbb1 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -6,7 +6,7 @@ import json import re from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any, List, Optional import torch @@ -27,7 +27,7 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer - from vllm.config import ModelConfig + from vllm.config import ModelConfig, SpeculativeConfig from vllm.reasoning import ReasoningParser from vllm.sampling_params import GuidedDecodingParams @@ -39,11 +39,14 @@ def get_local_xgrammar_guided_decoding_logits_processor( tokenizer: PreTrainedTokenizer, model_config: ModelConfig, reasoner: ReasoningParser | None, + speculative_config: Optional[SpeculativeConfig] = None, max_threads: int = 8): - config = GrammarConfig.from_guided_params(guided_params=guided_params, - model_config=model_config, - tokenizer=tokenizer, - max_threads=max_threads) + config = GrammarConfig.from_guided_params( + guided_params=guided_params, + model_config=model_config, + tokenizer=tokenizer, + speculative_config=speculative_config, + max_threads=max_threads) return XGrammarLogitsProcessor(config, reasoner) @@ -154,13 +157,16 @@ class GrammarConfig: any_whitespace: bool = True regex_str: str | None = None max_threads: int = 8 + num_lookahead_slots: int | None = None @classmethod - def from_guided_params(cls, - guided_params: GuidedDecodingParams, - model_config: ModelConfig, - tokenizer: PreTrainedTokenizer, - max_threads: int = 8) -> GrammarConfig: + def from_guided_params( + cls, + guided_params: GuidedDecodingParams, + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, + speculative_config: Optional[SpeculativeConfig] = None, + max_threads: int = 8) -> GrammarConfig: tokenizer_hash = hash(tokenizer) tokenizer_data = TokenizerDataCache.get_tokenizer_data( @@ -168,6 +174,9 @@ def from_guided_params(cls, tokenizer_hash=tokenizer_hash, vocab_size=model_config.hf_text_config.vocab_size, ) + num_lookahead_slots = (speculative_config.num_lookahead_slots + if speculative_config else None) + print("num_lookahead_slots", num_lookahead_slots) if guided_params.json: if not isinstance(guided_params.json, str): @@ -209,7 +218,8 @@ def from_guided_params(cls, tokenizer_hash=tokenizer_hash, max_threads=max_threads, tokenizer_data=tokenizer_data, - any_whitespace=any_whitespace) + any_whitespace=any_whitespace, + num_lookahead_slots=num_lookahead_slots) elif guided_params.grammar: # XGrammar only supports GBNF grammars, so we must convert Lark if grammar_is_likely_lark(guided_params.grammar): @@ -235,13 +245,15 @@ def from_guided_params(cls, return cls(grammar_str=grammar_str, tokenizer_hash=tokenizer_hash, max_threads=max_threads, - tokenizer_data=tokenizer_data) + tokenizer_data=tokenizer_data, + num_lookahead_slots=num_lookahead_slots) elif guided_params.json_object: return cls( json_object=True, tokenizer_hash=tokenizer_hash, max_threads=max_threads, tokenizer_data=tokenizer_data, + num_lookahead_slots=num_lookahead_slots, ) elif guided_params.choice: choice_str = GrammarConfig.choice_as_grammar(guided_params.choice) @@ -255,6 +267,7 @@ def from_guided_params(cls, tokenizer_hash=tokenizer_hash, max_threads=max_threads, tokenizer_data=tokenizer_data, + num_lookahead_slots=num_lookahead_slots, ) elif guided_params.regex: return cls( @@ -262,6 +275,7 @@ def from_guided_params(cls, tokenizer_hash=tokenizer_hash, max_threads=max_threads, tokenizer_data=tokenizer_data, + num_lookahead_slots=num_lookahead_slots, ) else: raise ValueError( @@ -301,7 +315,7 @@ class XGrammarLogitsProcessor: token_bitmask: torch.Tensor = None # type: ignore[assignment] matchers: list[xgr.GrammarMatcher] = field(default_factory=list) batch_size: int = field(default=1) - prefilled: bool = field(default=False) + num_processed_tokens: list[int] = field(default_factory=list) def __post_init__(self): self.tokenizer_info = self.config.tokenizer_info( @@ -320,7 +334,7 @@ def __setstate__(self, state: dict[str, Any]): self.matchers = [] self.batch_size = 1 self.token_bitmask = None # type: ignore[assignment] - self.prefilled = False + self.num_processed_tokens = [] def _ensure_ctx(self): """Lazily initialize the processor in the worker process""" @@ -358,18 +372,29 @@ def __call__(self, input_ids: list[int], self._ensure_ctx() if len(self.matchers) == 0: + max_rollback_tokens = (self.config.num_lookahead_slots + if self.config.num_lookahead_slots else 0) + print("max_rollback_tokens", max_rollback_tokens) self.matchers = [ - xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + xgr.GrammarMatcher(self.ctx, + max_rollback_tokens=max_rollback_tokens) + for _ in range(self.batch_size) ] + self.num_processed_tokens = [0 for _ in range(self.batch_size)] self.token_bitmask = xgr.allocate_token_bitmask( self.batch_size, self.tokenizer_info.vocab_size) - if not self.prefilled: - # Have not sampled a token yet - self.prefilled = True - else: + for i in range(len(self.matchers)): + if self.num_processed_tokens[i] > 0 and self.num_processed_tokens[ + i] >= len(input_ids): + diff = self.num_processed_tokens[i] - len(input_ids) + 1 + self.num_processed_tokens[i] -= diff + self.matchers[i].rollback(diff) + + if len(input_ids) > 0: for i, matcher in enumerate(self.matchers): if not matcher.is_terminated(): + self.num_processed_tokens[i] += 1 sampled_token = input_ids[-1] assert self.matchers[i].accept_token(sampled_token) @@ -392,8 +417,15 @@ def __call__(self, input_ids: list[int], # Note: In this method, if the tensors have different dimensions # on CPU device fails, but on GPU it runs without error. Hence the # unsqueeze above for scores, to match the token bitmask shape + + # A non-blocking copy causes incorrect behaviour in speculative + # decoding. Therefore, we use a blocking copy in the speculative + # decoding case. + speculative_decoding_disabled = self.config.num_lookahead_slots is None xgr.apply_token_bitmask_inplace( - scores, self.token_bitmask.to(scores.device, non_blocking=True)) + scores, + self.token_bitmask.to(scores.device, + non_blocking=speculative_decoding_disabled)) if device_type != "cuda": scores = scores.to(dtype).to(device_type).squeeze() @@ -407,10 +439,18 @@ def clone(self) -> XGrammarLogitsProcessor: # Share the compiled grammar context (immutable after compilation) new_processor.ctx = self.ctx - # Create fresh matchers for the new sequence + # Create fresh matchers for the new sequence and reset + # num_processed_tokens for new sequence if self.ctx is not None: + max_rollback_tokens = (self.config.num_lookahead_slots + if self.config.num_lookahead_slots else 0) new_processor.matchers = [ - xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + xgr.GrammarMatcher(self.ctx, + max_rollback_tokens=max_rollback_tokens) + for _ in range(self.batch_size) + ] + new_processor.num_processed_tokens = [ + 0 for _ in range(self.batch_size) ] # Create a new token bitmask with the same size @@ -419,7 +459,5 @@ def clone(self) -> XGrammarLogitsProcessor: # Copy simple attributes new_processor.batch_size = self.batch_size - # Reset prefilled state for new sequence - new_processor.prefilled = False return new_processor diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 24095ef2a567..c11ac0a0b0ec 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -54,8 +54,8 @@ def __init__(self, model_runner: ModelRunnerBase): self.indices_of_seq_with_bonus_tokens = None - def _update_sampling_metadata(self, sampling_metadata, num_seqs, - num_queries): + def _update_sampling_metadata(self, sampling_metadata, sampled_token_ids, + num_seqs, num_queries): assert sampling_metadata.num_prompts == 0 assert len(sampling_metadata.seq_groups) == num_queries @@ -71,6 +71,15 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, assert seq_group.prompt_logprob_indices == [] # No prompt assert seq_group.sample_indices == [i] # Simple + # Add draft tokens to the output for structural output only + logits_processors = seq_group.sampling_params.logits_processors + for seq_id in seq_group.seq_ids: + if logits_processors is not None: + seq_group.seq_data[seq_id].output_token_ids = ( + *seq_group.seq_data[seq_id].output_token_ids, + sampled_token_ids[i], + ) + def _gpu_advance_step(self, model_input: ModelRunnerInputBase, last_output: SamplerOutput) -> ModelRunnerInputBase: # Currently, we expect "decode mode" only @@ -93,8 +102,8 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase, # Update sampling_metadata sampling_metadata = model_input.sampling_metadata - self._update_sampling_metadata(sampling_metadata, num_seqs, - num_queries) + self._update_sampling_metadata(sampling_metadata, sampled_token_ids, + num_seqs, num_queries) # Create new input new_model_input = self._model_input_cls( diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py index 6275c460ecef..3f3c520e7739 100644 --- a/vllm/spec_decode/mqa_scorer.py +++ b/vllm/spec_decode/mqa_scorer.py @@ -28,6 +28,10 @@ def score_proposals( target_seq_group_metadata_list.append(seq_group_metadata) continue + assert (seq_group_metadata.sampling_params.logits_processors + is None), ("MQAScorer does not support structured output. " + "Use BatchExpansionTop1Scorer instead.") + seq_data_dict = seq_group_metadata.seq_data assert len(seq_data_dict) == 1 seq_id = next(iter(seq_data_dict.keys())) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 57ae173af674..5030fb1073b9 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -86,8 +86,16 @@ def sampler_output( token_prob_list: List[Optional[torch.Tensor]] = [] for idx, seq_group_metadata in enumerate( execute_model_req.seq_group_metadata_list): - seq_data = next(iter(seq_group_metadata.seq_data.values())) + # Turn off NGram speculative proposals when guided output is turned + # on because they propose non-structured tokens, which lead to + # complex logic and negatively affect performance. + if seq_group_metadata.sampling_params.logits_processors is not None: + token_id_list.append(None) + token_prob_list.append(None) + continue + + seq_data = next(iter(seq_group_metadata.seq_data.values())) seq_len = seq_data.get_len() # When seq_len is less than 3072 (3K), we use CPU to perform # the ngram match. Otherwise, we use the device specified in