From 81aadb6152ebba89c4d019b15206815e53bda7c9 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 14 Apr 2025 06:33:05 +0000 Subject: [PATCH 01/17] chore: migrate tokenizer init to manager only Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 25 +++++++++--- vllm/v1/structured_output/backend_guidance.py | 21 +++------- vllm/v1/structured_output/backend_types.py | 13 +++++++ vllm/v1/structured_output/backend_xgrammar.py | 39 +++++++------------ 4 files changed, 54 insertions(+), 44 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 218af43deb67..43cdcf2beffb 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -7,9 +7,11 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar) +from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend if TYPE_CHECKING: import numpy as np @@ -46,13 +48,26 @@ def grammar_init(self, request: Request) -> None: # backends on a per-request basis in V1 (for now, anyway...). if self.backend is None: backend_name = request.sampling_params.guided_decoding.backend_name + tokenizer_group = init_tokenizer_from_configs( + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + parallel_config=self.vllm_config.parallel_config, + lora_config=self.vllm_config.lora_config) + tokenizer_group.ping() + tokenizer = tokenizer_group.get_lora_tokenizer(None) + vocab_size = self.vllm_config.model_config.get_vocab_size() if backend_name == "xgrammar": - from vllm.v1.structured_output.backend_xgrammar import ( - XgrammarBackend) - - self.backend = XgrammarBackend(self.vllm_config) + self.backend = XgrammarBackend( + self.vllm_config, + tokenizer=tokenizer, + vocab_size=vocab_size, + ) elif backend_name == "guidance": - self.backend = GuidanceBackend(self.vllm_config) + self.backend = GuidanceBackend( + self.vllm_config, + tokenizer=tokenizer, + vocab_size=vocab_size, + ) else: raise ValueError( f"Unsupported structured output backend: {backend_name}") diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 9150a28570bd..e79eb898565c 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -1,15 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import os from dataclasses import dataclass from typing import TYPE_CHECKING, Optional import torch -from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar, @@ -29,25 +29,16 @@ logger = init_logger(__name__) +@dataclass class GuidanceBackend(StructuredOutputBackend): - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config - tokenizer_group = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, - lora_config=vllm_config.lora_config) # type: ignore[arg-type] - tokenizer_group.ping() - self.vllm_config = vllm_config - self.vocab_size = vllm_config.model_config.get_vocab_size() + def __post_init__(self): self.disable_any_whitespace = ( "disable-any-whitespace" - in vllm_config.decoding_config.guided_decoding_backend) + in self.vllm_config.decoding_config.guided_decoding_backend) - tokenizer = tokenizer_group.get_lora_tokenizer(None) self.ll_tokenizer = llguidance_hf.from_tokenizer( - tokenizer, self.vocab_size) + self.tokenizer, self.vocab_size) def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 6dc2a92411de..873c24801707 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -1,10 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import enum from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING import torch +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.transformers_utils.tokenizer import AnyTokenizer + class StructuredOutputOptions(enum.Enum): JSON = enum.auto() @@ -60,9 +68,14 @@ def reset(self): """ +@dataclass class StructuredOutputBackend(ABC): """Engine-level backend for structured output requests.""" + vllm_config: VllmConfig + tokenizer: AnyTokenizer + vocab_size: int + @abstractmethod def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 83f2c6436ed2..174fbbb30988 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -1,14 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from dataclasses import dataclass, field from typing import TYPE_CHECKING import torch import vllm.envs -from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, @@ -23,58 +23,49 @@ logger = init_logger(__name__) +@dataclass class XgrammarBackend(StructuredOutputBackend): - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config + def __post_init__(self): self.disable_any_whitespace = ( "disable-any-whitespace" - in vllm_config.decoding_config.guided_decoding_backend) - tokenizer_group = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, - lora_config=vllm_config.lora_config) # type: ignore[arg-type] - tokenizer_group.ping() - - tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.vocab_size = vllm_config.model_config.get_vocab_size() - if isinstance(tokenizer, MistralTokenizer): + in self.vllm_config.decoding_config.guided_decoding_backend) + if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 try: - if tokenizer.is_tekken: - encoded_vocab = tokenizer._vocab + if self.tokenizer.is_tekken: + encoded_vocab = self.tokenizer._vocab else: encoded_vocab = [ token for token, _ in sorted( - tokenizer.get_vocab().items(), + self.tokenizer.get_vocab().items(), key=lambda x: x[1], ) ] stop_token_ids = None if hasattr( - tokenizer, + self.tokenizer, "eos_token_id", - ) and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] + ) and self.tokenizer.eos_token_id is not None: + stop_token_ids = [self.tokenizer.eos_token_id] except AttributeError as e: raise ValueError( f"Cannot get the vocabulary of the tokenizer " - f"{type(tokenizer)}. The tokenizer should have a " + f"{type(self.tokenizer)}. The tokenizer should have a " "get_vocab method.") from e tokenizer_info = xgr.TokenizerInfo( # type: ignore encoded_vocab=encoded_vocab, # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 vocab_type=xgr.VocabType.RAW - if tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, + if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, vocab_size=self.vocab_size, stop_token_ids=stop_token_ids, add_prefix_space=True, ) else: tokenizer_info = xgr.TokenizerInfo.from_huggingface( - tokenizer, + self.tokenizer, vocab_size=self.vocab_size, ) self.compiler = xgr.GrammarCompiler( From a97b1726fc89995a6ade59f58e8d44e146dd8610 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 17 Apr 2025 19:03:39 +0000 Subject: [PATCH 02/17] --wip-- Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 24 ++++++++------- vllm/v1/structured_output/__init__.py | 29 +++++++++++++++++++ vllm/v1/structured_output/backend_guidance.py | 3 ++ vllm/v1/structured_output/backend_types.py | 6 ++++ vllm/v1/structured_output/backend_xgrammar.py | 5 ++++ 5 files changed, 56 insertions(+), 11 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a81574875a5c..e02d54b43499 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -652,15 +652,16 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and request.use_structured_output: - # NOTE: structured_output_request - # should not be None if use_structured_output, we have - # check above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) - - # Get prompt logprobs for this request. - prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + # --- Jump-forward decoding for structured output requests --- + if request.use_structured_output: + batch_index = scheduler_output.structured_output_request_ids.get( + req_id, 0) + jump_tokens = self.structured_output_manager.jump_forward_tokens( + request, batch_index) + if jump_tokens: + new_token_ids.extend(jump_tokens) + # --- End jump-forward decoding --- + if new_token_ids: # Add EngineCoreOutput for this Request. outputs.append( @@ -669,12 +670,13 @@ def update_from_output( new_token_ids=new_token_ids, finish_reason=request.get_finished_reason(), new_logprobs=new_logprobs, - new_prompt_logprobs_tensors=prompt_logprobs_tensors, + new_prompt_logprobs_tensors=prompt_logprobs_dict.get( + req_id), stop_reason=request.stop_reason, events=request.take_events())) else: # Invariant: EngineCore returns no partial prefill outputs. - assert not prompt_logprobs_tensors + assert not prompt_logprobs_dict.get(req_id) self.scheduled_req_ids.remove(req_id) if not stopped: diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 43cdcf2beffb..266e9fe2af64 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -8,6 +8,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar) @@ -17,8 +18,11 @@ import numpy as np import numpy.typing as npt import torch + import xgrammar.testing as xgr_testing from vllm.v1.request import Request +else: + xgr_testing = LazyLoader('xgr_testing', globals(), 'xgrammar.testing') logger = init_logger(__name__) @@ -122,3 +126,28 @@ def grammar_bitmask( # np.ndarray, because that is much more efficient for serialization # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() + + def jump_forward_tokens(self, request, batch_index) -> list[int]: + """ + For xgrammar-based structured output requests, repeatedly check if the grammar bitmask + is a single-token bitmask, and if so, advance the FSM and collect all jump-forward tokens. + Returns the list of jump-forward token IDs. + """ + so_request = request.structured_output_request + if so_request is None or so_request.grammar is None: + return [] + + jump_tokens = [] + bitmask = torch.zeros(so_request.grammar.vocab_size, dtype=torch.int32) + so_request.grammar.allocate_token_bitmask(1) + so_request.grammar.fill_bitmask(bitmask, 0) + is_single, unique_token_id = xgr_testing._is_single_token_bitmask( + bitmask, so_request.grammar.vocab_size, 0) + while is_single and unique_token_id != -1: + jump_tokens.append(unique_token_id) + so_request.grammar.accept_tokens(request.request_id, + [unique_token_id]) + so_request.grammar.fill_bitmask(bitmask, batch_index) + is_single, unique_token_id = xgr_testing._is_single_token_bitmask( + bitmask, so_request.grammar.vocab_size, 0) + return jump_tokens diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index e79eb898565c..96bc50811ad2 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -119,6 +119,9 @@ def reset(self): # This method may be not needed anymore? TODO self.ll_matcher.reset() + def find_jump_forward_tokens(self) -> list[int]: + raise NotImplementedError + def serialize_guidance_grammar(request_type: StructuredOutputOptions, grammar_spec: str, diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 873c24801707..a80d02966491 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -100,3 +100,9 @@ def allocate_token_bitmask(self, max_num_seqs: int): max_num_seqs (int): The maximum number of sequences for which to allocate the bitmask. """ + + @abstractmethod + def find_jump_forward_tokens(self) -> list[int]: + """ + Finds the tokens that can be used to jump forward in the grammar. + """ \ No newline at end of file diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 174fbbb30988..7ffed43fc5ef 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -147,3 +147,8 @@ def is_terminated(self) -> bool: def reset(self): self.num_processed_tokens = 0 self.matcher.reset() + + def find_jump_forward_tokens(self) -> list[int]: + jf_string = self.matcher.find_jump_forward_string() + return self.tokenizer.decode( + jf_string, skip_special_tokens=True) if jf_string else [] From 26f8a25f22f94b705a48c4f6437808b61e4f5d61 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 26 Apr 2025 13:06:19 +0000 Subject: [PATCH 03/17] chore: remove unused functions Signed-off-by: Aaron Pham --- vllm/v1/structured_output/backend_guidance.py | 3 --- vllm/v1/structured_output/backend_types.py | 6 ------ vllm/v1/structured_output/backend_xgrammar.py | 5 ----- 3 files changed, 14 deletions(-) diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 066050e651e5..603083f65b2d 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -158,9 +158,6 @@ def reset(self): # This method may be not needed anymore? TODO self.ll_matcher.reset() - def find_jump_forward_tokens(self) -> list[int]: - raise NotImplementedError - def serialize_guidance_grammar( request_type: StructuredOutputOptions, diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index e7fbab5cf1e6..7ac8e40f93e7 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -101,12 +101,6 @@ def allocate_token_bitmask(self, max_num_seqs: int): to allocate the bitmask. """ - @abstractmethod - def find_jump_forward_tokens(self) -> list[int]: - """ - Finds the tokens that can be used to jump forward in the grammar. - """ - @abstractmethod def destroy(self): """ diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index b6d230ea394c..957283ee3161 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -164,11 +164,6 @@ def reset(self): self.num_processed_tokens = 0 self.matcher.reset() - def find_jump_forward_tokens(self) -> list[int]: - jf_string = self.matcher.find_jump_forward_string() - return self.tokenizer.decode( - jf_string, skip_special_tokens=True) if jf_string else [] - def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: """Check if JSON schema contains features unsupported by xgrammar.""" From fb4ae733c2cf92b6eabb03dfec1a495d74975188 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 26 Apr 2025 13:21:03 +0000 Subject: [PATCH 04/17] --wip-- Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 10 +++++----- vllm/v1/structured_output/__init__.py | 8 +++----- vllm/v1/structured_output/backend_types.py | 4 ++-- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1822227d13c0..1c2b3281edb0 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -704,15 +704,16 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if request.use_structured_output: + if new_token_ids and request.use_structured_output: batch_index = scheduler_output.structured_output_request_ids.get( # noqa: E501 req_id, 0) jump_tokens = self.structured_output_manager.jump_forward_tokens( # noqa: E501 request, batch_index) if jump_tokens: new_token_ids.extend(jump_tokens) - # --- End jump-forward decoding --- + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: # Add EngineCoreOutput for this Request. outputs.append( @@ -721,13 +722,12 @@ def update_from_output( new_token_ids=new_token_ids, finish_reason=request.get_finished_reason(), new_logprobs=new_logprobs, - new_prompt_logprobs_tensors=prompt_logprobs_dict.get( - req_id), + new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, events=request.take_events())) else: # Invariant: EngineCore returns no partial prefill outputs. - assert not prompt_logprobs_dict.get(req_id) + assert not prompt_logprobs_tensors if not stopped: new_running.append(request) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index b8524ec7d18e..b58b85392e82 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -52,13 +52,11 @@ def grammar_init(self, request: Request) -> None: # backends on a per-request basis in V1 (for now, anyway...). if self.backend is None: backend_name = request.sampling_params.guided_decoding.backend_name - tokenizer_group = init_tokenizer_from_configs( + tokenizer = init_tokenizer_from_configs( model_config=self.vllm_config.model_config, scheduler_config=self.vllm_config.scheduler_config, - parallel_config=self.vllm_config.parallel_config, - lora_config=self.vllm_config.lora_config) - tokenizer_group.ping() - tokenizer = tokenizer_group.get_lora_tokenizer(None) + lora_config=self.vllm_config.lora_config, + ).get_lora_tokenizer(None) vocab_size = self.vllm_config.model_config.get_vocab_size() if backend_name == "xgrammar": self.backend = XgrammarBackend( diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 7ac8e40f93e7..db26261ba428 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -7,9 +7,9 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -import torch - if TYPE_CHECKING: + import torch + from vllm.config import VllmConfig from vllm.transformers_utils.tokenizer import AnyTokenizer From 4ddf58c4e42244c635d48151c9c42b479822b91a Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 26 Apr 2025 14:57:00 +0000 Subject: [PATCH 05/17] feat: working version Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 18 ++++++++--- vllm/v1/outputs.py | 17 ++++++++-- vllm/v1/structured_output/__init__.py | 25 ++++++--------- vllm/v1/worker/gpu_model_runner.py | 46 +++++++++++++++++++-------- vllm/v1/worker/tpu_model_runner.py | 20 +++++++++++- 5 files changed, 90 insertions(+), 36 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1c2b3281edb0..5b3c3fc5e621 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -625,6 +625,7 @@ def update_from_output( spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + structured_output_metadata = model_runner_output.structured_output_metadata # noqa: E501 num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: list[Request] = [] @@ -704,13 +705,22 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and request.use_structured_output: - batch_index = scheduler_output.structured_output_request_ids.get( # noqa: E501 - req_id, 0) + if new_token_ids and request.use_structured_output and ( + (grammar_bitmask := + structured_output_metadata['grammar_bitmask']) is not None + ) and ((struct_out_req_batch_indices := + structured_output_metadata['struct_out_req_batch_indices']) + is not None) and req_id in struct_out_req_batch_indices: jump_tokens = self.structured_output_manager.jump_forward_tokens( # noqa: E501 - request, batch_index) + request, + bitmask=grammar_bitmask, + batch_index=struct_out_req_batch_indices[req_id], + ) if jump_tokens: + print(jump_tokens) new_token_ids.extend(jump_tokens) + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 2732b933c28a..a99aa359f62f 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from dataclasses import dataclass -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional, TypedDict +import numpy as np import torch @@ -41,7 +44,7 @@ def tolists(self): @staticmethod def empty_cpu(num_positions: int, - num_tokens_per_position: int) -> "LogprobsTensors": + num_tokens_per_position: int) -> LogprobsTensors: """Create empty LogprobsTensors on CPU.""" logprob_token_ids = torch.empty( @@ -70,6 +73,11 @@ class SamplerOutput: logprobs_tensors: Optional[LogprobsTensors] +class ModelRunnerStructuredOutputMetadata(TypedDict): + grammar_bitmask: Optional[np.ndarray] + struct_out_req_batch_indices: Optional[dict[str, int]] + + # ModelRunnerOutput is serialized and sent to the scheduler process. # This is expensive for torch.Tensor so prefer to use list instead. @dataclass @@ -99,6 +107,7 @@ class ModelRunnerOutput: # [prompt_len, num_prompt_logprobs] # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] + structured_output_metadata: ModelRunnerStructuredOutputMetadata EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( @@ -108,4 +117,8 @@ class ModelRunnerOutput: spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + structured_output_metadata=ModelRunnerStructuredOutputMetadata( + grammar_bitmask=None, + struct_out_req_batch_indices=None, + ), ) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index b58b85392e82..a8c0578c4893 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -22,6 +22,7 @@ from vllm.v1.request import Request else: + torch = LazyLoader('torch', globals(), 'torch') xgr_testing = LazyLoader('xgr_testing', globals(), 'xgrammar.testing') logger = init_logger(__name__) @@ -125,30 +126,24 @@ def grammar_bitmask( # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() - def jump_forward_tokens(self, request, batch_index) -> list[int]: + def jump_forward_tokens(self, request: Request, bitmask: np.ndarray, + batch_index: int) -> list[int]: """ - For xgrammar-based structured output requests, repeatedly + For structured output requests, repeatedly check if the grammar bitmask is a single-token bitmask, and if so, advance the FSM and collect all jump-forward tokens. Returns the list of jump-forward token IDs. """ - so_request = request.structured_output_request - if so_request is None or so_request.grammar is None: - return [] + if TYPE_CHECKING: + assert request.structured_output_request is not None + assert request.structured_output_request.grammar is not None + assert self.backend is not None jump_tokens = [] - bitmask = torch.zeros(so_request.grammar.vocab_size, dtype=torch.int32) - so_request.grammar.allocate_token_bitmask(1) - so_request.grammar.fill_bitmask(bitmask, 0) is_single, unique_token_id = xgr_testing._is_single_token_bitmask( - bitmask, so_request.grammar.vocab_size, 0) - while is_single and unique_token_id != -1: + torch.from_numpy(bitmask), self.backend.vocab_size, batch_index) + if is_single: jump_tokens.append(unique_token_id) - so_request.grammar.accept_tokens(request.request_id, - [unique_token_id]) - so_request.grammar.fill_bitmask(bitmask, batch_index) - is_single, unique_token_id = xgr_testing._is_single_token_bitmask( - bitmask, so_request.grammar.vocab_size, 0) return jump_tokens def clear_backend(self) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b9b4ce4d19ac..e3cf49651db1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations import gc import time @@ -35,7 +36,8 @@ KVCacheConfig, KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput) + ModelRunnerOutput, + ModelRunnerStructuredOutputMetadata) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -282,7 +284,7 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + def _update_states(self, scheduler_output: SchedulerOutput) -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -487,7 +489,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _prepare_inputs( self, - scheduler_output: "SchedulerOutput", + scheduler_output: SchedulerOutput, ) -> tuple[FlashAttentionMetadata, torch.Tensor, Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -714,7 +716,7 @@ def _compute_cascade_attn_prefix_len( ) return common_prefix_len if use_cascade else 0 - def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): + def _calc_mrope_positions(self, scheduler_output: SchedulerOutput): mrope_pos_ptr = 0 for index, req_id in enumerate(self.input_batch.req_ids): req = self.requests[req_id] @@ -841,7 +843,7 @@ def _calc_spec_decode_metadata( ) return metadata - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + def _execute_mm_encoder(self, scheduler_output: SchedulerOutput): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return @@ -905,7 +907,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): def _gather_mm_embeddings( self, - scheduler_output: "SchedulerOutput", + scheduler_output: SchedulerOutput, ) -> list[torch.Tensor]: mm_embeds: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: @@ -954,14 +956,15 @@ def get_model(self) -> nn.Module: def apply_grammar_bitmask( self, - scheduler_output: "SchedulerOutput", + scheduler_output: SchedulerOutput, logits: torch.Tensor, - ): + ) -> tuple[np.ndarray | None, dict[str, int]]: # Serialization of np.ndarray is much more efficient than a tensor, # so we receive it in that format. grammar_bitmask = scheduler_output.grammar_bitmask if grammar_bitmask is None: - return + # Should not happen if called correctly, but return empty for safety + return None, {} # We receive the structured output bitmask from the scheduler, but the # indices of the requests in the batch may not match the indices of @@ -989,20 +992,26 @@ def apply_grammar_bitmask( req_id] sorted_bitmask[batch_index] = grammar_bitmask[orig_index] grammar_bitmask = sorted_bitmask + # Keep a reference before converting to tensor, as the original numpy + # array might be needed for return + grammar_bitmask_np = grammar_bitmask - grammar_bitmask = torch.from_numpy(grammar_bitmask) + grammar_bitmask_tensor = torch.from_numpy(grammar_bitmask).to( + self.device, non_blocking=True) # TODO: compatibility with spec decode xgr.apply_token_bitmask_inplace( logits, - grammar_bitmask.to(self.device, non_blocking=True), + grammar_bitmask_tensor, indices=list(struct_out_req_batch_indices.values()), ) + # Return the potentially reordered numpy array and the index mapping + return grammar_bitmask_np, struct_out_req_batch_indices @torch.inference_mode() def execute_model( self, - scheduler_output: "SchedulerOutput", + scheduler_output: SchedulerOutput, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: # Update KVConnector with the KVConnector metadata forward(). @@ -1102,8 +1111,13 @@ def execute_model( logits = self.model.compute_logits(sample_hidden_states, None) # Apply structured output bitmasks if present + grammar_bitmask_np: np.ndarray | None = None + struct_out_req_batch_indices: dict[str, int] | None = None if scheduler_output.grammar_bitmask is not None: - self.apply_grammar_bitmask(scheduler_output, logits) + grammar_bitmask_np, struct_out_req_batch_indices = self.apply_grammar_bitmask( # noqa: E501 + scheduler_output, + logits, + ) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata @@ -1277,6 +1291,10 @@ def execute_model( spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + structured_output_metadata=ModelRunnerStructuredOutputMetadata( + grammar_bitmask=grammar_bitmask_np, + struct_out_req_batch_indices=struct_out_req_batch_indices, + ), ) def generate_draft_token_ids( @@ -1343,7 +1361,7 @@ def load_model(self) -> None: def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, - scheduler_output: "SchedulerOutput", + scheduler_output: SchedulerOutput, ) -> dict[str, Optional[LogprobsTensors]]: num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 98b0ddcccb5d..0b2650742dfd 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -34,7 +34,8 @@ KVCacheConfig, KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput) + ModelRunnerOutput, + ModelRunnerStructuredOutputMetadata) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache @@ -784,6 +785,19 @@ def execute_model( if scheduler_output.grammar_bitmask is not None: require_struct_decoding, grammar_bitmask_padded, arange = \ self.prepare_structured_decoding_input(logits, scheduler_output) + # Reconstruct the reordered numpy bitmask and index map for output + # We can reuse the grammar_bitmask_cpu buffer populated above. + num_reqs = logits.shape[0] # Get actual number of reqs + grammar_bitmask_np = self.grammar_bitmask_cpu[:num_reqs].numpy() + # Calculate the index map needed for the output + struct_out_req_batch_indices = {} + for req_id in self.input_batch.req_ids[:num_reqs]: + mask_idx = scheduler_output.structured_output_request_ids.get( + req_id) + if mask_idx is None: + continue + batch_idx = self.input_batch.req_id_to_index[req_id] + struct_out_req_batch_indices[req_id] = batch_idx logits = self.structured_decode(require_struct_decoding, grammar_bitmask_padded, logits, arange) @@ -862,6 +876,10 @@ def execute_model( spec_token_ids=None, logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, + structured_output_metadata=ModelRunnerStructuredOutputMetadata( + grammar_bitmask=grammar_bitmask_np, + struct_out_req_batch_indices=struct_out_req_batch_indices, + ), ) # Check there are no new graphs compiled - all the graphs should be From 7c41ce089761b3803a6c2f9d60f853193916ff03 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 26 Apr 2025 16:51:52 +0000 Subject: [PATCH 06/17] chore: remove debug print Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 25 +++++++++++++------------ vllm/v1/outputs.py | 7 +++---- vllm/v1/structured_output/__init__.py | 20 +++++++++++++------- vllm/v1/worker/gpu_model_runner.py | 20 +++++++++----------- 4 files changed, 38 insertions(+), 34 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5b3c3fc5e621..d3f24fe7177f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -5,7 +5,7 @@ import time from collections import defaultdict, deque from collections.abc import Iterable -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.factory import ( @@ -625,7 +625,7 @@ def update_from_output( spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict - structured_output_metadata = model_runner_output.structured_output_metadata # noqa: E501 + so_metadata = model_runner_output.structured_output_metadata # noqa: E501 num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: list[Request] = [] @@ -705,22 +705,23 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and request.use_structured_output and ( - (grammar_bitmask := - structured_output_metadata['grammar_bitmask']) is not None - ) and ((struct_out_req_batch_indices := - structured_output_metadata['struct_out_req_batch_indices']) - is not None) and req_id in struct_out_req_batch_indices: + grammar_bitmask = so_metadata['grammar_bitmask'] + so_req_batch_indices = so_metadata['struct_out_req_batch_indices'] + if new_token_ids and request.use_structured_output and grammar_bitmask is not None and so_req_batch_indices is not None and req_id in so_req_batch_indices: # noqa: E501 + if TYPE_CHECKING: + assert request.structured_output_request is not None + assert request.structured_output_request.grammar is not None jump_tokens = self.structured_output_manager.jump_forward_tokens( # noqa: E501 request, bitmask=grammar_bitmask, - batch_index=struct_out_req_batch_indices[req_id], + batch_index=so_req_batch_indices[req_id], ) if jump_tokens: - print(jump_tokens) new_token_ids.extend(jump_tokens) - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) + request.structured_output_request.grammar.accept_tokens( + req_id, + new_token_ids, + ) # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index a99aa359f62f..e2437b635498 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - from dataclasses import dataclass from typing import NamedTuple, Optional, TypedDict import numpy as np +import numpy.typing as npt import torch @@ -44,7 +43,7 @@ def tolists(self): @staticmethod def empty_cpu(num_positions: int, - num_tokens_per_position: int) -> LogprobsTensors: + num_tokens_per_position: int) -> "LogprobsTensors": """Create empty LogprobsTensors on CPU.""" logprob_token_ids = torch.empty( @@ -74,7 +73,7 @@ class SamplerOutput: class ModelRunnerStructuredOutputMetadata(TypedDict): - grammar_bitmask: Optional[np.ndarray] + grammar_bitmask: Optional[npt.NDArray[np.int32]] struct_out_req_batch_indices: Optional[dict[str, int]] diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index a8c0578c4893..091a1aff53d1 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -126,25 +126,31 @@ def grammar_bitmask( # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() - def jump_forward_tokens(self, request: Request, bitmask: np.ndarray, - batch_index: int) -> list[int]: + def jump_forward_tokens( + self, + request: Request, + bitmask: npt.NDArray[np.int32], + batch_index: int, + ) -> list[int] | None: """ For structured output requests, repeatedly check if the grammar bitmask is a single-token bitmask, and if so, advance the FSM and collect all jump-forward tokens. Returns the list of jump-forward token IDs. + + We can also consider to perform jump_and_retokenize here as well. """ if TYPE_CHECKING: assert request.structured_output_request is not None assert request.structured_output_request.grammar is not None assert self.backend is not None - jump_tokens = [] is_single, unique_token_id = xgr_testing._is_single_token_bitmask( - torch.from_numpy(bitmask), self.backend.vocab_size, batch_index) - if is_single: - jump_tokens.append(unique_token_id) - return jump_tokens + torch.from_numpy(bitmask), + vocab_size=self.backend.vocab_size, + index=batch_index, + ) + return [unique_token_id] if is_single else None def clear_backend(self) -> None: if self.backend is not None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e3cf49651db1..a54861a3992f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - import gc import time import weakref @@ -284,7 +282,7 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - def _update_states(self, scheduler_output: SchedulerOutput) -> None: + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -489,7 +487,7 @@ def _update_states(self, scheduler_output: SchedulerOutput) -> None: def _prepare_inputs( self, - scheduler_output: SchedulerOutput, + scheduler_output: "SchedulerOutput", ) -> tuple[FlashAttentionMetadata, torch.Tensor, Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -716,7 +714,7 @@ def _compute_cascade_attn_prefix_len( ) return common_prefix_len if use_cascade else 0 - def _calc_mrope_positions(self, scheduler_output: SchedulerOutput): + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 for index, req_id in enumerate(self.input_batch.req_ids): req = self.requests[req_id] @@ -843,7 +841,7 @@ def _calc_spec_decode_metadata( ) return metadata - def _execute_mm_encoder(self, scheduler_output: SchedulerOutput): + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return @@ -907,7 +905,7 @@ def _execute_mm_encoder(self, scheduler_output: SchedulerOutput): def _gather_mm_embeddings( self, - scheduler_output: SchedulerOutput, + scheduler_output: "SchedulerOutput", ) -> list[torch.Tensor]: mm_embeds: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: @@ -956,9 +954,9 @@ def get_model(self) -> nn.Module: def apply_grammar_bitmask( self, - scheduler_output: SchedulerOutput, + scheduler_output: "SchedulerOutput", logits: torch.Tensor, - ) -> tuple[np.ndarray | None, dict[str, int]]: + ) -> tuple[Optional[np.ndarray], dict[str, int]]: # Serialization of np.ndarray is much more efficient than a tensor, # so we receive it in that format. grammar_bitmask = scheduler_output.grammar_bitmask @@ -1111,8 +1109,8 @@ def execute_model( logits = self.model.compute_logits(sample_hidden_states, None) # Apply structured output bitmasks if present - grammar_bitmask_np: np.ndarray | None = None - struct_out_req_batch_indices: dict[str, int] | None = None + grammar_bitmask_np: Optional[np.ndarray] = None + struct_out_req_batch_indices: Optional[dict[str, int]] = None if scheduler_output.grammar_bitmask is not None: grammar_bitmask_np, struct_out_req_batch_indices = self.apply_grammar_bitmask( # noqa: E501 scheduler_output, From 812d684c525c810b7bbbae21657ae85fd311e1c0 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sun, 27 Apr 2025 17:43:33 +0000 Subject: [PATCH 07/17] revert: use scheduler_output bitmask Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 14 +++++++----- vllm/v1/outputs.py | 14 +----------- vllm/v1/worker/gpu_model_runner.py | 34 ++++++++---------------------- vllm/v1/worker/tpu_model_runner.py | 20 +----------------- 4 files changed, 20 insertions(+), 62 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7ac0bc4e7338..9db7a41b2136 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -631,7 +631,6 @@ def update_from_output( spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict - so_metadata = model_runner_output.structured_output_metadata # noqa: E501 num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: list[Request] = [] @@ -711,18 +710,23 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - grammar_bitmask = so_metadata['grammar_bitmask'] - so_req_batch_indices = so_metadata['struct_out_req_batch_indices'] - if new_token_ids and request.use_structured_output and grammar_bitmask is not None and so_req_batch_indices is not None and req_id in so_req_batch_indices: # noqa: E501 + if new_token_ids and request.use_structured_output: + grammar_bitmask = scheduler_output.grammar_bitmask + batch_index = scheduler_output.structured_output_request_ids.get( # noqa: E501 + req_id, + 0, + ) if TYPE_CHECKING: assert request.structured_output_request is not None assert request.structured_output_request.grammar is not None + assert grammar_bitmask is not None jump_tokens = self.structured_output_manager.jump_forward_tokens( # noqa: E501 request, bitmask=grammar_bitmask, - batch_index=so_req_batch_indices[req_id], + batch_index=batch_index, ) if jump_tokens: + print(jump_tokens) new_token_ids.extend(jump_tokens) request.structured_output_request.grammar.accept_tokens( req_id, diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index e2437b635498..2732b933c28a 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,10 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import NamedTuple, Optional, TypedDict +from typing import NamedTuple, Optional -import numpy as np -import numpy.typing as npt import torch @@ -72,11 +70,6 @@ class SamplerOutput: logprobs_tensors: Optional[LogprobsTensors] -class ModelRunnerStructuredOutputMetadata(TypedDict): - grammar_bitmask: Optional[npt.NDArray[np.int32]] - struct_out_req_batch_indices: Optional[dict[str, int]] - - # ModelRunnerOutput is serialized and sent to the scheduler process. # This is expensive for torch.Tensor so prefer to use list instead. @dataclass @@ -106,7 +99,6 @@ class ModelRunnerOutput: # [prompt_len, num_prompt_logprobs] # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] - structured_output_metadata: ModelRunnerStructuredOutputMetadata EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( @@ -116,8 +108,4 @@ class ModelRunnerOutput: spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, - structured_output_metadata=ModelRunnerStructuredOutputMetadata( - grammar_bitmask=None, - struct_out_req_batch_indices=None, - ), ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5d2bdc2f39e9..e3d8b94fe9d7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 + import gc import time import weakref @@ -34,8 +35,7 @@ KVCacheConfig, KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput, - ModelRunnerStructuredOutputMetadata) + ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -956,13 +956,12 @@ def apply_grammar_bitmask( self, scheduler_output: "SchedulerOutput", logits: torch.Tensor, - ) -> tuple[Optional[np.ndarray], dict[str, int]]: + ): # Serialization of np.ndarray is much more efficient than a tensor, # so we receive it in that format. grammar_bitmask = scheduler_output.grammar_bitmask if grammar_bitmask is None: - # Should not happen if called correctly, but return empty for safety - return None, {} + return # We receive the structured output bitmask from the scheduler, but the # indices of the requests in the batch may not match the indices of @@ -990,26 +989,20 @@ def apply_grammar_bitmask( req_id] sorted_bitmask[batch_index] = grammar_bitmask[orig_index] grammar_bitmask = sorted_bitmask - # Keep a reference before converting to tensor, as the original numpy - # array might be needed for return - grammar_bitmask_np = grammar_bitmask - grammar_bitmask_tensor = torch.from_numpy(grammar_bitmask).to( - self.device, non_blocking=True) + grammar_bitmask = torch.from_numpy(grammar_bitmask) # TODO: compatibility with spec decode xgr.apply_token_bitmask_inplace( logits, - grammar_bitmask_tensor, + grammar_bitmask.to(self.device, non_blocking=True), indices=list(struct_out_req_batch_indices.values()), ) - # Return the potentially reordered numpy array and the index mapping - return grammar_bitmask_np, struct_out_req_batch_indices @torch.inference_mode() def execute_model( self, - scheduler_output: SchedulerOutput, + scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: # Update KVConnector with the KVConnector metadata forward(). @@ -1117,13 +1110,8 @@ def execute_model( logits = self.model.compute_logits(sample_hidden_states, None) # Apply structured output bitmasks if present - grammar_bitmask_np: Optional[np.ndarray] = None - struct_out_req_batch_indices: Optional[dict[str, int]] = None if scheduler_output.grammar_bitmask is not None: - grammar_bitmask_np, struct_out_req_batch_indices = self.apply_grammar_bitmask( # noqa: E501 - scheduler_output, - logits, - ) + self.apply_grammar_bitmask(scheduler_output, logits) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata @@ -1297,10 +1285,6 @@ def execute_model( spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, - structured_output_metadata=ModelRunnerStructuredOutputMetadata( - grammar_bitmask=grammar_bitmask_np, - struct_out_req_batch_indices=struct_out_req_batch_indices, - ), ) def generate_draft_token_ids( @@ -1367,7 +1351,7 @@ def load_model(self) -> None: def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, - scheduler_output: SchedulerOutput, + scheduler_output: "SchedulerOutput", ) -> dict[str, Optional[LogprobsTensors]]: num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e3d48860b84d..67f8af29db0e 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -34,8 +34,7 @@ KVCacheConfig, KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput, - ModelRunnerStructuredOutputMetadata) + ModelRunnerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache @@ -784,19 +783,6 @@ def execute_model( if scheduler_output.grammar_bitmask is not None: require_struct_decoding, grammar_bitmask_padded, arange = \ self.prepare_structured_decoding_input(logits, scheduler_output) - # Reconstruct the reordered numpy bitmask and index map for output - # We can reuse the grammar_bitmask_cpu buffer populated above. - num_reqs = logits.shape[0] # Get actual number of reqs - grammar_bitmask_np = self.grammar_bitmask_cpu[:num_reqs].numpy() - # Calculate the index map needed for the output - struct_out_req_batch_indices = {} - for req_id in self.input_batch.req_ids[:num_reqs]: - mask_idx = scheduler_output.structured_output_request_ids.get( - req_id) - if mask_idx is None: - continue - batch_idx = self.input_batch.req_id_to_index[req_id] - struct_out_req_batch_indices[req_id] = batch_idx logits = self.structured_decode(require_struct_decoding, grammar_bitmask_padded, logits, arange) @@ -875,10 +861,6 @@ def execute_model( spec_token_ids=None, logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, - structured_output_metadata=ModelRunnerStructuredOutputMetadata( - grammar_bitmask=grammar_bitmask_np, - struct_out_req_batch_indices=struct_out_req_batch_indices, - ), ) # Check there are no new graphs compiled - all the graphs should be From bf5c46ca4ab53ecbf70a3cf409e305d0d7146f9a Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sun, 27 Apr 2025 17:55:37 +0000 Subject: [PATCH 08/17] chore: remove debug print Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9db7a41b2136..b580f0ce8b86 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -726,7 +726,6 @@ def update_from_output( batch_index=batch_index, ) if jump_tokens: - print(jump_tokens) new_token_ids.extend(jump_tokens) request.structured_output_request.grammar.accept_tokens( req_id, From a1ae3ac6e9ff3da32e802f12b6dcec686a767260 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 28 Apr 2025 15:49:33 +0000 Subject: [PATCH 09/17] chore: move tokenizer to __init__ Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 091a1aff53d1..a2505a4e86ca 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -42,6 +42,11 @@ def __init__(self, vllm_config: VllmConfig): # compilation, so we set it to half the number of CPUs. max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.tokenizer = init_tokenizer_from_configs( + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, + ).get_lora_tokenizer(None) def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: @@ -53,22 +58,17 @@ def grammar_init(self, request: Request) -> None: # backends on a per-request basis in V1 (for now, anyway...). if self.backend is None: backend_name = request.sampling_params.guided_decoding.backend_name - tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config, - scheduler_config=self.vllm_config.scheduler_config, - lora_config=self.vllm_config.lora_config, - ).get_lora_tokenizer(None) vocab_size = self.vllm_config.model_config.get_vocab_size() if backend_name == "xgrammar": self.backend = XgrammarBackend( self.vllm_config, - tokenizer=tokenizer, + tokenizer=self.tokenizer, vocab_size=vocab_size, ) elif backend_name == "guidance": self.backend = GuidanceBackend( self.vllm_config, - tokenizer=tokenizer, + tokenizer=self.tokenizer, vocab_size=vocab_size, ) else: From 535e06ecce8d19124f6044837bf107d976d58c88 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 28 Apr 2025 15:50:29 +0000 Subject: [PATCH 10/17] --wip retokenize-- Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 4 ++++ vllm/v1/structured_output/backend_guidance.py | 3 +++ vllm/v1/structured_output/backend_types.py | 4 ++++ vllm/v1/structured_output/backend_xgrammar.py | 3 +++ 4 files changed, 14 insertions(+) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index a2505a4e86ca..61594869836f 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -150,6 +150,10 @@ def jump_forward_tokens( vocab_size=self.backend.vocab_size, index=batch_index, ) + s = request.structured_output_request.grammar.find_jf_string() + if s: + jf_tokens = self.tokenizer.encode(s, add_special_tokens=False) + print(jf_tokens) return [unique_token_id] if is_single else None def clear_backend(self) -> None: diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 291a6bdf3ee2..a17c7b8e1d0b 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -158,6 +158,9 @@ def reset(self): # This method may be not needed anymore? TODO self.ll_matcher.reset() + def find_jf_string(self): + pass + def serialize_guidance_grammar( request_type: StructuredOutputOptions, diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 163df4afff7b..b196b4791606 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -68,6 +68,10 @@ def reset(self): Resets the state of the structured output grammar. """ + @abstractmethod + def find_jf_string(self): + ... + @dataclass class StructuredOutputBackend(ABC): diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index c9cb20317bd0..1161d745a6a5 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -174,6 +174,9 @@ def reset(self): self.num_processed_tokens = 0 self.matcher.reset() + def find_jf_string(self): + return self.matcher.find_jump_forward_string() + def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: """Check if JSON schema contains features unsupported by xgrammar.""" From 511db485720bf9229a525d0fc266ca284a8d8954 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 29 Apr 2025 10:56:30 +0000 Subject: [PATCH 11/17] feat: jump_and_retokenize Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 47 ++++++----- vllm/v1/structured_output/__init__.py | 84 +++++++++++++------ vllm/v1/structured_output/backend_guidance.py | 9 +- vllm/v1/structured_output/backend_types.py | 22 ++++- vllm/v1/structured_output/backend_xgrammar.py | 20 ++++- 5 files changed, 125 insertions(+), 57 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b580f0ce8b86..d2fd6b56679d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -690,6 +690,31 @@ def update_from_output( new_logprobs = None new_token_ids = generated_token_ids + # NOTE: We will need to first advance the FSM + # given that we apply bitmask in first pass + # and we only perform jump-forward posteriori. + initial_advancement = True + if new_token_ids and request.use_structured_output: + so_request = request.structured_output_request + if TYPE_CHECKING: + assert so_request is not None + assert so_request.grammar is not None + so_request.grammar.accept_tokens(request.request_id, + new_token_ids) + initial_advancement = False + + jump_tokens: list[int] | None = None + if initial_advancement and new_token_ids and request.use_structured_output: # noqa: E501 + # NOTE: We are performing retokenization to handle + # tokenizer boundary edge cases. There will be some + # general overhead incur here. Note that we already + # handle the state of the grammar within + # jump_forward_tokens. + jump_tokens = self.structured_output_manager.jump_forward_tokens( # noqa: E501 + request) + if jump_tokens: + new_token_ids += jump_tokens + # Append generated tokens and check for stop. Note that if # a request is still being prefilled, we expect the model runner # to return empty token ids for the request. @@ -710,28 +735,6 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and request.use_structured_output: - grammar_bitmask = scheduler_output.grammar_bitmask - batch_index = scheduler_output.structured_output_request_ids.get( # noqa: E501 - req_id, - 0, - ) - if TYPE_CHECKING: - assert request.structured_output_request is not None - assert request.structured_output_request.grammar is not None - assert grammar_bitmask is not None - jump_tokens = self.structured_output_manager.jump_forward_tokens( # noqa: E501 - request, - bitmask=grammar_bitmask, - batch_index=batch_index, - ) - if jump_tokens: - new_token_ids.extend(jump_tokens) - request.structured_output_request.grammar.accept_tokens( - req_id, - new_token_ids, - ) - # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 61594869836f..b3527d295677 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import itertools import multiprocessing from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Optional @@ -18,12 +19,10 @@ import numpy as np import numpy.typing as npt import torch - import xgrammar.testing as xgr_testing from vllm.v1.request import Request else: torch = LazyLoader('torch', globals(), 'torch') - xgr_testing = LazyLoader('xgr_testing', globals(), 'xgrammar.testing') logger = init_logger(__name__) @@ -126,35 +125,68 @@ def grammar_bitmask( # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() - def jump_forward_tokens( - self, - request: Request, - bitmask: npt.NDArray[np.int32], - batch_index: int, - ) -> list[int] | None: + def jump_forward_tokens(self, request: Request) -> list[int] | None: """ - For structured output requests, repeatedly - check if the grammar bitmask is a single-token bitmask, and if so, - advance the FSM and collect all jump-forward tokens. - Returns the list of jump-forward token IDs. - - We can also consider to perform jump_and_retokenize here as well. + For structured output requests, we will perform + jump_and_retokenize possible divergence based on grammar state """ + so_request = request.structured_output_request if TYPE_CHECKING: - assert request.structured_output_request is not None - assert request.structured_output_request.grammar is not None + assert so_request is not None + assert so_request.grammar is not None assert self.backend is not None - is_single, unique_token_id = xgr_testing._is_single_token_bitmask( - torch.from_numpy(bitmask), - vocab_size=self.backend.vocab_size, - index=batch_index, - ) - s = request.structured_output_request.grammar.find_jf_string() - if s: - jf_tokens = self.tokenizer.encode(s, add_special_tokens=False) - print(jf_tokens) - return [unique_token_id] if is_single else None + jf_string = so_request.grammar.find_jump_string() + if not jf_string: + return None + + original_output_ids = list(request.output_token_ids) + + # NOTE: max_rollback_window determines the size + # of the tokenes from all_token_ids to be used for retokenization. + # Note that we don't need to whole token_ids + # for performance reason (tokenizer is blocking) + # TODO: handle token fusion + # max_rollback_window = 10 + + current_text_str = self.tokenizer.decode(request.all_token_ids) + all_text = current_text_str + jf_string + combined_all_token_ids = self.tokenizer.encode( + all_text, add_special_tokens=False) + retokenized_output_ids = combined_all_token_ids[request. + num_prompt_tokens:] + + # Find the prefix match length + k = sum(1 for _ in itertools.takewhile( + lambda pair: pair[0] == pair[1], + zip(original_output_ids, retokenized_output_ids))) + num_original_suffix = len(original_output_ids) - k + retokenized_suffix = retokenized_output_ids[k:] + if num_original_suffix > 0: + so_request.grammar.rollback(num_original_suffix) + + # Validate tokens one by one + accepted_tokens: list[int] = [] + num_validated_in_suffix = 0 + validation_ok = True + for token in retokenized_suffix: + if so_request.grammar.accept_tokens(request.request_id, [token]): + accepted_tokens.append(token) + num_validated_in_suffix += 1 + else: + if num_validated_in_suffix > 0: + so_request.grammar.rollback(num_validated_in_suffix) + validation_ok = False + break + + if validation_ok: + return accepted_tokens + + original_suffix_tokens = original_output_ids[num_validated_in_suffix:] + if original_suffix_tokens and not so_request.grammar.accept_tokens( + request.request_id, original_suffix_tokens): + so_request.grammar.rollback(len(original_suffix_tokens)) + return None def clear_backend(self) -> None: if self.backend is not None: diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index a17c7b8e1d0b..a2ecfaa6b10b 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -158,8 +158,13 @@ def reset(self): # This method may be not needed anymore? TODO self.ll_matcher.reset() - def find_jf_string(self): - pass + def find_jump_string(self) -> str | None: + ff_string = self.ll_matcher.compute_ff_bytes() + return ff_string.decode() if ff_string else None + + def rollback(self, num_tokens: int) -> None: + self.ll_matcher.rollback(num_tokens) + self.check_error() def serialize_guidance_grammar( diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index b196b4791606..1afee9cceb9b 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -29,6 +29,15 @@ class StructuredOutputOptions(enum.Enum): class StructuredOutputGrammar(ABC): """Request-level backend for structured output requests.""" + @abstractmethod + def find_jump_string(self) -> str | None: + """ + Find jump-forward string based on current grammar state. + + Returns: + Optional list of int: list of jump tokens + """ + @abstractmethod def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """ @@ -43,6 +52,15 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: bool: True if the tokens are accepted, False otherwise. """ + @abstractmethod + def rollback(self, num_tokens: int) -> None: + """ + Rolls back the state of the grammar by a specified number of tokens. + Will also revert counters for the number of processed tokens. + Args: + num_tokens (int): The number of tokens to roll back. + """ + @abstractmethod def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: """ @@ -68,10 +86,6 @@ def reset(self): Resets the state of the structured output grammar. """ - @abstractmethod - def find_jf_string(self): - ... - @dataclass class StructuredOutputBackend(ABC): diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 1161d745a6a5..70dde8aef9a6 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -88,6 +88,12 @@ def __post_init__(self): cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024, ) + self.num_speculative_tokens = \ + self.vllm_config.scheduler_config.max_num_seqs + if self.vllm_config.speculative_config is not None: + self.num_speculative_tokens = \ + self.vllm_config.speculative_config.num_speculative_tokens + def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: if request_type == StructuredOutputOptions.JSON: @@ -119,7 +125,10 @@ def compile_grammar(self, request_type: StructuredOutputOptions, f"grammar is not of valid supported types. ({request_type!s})") return XgrammarGrammar( - matcher=xgr.GrammarMatcher(ctx), + matcher=xgr.GrammarMatcher( + ctx, + max_rollback_tokens=self.num_speculative_tokens, + ), vocab_size=self.vocab_size, ctx=ctx, ) @@ -174,8 +183,13 @@ def reset(self): self.num_processed_tokens = 0 self.matcher.reset() - def find_jf_string(self): - return self.matcher.find_jump_forward_string() + def find_jump_string(self) -> str | None: + jf_string = self.matcher.find_jump_forward_string() + return jf_string if jf_string else None + + def rollback(self, num_tokens: int) -> None: + self.matcher.rollback(num_tokens) + self.num_processed_tokens -= num_tokens def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: From ba6d499ecd6bd93f48b9e1144915a30a0991c21c Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 29 Apr 2025 11:16:02 +0000 Subject: [PATCH 12/17] fix: set default rollback to 0 Signed-off-by: Aaron Pham --- vllm/v1/structured_output/backend_xgrammar.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 70dde8aef9a6..f644c82a1684 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -88,8 +88,7 @@ def __post_init__(self): cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024, ) - self.num_speculative_tokens = \ - self.vllm_config.scheduler_config.max_num_seqs + self.num_speculative_tokens = 0 if self.vllm_config.speculative_config is not None: self.num_speculative_tokens = \ self.vllm_config.speculative_config.num_speculative_tokens From ffb0324a0dd81761283765f4176e021a574d2c44 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 29 Apr 2025 11:32:47 +0000 Subject: [PATCH 13/17] chore: implement static max_rollback_window Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 19 ++++++++----------- vllm/v1/structured_output/__init__.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d2fd6b56679d..7ee497397529 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -703,17 +703,14 @@ def update_from_output( new_token_ids) initial_advancement = False - jump_tokens: list[int] | None = None - if initial_advancement and new_token_ids and request.use_structured_output: # noqa: E501 - # NOTE: We are performing retokenization to handle - # tokenizer boundary edge cases. There will be some - # general overhead incur here. Note that we already - # handle the state of the grammar within - # jump_forward_tokens. - jump_tokens = self.structured_output_manager.jump_forward_tokens( # noqa: E501 - request) - if jump_tokens: - new_token_ids += jump_tokens + # NOTE: We are performing retokenization to handle + # tokenizer boundary. There will be some + # overhead here. + if initial_advancement and new_token_ids and request.use_structured_output and ( # noqa: E501 + jump_tokens := + self.structured_output_manager.jump_forward_tokens(request) + ): + new_token_ids += jump_tokens # Append generated tokens and check for stop. Note that if # a request is still being prefilled, we expect the model runner diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index b3527d295677..a8b8effd6042 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -146,15 +146,17 @@ def jump_forward_tokens(self, request: Request) -> list[int] | None: # of the tokenes from all_token_ids to be used for retokenization. # Note that we don't need to whole token_ids # for performance reason (tokenizer is blocking) - # TODO: handle token fusion - # max_rollback_window = 10 + max_rollback_window = 10 - current_text_str = self.tokenizer.decode(request.all_token_ids) + current_text_str = self.tokenizer.decode( + request.all_token_ids[-max_rollback_window:]) all_text = current_text_str + jf_string - combined_all_token_ids = self.tokenizer.encode( + retokenized_output_ids = self.tokenizer.encode( all_text, add_special_tokens=False) - retokenized_output_ids = combined_all_token_ids[request. - num_prompt_tokens:] + if request.prompt_token_ids[-1] in retokenized_output_ids: + retokenized_output_ids = retokenized_output_ids[ + retokenized_output_ids.index(request.prompt_token_ids[-1]) + + 1:] # Find the prefix match length k = sum(1 for _ in itertools.takewhile( From a7c80705e130d3054573d5696675d664732d1fdf Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 29 Apr 2025 17:51:45 +0000 Subject: [PATCH 14/17] chore: add a mock test case --wip-- Signed-off-by: Aaron Pham --- tests/v1/core/test_scheduler.py | 84 +++++++++++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 12 ++-- vllm/v1/structured_output/__init__.py | 25 ++++---- 3 files changed, 105 insertions(+), 16 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index ee4e95856f23..05b6bfeaea97 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1165,3 +1165,87 @@ def test_kv_connector_handles_preemption(): # All memory should be freed since nothing is running. assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ == NUM_BLOCKS - 1 + + +def test_scheduler_jump_forward(): + scheduler = create_scheduler() + so_manager = scheduler.structured_output_manager + + # Mock the tokenizer used by StructuredOutputManager + mock_tokenizer = Mock() + # Example tokenization behavior: + # decode([10, 20]) -> "ab" + # encode("abc") -> [10, 25] (Simulates 'c' being token 25) + # encode("ab" + "c") -> [10, 25] + mock_tokenizer.decode.side_effect = lambda ids: { + tuple([10, 20]): "ab", + # Add more cases as needed + }.get(tuple(ids), "decode_fallback") + mock_tokenizer.encode.side_effect = lambda text, add_special_tokens: { + "abc": [10, 25], + }.get(text, [999]) + + so_manager.tokenizer = mock_tokenizer + + # 2. Create a request using structured output + request = create_requests(num_requests=1)[0] + request.use_structured_output = True + request.structured_output_request = Mock() + request.structured_output_request.structured_output_key = ("json", "{}" + ) # Dummy key + + # Mock the grammar object + mock_grammar = Mock() + mock_grammar.find_jump_string.return_value = "c" # The jump string + mock_grammar.accept_tokens.return_value = True # Assume validation passes for now + mock_grammar.is_terminated.return_value = False + mock_grammar.rollback = Mock() # To track calls + + request.structured_output_request.grammar = mock_grammar + + # 3. Simulate scheduling and initial output + scheduler.add_request(request) + output = scheduler.schedule() # Schedule the prompt + + # Simulate model outputting initial tokens (e.g., [10, 20] for "ab") + initial_output_tokens = [10, 20] + model_runner_output = ModelRunnerOutput( + req_ids=[request.request_id], + req_id_to_index={request.request_id: 0}, + sampled_token_ids=[initial_output_tokens], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # 4. Call update_from_output - this triggers jump_forward_tokens + # First, manually update computed tokens as schedule() does + request.num_computed_tokens += output.num_scheduled_tokens[ + request.request_id] + # Update output_token_ids with initial output before jump-forward call + request.append_output_token_ids(initial_output_tokens) + + # Now manually call jump_forward_tokens (since it's complex to mock the whole update flow) + # In reality, update_from_output calls this internally + jump_tokens = so_manager.jump_forward_tokens(request) + + # 5. Assertions + # Based on mock tokenizer: decode([10, 20]) -> "ab" + # jf_string = "c" + # text = "ab" + "c" = "abc" + # encode("abc") -> [10, 25] + # original_output_ids = [10, 20] + # retokenized_output_ids = [10, 25] + # k = 1 (common prefix is [10]) + # num_original_suffix = 1 ([20]) + # retokenized_suffix = [25] + # rollback(1) should be called + # accept_tokens([25]) should be called and return True + # expected jump_tokens = [25] + + assert jump_tokens == [ + 25 + ], f"Expected jump tokens [25], but got {jump_tokens}" + mock_grammar.rollback.assert_called_once_with(1) + mock_grammar.accept_tokens.assert_called_once_with(request.request_id, + [25]) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7ee497397529..5aaf83b3f684 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -693,20 +693,22 @@ def update_from_output( # NOTE: We will need to first advance the FSM # given that we apply bitmask in first pass # and we only perform jump-forward posteriori. - initial_advancement = True + first_pass = True if new_token_ids and request.use_structured_output: so_request = request.structured_output_request if TYPE_CHECKING: assert so_request is not None assert so_request.grammar is not None - so_request.grammar.accept_tokens(request.request_id, - new_token_ids) - initial_advancement = False + so_request.grammar.accept_tokens( + request.request_id, + new_token_ids, + ) + first_pass = False # NOTE: We are performing retokenization to handle # tokenizer boundary. There will be some # overhead here. - if initial_advancement and new_token_ids and request.use_structured_output and ( # noqa: E501 + if first_pass and new_token_ids and request.use_structured_output and ( # noqa: E501 jump_tokens := self.structured_output_manager.jump_forward_tokens(request) ): diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index a8b8effd6042..ac4e4f7baa1b 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -148,24 +148,25 @@ def jump_forward_tokens(self, request: Request) -> list[int] | None: # for performance reason (tokenizer is blocking) max_rollback_window = 10 - current_text_str = self.tokenizer.decode( + rollback_text_str = self.tokenizer.decode( request.all_token_ids[-max_rollback_window:]) - all_text = current_text_str + jf_string retokenized_output_ids = self.tokenizer.encode( - all_text, add_special_tokens=False) + rollback_text_str + jf_string, + add_special_tokens=False, + ) if request.prompt_token_ids[-1] in retokenized_output_ids: - retokenized_output_ids = retokenized_output_ids[ - retokenized_output_ids.index(request.prompt_token_ids[-1]) + - 1:] + prompt_boundary = retokenized_output_ids.index( + request.prompt_token_ids[-1]) + 1 + retokenized_output_ids = retokenized_output_ids[prompt_boundary:] # Find the prefix match length k = sum(1 for _ in itertools.takewhile( lambda pair: pair[0] == pair[1], - zip(original_output_ids, retokenized_output_ids))) - num_original_suffix = len(original_output_ids) - k + zip(original_output_ids, retokenized_output_ids), + )) retokenized_suffix = retokenized_output_ids[k:] - if num_original_suffix > 0: - so_request.grammar.rollback(num_original_suffix) + if k < len(original_output_ids): + so_request.grammar.rollback(len(original_output_ids) - k) # Validate tokens one by one accepted_tokens: list[int] = [] @@ -186,7 +187,9 @@ def jump_forward_tokens(self, request: Request) -> list[int] | None: original_suffix_tokens = original_output_ids[num_validated_in_suffix:] if original_suffix_tokens and not so_request.grammar.accept_tokens( - request.request_id, original_suffix_tokens): + request.request_id, + original_suffix_tokens, + ): so_request.grammar.rollback(len(original_suffix_tokens)) return None From 13b6c192045f519b9be9aed7e7741df16bc1fd3c Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 29 Apr 2025 20:44:03 +0000 Subject: [PATCH 15/17] fix: align output_ids to correct retokenized windows Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index ac4e4f7baa1b..d1d1e1a7899b 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -159,6 +159,10 @@ def jump_forward_tokens(self, request: Request) -> list[int] | None: request.prompt_token_ids[-1]) + 1 retokenized_output_ids = retokenized_output_ids[prompt_boundary:] + original_output_ids = original_output_ids[ + max(0, + len(original_output_ids) - len(retokenized_output_ids)):] + # Find the prefix match length k = sum(1 for _ in itertools.takewhile( lambda pair: pair[0] == pair[1], From 93cd93f5d4d75a7f5b9da069824c7ab3faa34a00 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 15 May 2025 00:25:44 +0000 Subject: [PATCH 16/17] fix: revert bad merge Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 4 ++-- vllm/v1/structured_output/backend_xgrammar.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 31f520cb1a6b..7f579d322d77 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -5,7 +5,7 @@ import time from collections import defaultdict, deque from collections.abc import Iterable -from typing import Any, Optional +from typing import Any, Optional, Union from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch @@ -844,7 +844,7 @@ def add_request(self, request: Request) -> None: def finish_requests( self, - request_ids: str | Iterable[str], + request_ids: Union[str, Iterable[str]], finished_status: RequestStatus, ) -> None: """Handles the finish signal from outside the scheduler. diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 2396ce590857..9d3b6433d723 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -49,10 +49,10 @@ def __post_init__(self): ) ] stop_token_ids = None - if hasattr( + if (hasattr( self.tokenizer, "eos_token_id", - ) and self.tokenizer.eos_token_id is not None: + ) and self.tokenizer.eos_token_id is not None): stop_token_ids = [self.tokenizer.eos_token_id] except AttributeError as e: raise ValueError( From f580263c2aa320995a5b94570831c74174d588b5 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 15 May 2025 00:28:01 +0000 Subject: [PATCH 17/17] revert: remove jump forward tests implementation for now Signed-off-by: Aaron Pham --- tests/v1/core/test_scheduler.py | 60 +-------------------------------- 1 file changed, 1 insertion(+), 59 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index d4e3cebd98cd..f40d477a0036 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -185,7 +185,7 @@ def test_get_num_unfinished_requests(): ]) def test_schedule(enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int]): - '''Test scheduling. + '''Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs ''' scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) @@ -1168,64 +1168,6 @@ def test_kv_connector_handles_preemption(): == NUM_BLOCKS - 1 -def test_scheduler_jump_forward(): - scheduler = create_scheduler() - so_manager = scheduler.structured_output_manager - - mock_tokenizer = Mock() - mock_tokenizer.decode.side_effect = lambda ids: { - tuple([10, 20]): "ab", - }.get(tuple(ids), "decode_fallback") - mock_tokenizer.encode.side_effect = lambda text, add_special_tokens: { - "abc": [10, 25], - }.get(text, [999]) - - so_manager.tokenizer = mock_tokenizer - - request = create_requests(num_requests=1)[0] - request.use_structured_output = True - request.structured_output_request = Mock() - request.structured_output_request.structured_output_key = ("json", "{}" - ) # Dummy key - - mock_grammar = Mock() - mock_grammar.find_jump_string.return_value = "c" - mock_grammar.accept_tokens.return_value = True - mock_grammar.is_terminated.return_value = False - mock_grammar.rollback = Mock() - - request.structured_output_request.grammar = mock_grammar - - # 3. Simulate scheduling and initial output - scheduler.add_request(request) - output = scheduler.schedule() # Schedule the prompt - - initial_output_tokens = [10, 20] - request.num_computed_tokens += output.num_scheduled_tokens[ - request.request_id] - request.append_output_token_ids(initial_output_tokens) - - jump_tokens = so_manager.jump_forward_tokens(request) - # jf_string = "c" - # text = "ab" + "c" = "abc" - # encode("abc") -> [10, 25] - # original_output_ids = [10, 20] - # retokenized_output_ids = [10, 25] - # k = 1 (common prefix is [10]) - # num_original_suffix = 1 ([20]) - # retokenized_suffix = [25] - # rollback(1) should be called - # accept_tokens([25]) should be called and return True - # expected jump_tokens = [25] - if jump_tokens != [25]: - pytest.fail(f"Expected jump tokens [25], but got {jump_tokens}") - mock_grammar.rollback.assert_called_once_with(1) - mock_grammar.accept_tokens.assert_called_once_with( - request.request_id, - [25], - ) - - def make_output(scheduler: Scheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running],