Skip to content

Commit a97b172

Browse files
committed
--wip--
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
1 parent 81aadb6 commit a97b172

File tree

5 files changed

+56
-11
lines changed

5 files changed

+56
-11
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -652,15 +652,16 @@ def update_from_output(
652652
# the outer lists can be of length > 1.
653653
new_logprobs = logprobs.slice(req_index, req_index + 1)
654654

655-
if new_token_ids and request.use_structured_output:
656-
# NOTE: structured_output_request
657-
# should not be None if use_structured_output, we have
658-
# check above, so safe to ignore type warning
659-
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
660-
req_id, new_token_ids)
661-
662-
# Get prompt logprobs for this request.
663-
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
655+
# --- Jump-forward decoding for structured output requests ---
656+
if request.use_structured_output:
657+
batch_index = scheduler_output.structured_output_request_ids.get(
658+
req_id, 0)
659+
jump_tokens = self.structured_output_manager.jump_forward_tokens(
660+
request, batch_index)
661+
if jump_tokens:
662+
new_token_ids.extend(jump_tokens)
663+
# --- End jump-forward decoding ---
664+
664665
if new_token_ids:
665666
# Add EngineCoreOutput for this Request.
666667
outputs.append(
@@ -669,12 +670,13 @@ def update_from_output(
669670
new_token_ids=new_token_ids,
670671
finish_reason=request.get_finished_reason(),
671672
new_logprobs=new_logprobs,
672-
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
673+
new_prompt_logprobs_tensors=prompt_logprobs_dict.get(
674+
req_id),
673675
stop_reason=request.stop_reason,
674676
events=request.take_events()))
675677
else:
676678
# Invariant: EngineCore returns no partial prefill outputs.
677-
assert not prompt_logprobs_tensors
679+
assert not prompt_logprobs_dict.get(req_id)
678680

679681
self.scheduled_req_ids.remove(req_id)
680682
if not stopped:

vllm/v1/structured_output/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm.config import VllmConfig
99
from vllm.logger import init_logger
1010
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
11+
from vllm.utils import LazyLoader
1112
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
1213
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
1314
StructuredOutputGrammar)
@@ -17,8 +18,11 @@
1718
import numpy as np
1819
import numpy.typing as npt
1920
import torch
21+
import xgrammar.testing as xgr_testing
2022

2123
from vllm.v1.request import Request
24+
else:
25+
xgr_testing = LazyLoader('xgr_testing', globals(), 'xgrammar.testing')
2226

2327
logger = init_logger(__name__)
2428

@@ -122,3 +126,28 @@ def grammar_bitmask(
122126
# np.ndarray, because that is much more efficient for serialization
123127
# and deserialization when sending this to the GPU workers.
124128
return bitmask_tensor.numpy()
129+
130+
def jump_forward_tokens(self, request, batch_index) -> list[int]:
131+
"""
132+
For xgrammar-based structured output requests, repeatedly check if the grammar bitmask
133+
is a single-token bitmask, and if so, advance the FSM and collect all jump-forward tokens.
134+
Returns the list of jump-forward token IDs.
135+
"""
136+
so_request = request.structured_output_request
137+
if so_request is None or so_request.grammar is None:
138+
return []
139+
140+
jump_tokens = []
141+
bitmask = torch.zeros(so_request.grammar.vocab_size, dtype=torch.int32)
142+
so_request.grammar.allocate_token_bitmask(1)
143+
so_request.grammar.fill_bitmask(bitmask, 0)
144+
is_single, unique_token_id = xgr_testing._is_single_token_bitmask(
145+
bitmask, so_request.grammar.vocab_size, 0)
146+
while is_single and unique_token_id != -1:
147+
jump_tokens.append(unique_token_id)
148+
so_request.grammar.accept_tokens(request.request_id,
149+
[unique_token_id])
150+
so_request.grammar.fill_bitmask(bitmask, batch_index)
151+
is_single, unique_token_id = xgr_testing._is_single_token_bitmask(
152+
bitmask, so_request.grammar.vocab_size, 0)
153+
return jump_tokens

vllm/v1/structured_output/backend_guidance.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ def reset(self):
119119
# This method may be not needed anymore? TODO
120120
self.ll_matcher.reset()
121121

122+
def find_jump_forward_tokens(self) -> list[int]:
123+
raise NotImplementedError
124+
122125

123126
def serialize_guidance_grammar(request_type: StructuredOutputOptions,
124127
grammar_spec: str,

vllm/v1/structured_output/backend_types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,9 @@ def allocate_token_bitmask(self, max_num_seqs: int):
100100
max_num_seqs (int): The maximum number of sequences for which
101101
to allocate the bitmask.
102102
"""
103+
104+
@abstractmethod
105+
def find_jump_forward_tokens(self) -> list[int]:
106+
"""
107+
Finds the tokens that can be used to jump forward in the grammar.
108+
"""

vllm/v1/structured_output/backend_xgrammar.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,8 @@ def is_terminated(self) -> bool:
147147
def reset(self):
148148
self.num_processed_tokens = 0
149149
self.matcher.reset()
150+
151+
def find_jump_forward_tokens(self) -> list[int]:
152+
jf_string = self.matcher.find_jump_forward_string()
153+
return self.tokenizer.decode(
154+
jf_string, skip_special_tokens=True) if jf_string else []

0 commit comments

Comments
 (0)