Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
81aadb6
chore: migrate tokenizer init to manager only
aarnphm Apr 14, 2025
a97b172
--wip--
aarnphm Apr 17, 2025
b15d00f
merge: branch 'main' of github.com:vllm-project/vllm into feat/jump-f…
aarnphm Apr 23, 2025
d612f85
merge: branch 'main' of github.com:vllm-project/vllm into feat/jump-f…
aarnphm Apr 26, 2025
26f8a25
chore: remove unused functions
aarnphm Apr 26, 2025
fb4ae73
--wip--
aarnphm Apr 26, 2025
4ddf58c
feat: working version
aarnphm Apr 26, 2025
c744d62
merge: branch 'main' of github.com:vllm-project/vllm into feat/jump-f…
aarnphm Apr 26, 2025
7c41ce0
chore: remove debug print
aarnphm Apr 26, 2025
19e2a5c
merge: branch 'main' of github.com:vllm-project/vllm into feat/jump-f…
aarnphm Apr 27, 2025
812d684
revert: use scheduler_output bitmask
aarnphm Apr 27, 2025
bf5c46c
chore: remove debug print
aarnphm Apr 27, 2025
a1ae3ac
chore: move tokenizer to __init__
aarnphm Apr 28, 2025
535e06e
--wip retokenize--
aarnphm Apr 28, 2025
dbc9455
merge: branch 'main' of github.com:vllm-project/vllm into feat/jump-f…
aarnphm Apr 29, 2025
511db48
feat: jump_and_retokenize
aarnphm Apr 29, 2025
ba6d499
fix: set default rollback to 0
aarnphm Apr 29, 2025
ffb0324
chore: implement static max_rollback_window
aarnphm Apr 29, 2025
a7c8070
chore: add a mock test case --wip--
aarnphm Apr 29, 2025
13b6c19
fix: align output_ids to correct retokenized windows
aarnphm Apr 29, 2025
7d26f48
merge: branch 'main' of github.com:vllm-project/vllm into feat/jump-f…
aarnphm Apr 29, 2025
372bcda
merge: branch 'main' of github.com:vllm-project/vllm into feat/jump-f…
aarnphm Apr 30, 2025
1262acc
merge: branch 'main' of github.com:vllm-project/vllm into feat/jump-f…
aarnphm May 14, 2025
d89a660
merge: branch 'main' of github.com:vllm-project/vllm into feat/jump-f…
aarnphm May 15, 2025
93cd93f
fix: revert bad merge
aarnphm May 15, 2025
f580263
revert: remove jump forward tests implementation for now
aarnphm May 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,28 @@ def update_from_output(
new_token_ids = generated_token_ids
kv_transfer_params = None

# NOTE: We will need to first advance the FSM
# given that we apply bitmask in first pass
# and we only perform jump-forward posteriori.
first_pass = True
if new_token_ids and self.structured_output_manager.should_advance(
request):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)
first_pass = False

# NOTE: We are performing retokenization to handle
# tokenizer boundary. There will be some
# overhead here.
if first_pass and new_token_ids and request.use_structured_output and ( # noqa: E501
jump_tokens :=
self.structured_output_manager.jump_forward_tokens(request)
):
new_token_ids += jump_tokens

# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
Expand All @@ -758,14 +780,6 @@ def update_from_output(
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)

if new_token_ids and self.structured_output_manager.should_advance(
request):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)

# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if self.structured_output_manager.should_advance(request):
Expand Down
71 changes: 71 additions & 0 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import itertools
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional
Expand Down Expand Up @@ -183,6 +184,76 @@ def grammar_bitmask(
# and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy()

def jump_forward_tokens(self, request: Request) -> list[int] | None:
"""
For structured output requests, we will perform
jump_and_retokenize possible divergence based on grammar state
"""
so_request = request.structured_output_request
if TYPE_CHECKING:
assert so_request is not None
assert so_request.grammar is not None
assert self.backend is not None

jf_string = so_request.grammar.find_jump_string()
if not jf_string:
return None

# NOTE: max_rollback_window determines the size
# of the tokenes from all_token_ids to be used for retokenization.
# Note that we don't need to whole token_ids
# for performance reason (tokenizer is blocking)
max_rollback_window = 10

rollback_text_str = self.tokenizer.decode(
request.all_token_ids[-max_rollback_window:])
retokenized_output_ids = self.tokenizer.encode(
rollback_text_str + jf_string,
add_special_tokens=False,
)
if request.prompt_token_ids[-1] in retokenized_output_ids:
prompt_boundary = retokenized_output_ids.index(
request.prompt_token_ids[-1]) + 1
retokenized_output_ids = retokenized_output_ids[prompt_boundary:]

original_output_ids = request.output_token_ids[
max(0,
len(request.output_token_ids) - len(retokenized_output_ids)):]

# Find the prefix match length
k = sum(1 for _ in itertools.takewhile(
lambda pair: pair[0] == pair[1],
zip(original_output_ids, retokenized_output_ids),
))
retokenized_suffix = retokenized_output_ids[k:]
if k < len(original_output_ids):
so_request.grammar.rollback(len(original_output_ids) - k)

# Validate tokens one by one
accepted_tokens: list[int] = []
num_validated_in_suffix = 0
validation_ok = True
for token in retokenized_suffix:
if so_request.grammar.accept_tokens(request.request_id, [token]):
accepted_tokens.append(token)
num_validated_in_suffix += 1
else:
if num_validated_in_suffix > 0:
so_request.grammar.rollback(num_validated_in_suffix)
validation_ok = False
break

if validation_ok:
return accepted_tokens

original_suffix_tokens = original_output_ids[num_validated_in_suffix:]
if original_suffix_tokens and not so_request.grammar.accept_tokens(
request.request_id,
original_suffix_tokens,
):
so_request.grammar.rollback(len(original_suffix_tokens))
return None

def should_advance(self, request: Request) -> bool:
if not request.use_structured_output:
return False
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/structured_output/backend_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def reset(self):
# This method may be not needed anymore? TODO
self.ll_matcher.reset()

def find_jump_string(self) -> str | None:
ff_string = self.ll_matcher.compute_ff_bytes()
return ff_string.decode() if ff_string else None


def serialize_guidance_grammar(
request_type: StructuredOutputOptions,
Expand Down
9 changes: 9 additions & 0 deletions vllm/v1/structured_output/backend_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ class StructuredOutputOptions(enum.Enum):
class StructuredOutputGrammar(ABC):
"""Request-level backend for structured output requests."""

@abstractmethod
def find_jump_string(self) -> str | None:
"""
Find jump-forward string based on current grammar state.

Returns:
Optional list of int: list of jump tokens
"""

@abstractmethod
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/structured_output/backend_xgrammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ def reset(self):
self.num_processed_tokens = 0
self.matcher.reset()

def find_jump_string(self) -> str | None:
jf_string = self.matcher.find_jump_forward_string()
return jf_string if jf_string else None


def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""
Expand Down