Skip to content

Commit ec7ceda

Browse files
committed
[Core] Streamline some structured output related code
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 314285d commit ec7ceda

File tree

9 files changed

+84
-99
lines changed

9 files changed

+84
-99
lines changed

vllm/v1/core/sched/output.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,8 @@ class SchedulerOutput:
165165
# freed from the encoder cache.
166166
free_encoder_mm_hashes: list[str]
167167

168-
# Dict of request ids to their index within the batch
169-
# for filling the next token bitmask
170-
structured_output_request_ids: dict[str, int]
168+
# ids of structured outputs requests included in the bitmask, in order.
169+
structured_output_request_ids: list[str]
171170
# the bitmask for the whole batch
172171
grammar_bitmask: "npt.NDArray[np.int32] | None"
173172

vllm/v1/core/sched/scheduler.py

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from collections.abc import Iterable
88
from typing import Any
99

10+
import numpy as np
11+
from pandas._typing import npt
12+
1013
from vllm.config import VllmConfig
1114
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
1215
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
@@ -610,11 +613,8 @@ def schedule(self) -> SchedulerOutput:
610613
scheduled_spec_decode_tokens,
611614
req_to_new_blocks,
612615
)
613-
scheduled_requests = (
614-
scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs
615-
)
616616
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
617-
scheduled_requests, scheduled_spec_decode_tokens
617+
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
618618
)
619619
scheduler_output = SchedulerOutput(
620620
scheduled_new_reqs=new_reqs_data,
@@ -878,32 +878,28 @@ def _try_schedule_encoder_inputs(
878878

879879
def get_grammar_bitmask(
880880
self,
881-
requests: list[Request],
881+
scheduled_request_ids: Iterable[str],
882882
scheduled_spec_decode_tokens: dict[str, list[int]],
883-
):
884-
# NOTE: structured_output_request_ids maps
885-
# a request's (request that uses structured output)
886-
# request_id to its index in the batch.
887-
# This will help us determine to slice the grammar bitmask
888-
# and only applies valid mask for requests that
889-
# uses structured decoding.
890-
structured_output_request_ids: dict[str, int] = {}
891-
for i, req in enumerate(requests):
892-
if req.use_structured_output:
893-
# PERF: in case of chunked prefill,
894-
# request might not include any new tokens.
895-
# Therefore, we might introduce some additional
896-
# cycle to fill in the bitmask, which could be a big no-op.
897-
structured_output_request_ids[req.request_id] = i
898-
883+
) -> tuple[list[str], npt.NDArray[np.int32] | None]:
884+
# Collect list of scheduled request ids that use structured output.
885+
# The corresponding rows of the bitmask will be in this order.
886+
# PERF: in case of chunked prefill,
887+
# request might not include any new tokens.
888+
# Therefore, we might introduce some additional
889+
# cycle to fill in the bitmask, which could be a big no-op.
890+
structured_output_request_ids = [
891+
req_id
892+
for req_id in scheduled_request_ids
893+
if (req := self.requests.get(req_id)) and req.use_structured_output
894+
]
899895
if not structured_output_request_ids:
900-
bitmask = None
901-
else:
902-
bitmask = self.structured_output_manager.grammar_bitmask(
903-
self.requests,
904-
structured_output_request_ids,
905-
scheduled_spec_decode_tokens,
906-
)
896+
return structured_output_request_ids, None
897+
898+
bitmask = self.structured_output_manager.grammar_bitmask(
899+
self.requests,
900+
structured_output_request_ids,
901+
scheduled_spec_decode_tokens,
902+
)
907903
return structured_output_request_ids, bitmask
908904

909905
def update_from_output(
@@ -1011,12 +1007,10 @@ def update_from_output(
10111007
new_logprobs = logprobs.slice(req_index, req_index + 1)
10121008

10131009
if new_token_ids and self.structured_output_manager.should_advance(request):
1014-
# NOTE: structured_output_request
1015-
# should not be None if use_structured_output, we have
1016-
# checked above, so safe to ignore type warning
1017-
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
1018-
req_id, new_token_ids
1019-
)
1010+
struct_output_request = request.structured_output_request
1011+
assert struct_output_request is not None
1012+
assert struct_output_request.grammar is not None
1013+
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
10201014

10211015
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
10221016
request.num_nans_in_logits = num_nans_in_logits[req_id]

vllm/v1/request.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(
4040
prompt_embeds: torch.Tensor | None = None,
4141
mm_features: list[MultiModalFeatureSpec] | None = None,
4242
lora_request: Optional["LoRARequest"] = None,
43-
structured_output_request: Optional["StructuredOutputRequest"] = None,
4443
cache_salt: str | None = None,
4544
priority: int = 0,
4645
trace_headers: Mapping[str, str] | None = None,
@@ -54,11 +53,12 @@ def __init__(
5453
# Because of LoRA, the eos token id can be different for each request.
5554
self.eos_token_id = eos_token_id
5655
self.lora_request = lora_request
57-
self.structured_output_request = structured_output_request
56+
self.structured_output_request = StructuredOutputRequest.from_sampling_params(
57+
sampling_params
58+
)
5859
self.arrival_time = arrival_time if arrival_time is not None else time.time()
5960

6061
self.status = RequestStatus.WAITING
61-
self.use_structured_output = False
6262
self.events: list[EngineCoreEvent] = []
6363
self.stop_reason: int | str | None = None
6464

@@ -72,9 +72,8 @@ def __init__(
7272
# Generative models.
7373
assert sampling_params.max_tokens is not None
7474
self.max_tokens = sampling_params.max_tokens
75-
if sampling_params.structured_outputs is not None:
75+
if self.structured_output_request is not None:
7676
self.status = RequestStatus.WAITING_FOR_FSM
77-
self.use_structured_output = True
7877

7978
if sampling_params.extra_args is not None:
8079
self.kv_transfer_params = sampling_params.extra_args.get(
@@ -145,11 +144,6 @@ def from_engine_core_request(
145144
eos_token_id=request.eos_token_id,
146145
arrival_time=request.arrival_time,
147146
lora_request=request.lora_request,
148-
structured_output_request=StructuredOutputRequest(
149-
sampling_params=request.sampling_params
150-
)
151-
if request.sampling_params
152-
else None,
153147
cache_salt=request.cache_salt,
154148
priority=request.priority,
155149
trace_headers=request.trace_headers,
@@ -170,6 +164,10 @@ def append_output_token_ids(
170164
if self.get_hash_new_full_blocks is not None:
171165
self.block_hashes.extend(self.get_hash_new_full_blocks())
172166

167+
@property
168+
def use_structured_output(self) -> bool:
169+
return self.structured_output_request is not None
170+
173171
@property
174172
def is_output_corrupted(self) -> bool:
175173
return self.num_nans_in_logits > 0

vllm/v1/structured_output/__init__.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _async_submit_fill_bitmask(
167167
def grammar_bitmask(
168168
self,
169169
requests: dict[str, Request],
170-
structured_output_request_ids: dict[str, int],
170+
structured_output_request_ids: list[str],
171171
scheduled_spec_decode_tokens: dict[str, list[int]],
172172
) -> "npt.NDArray[np.int32] | None":
173173
# Prepare the structured output bitmask for this batch.
@@ -196,17 +196,16 @@ def grammar_bitmask(
196196
# masks for each request, one for each possible bonus token position.
197197
# These are stored inline in the tensor and unpacked by the gpu runner.
198198
cumulative_index = 0
199-
ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1])
200199

201200
# Optimized parallel filling of bitmasks for
202201
# non-spec, large-batch-size cases
203202
if (
204-
len(ordered_seq) > self.fill_bitmask_parallel_threshold
203+
len(structured_output_request_ids) > self.fill_bitmask_parallel_threshold
205204
and max_num_spec_tokens == 0
206205
):
207206
promises = []
208207
batch = []
209-
for req_id, _ in ordered_seq:
208+
for req_id in structured_output_request_ids:
210209
request = requests[req_id]
211210
structured_output_request = request.structured_output_request
212211
if TYPE_CHECKING:
@@ -230,7 +229,7 @@ def grammar_bitmask(
230229
promise.result()
231230
else:
232231
# Fallback to serial filling of bitmasks for small-batch-size cases
233-
for req_id, _ in ordered_seq:
232+
for req_id in structured_output_request_ids:
234233
request = requests[req_id]
235234
structured_output_request = request.structured_output_request
236235

@@ -295,21 +294,20 @@ def should_advance(self, request: Request) -> bool:
295294
assert request.structured_output_request.grammar is not None
296295
# by default, we should always advance
297296
# for cases that don't use thinking mode.
298-
if self.reasoner is not None:
299-
structured_req = request.structured_output_request
297+
if self.reasoner is None:
298+
return True
300299

301-
if structured_req.reasoning_ended:
302-
return True
300+
structured_req = request.structured_output_request
301+
if structured_req.reasoning_ended:
302+
return True
303303

304-
# Check if reasoning ends in *this* step
305-
if self.reasoner.is_reasoning_end(request.all_token_ids):
306-
# Reasoning just ended, so we shouldn't advance til
307-
# next pass
308-
structured_req.reasoning_ended = True
304+
# Check if reasoning ends in *this* step
305+
if self.reasoner.is_reasoning_end(request.all_token_ids):
306+
# Reasoning just ended, so we shouldn't advance til
307+
# next pass
308+
structured_req.reasoning_ended = True
309309

310-
return False
311-
else:
312-
return True
310+
return False
313311

314312
def clear_backend(self) -> None:
315313
if self.backend is not None:

vllm/v1/structured_output/backend_guidance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _process_schema(
252252
def validate_guidance_grammar(
253253
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
254254
) -> None:
255-
tp, grm = get_structured_output_key(sampling_params)
255+
tp, grm = get_structured_output_key(sampling_params.structured_outputs)
256256
guidance_grm = serialize_guidance_grammar(tp, grm)
257257
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
258258
if err:

vllm/v1/structured_output/request.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from concurrent.futures._base import TimeoutError
88
from typing import cast
99

10-
from vllm.sampling_params import SamplingParams
10+
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
1111
from vllm.v1.structured_output.backend_types import (
1212
StructuredOutputGrammar,
1313
StructuredOutputKey,
@@ -17,10 +17,19 @@
1717

1818
@dataclasses.dataclass
1919
class StructuredOutputRequest:
20-
sampling_params: SamplingParams
20+
params: StructuredOutputsParams
2121
_grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None
2222
reasoning_ended: bool | None = None
2323

24+
@staticmethod
25+
def from_sampling_params(
26+
sampling_params: SamplingParams | None,
27+
) -> "StructuredOutputRequest | None":
28+
if sampling_params is None:
29+
return None
30+
params = sampling_params.structured_outputs
31+
return StructuredOutputRequest(params=params) if params else None
32+
2433
def _check_grammar_completion(self) -> bool:
2534
# NOTE: We have to lazy import to gate circular imports
2635
from vllm.v1.request import RequestStatus
@@ -53,31 +62,28 @@ def grammar(
5362

5463
@functools.cached_property
5564
def structured_output_key(self) -> StructuredOutputKey:
56-
return get_structured_output_key(self.sampling_params)
65+
return get_structured_output_key(self.params)
5766

5867

59-
def get_structured_output_key(sampling_params: SamplingParams) -> StructuredOutputKey:
60-
params = sampling_params.structured_outputs
61-
assert params is not None, "params can't be None."
68+
def get_structured_output_key(params: StructuredOutputsParams) -> StructuredOutputKey:
6269
if params.json is not None:
6370
if not isinstance(params.json, str):
6471
json_str = json.dumps(params.json)
6572
else:
6673
json_str = params.json
67-
return (StructuredOutputOptions.JSON, json_str)
68-
elif params.json_object:
69-
return (StructuredOutputOptions.JSON_OBJECT, "")
70-
elif params.regex is not None:
71-
return (StructuredOutputOptions.REGEX, params.regex)
72-
elif params.choice is not None:
74+
return StructuredOutputOptions.JSON, json_str
75+
if params.json_object:
76+
return StructuredOutputOptions.JSON_OBJECT, ""
77+
if params.regex is not None:
78+
return StructuredOutputOptions.REGEX, params.regex
79+
if params.choice is not None:
7380
if not isinstance(params.choice, str):
7481
json_str = json.dumps(params.choice)
7582
else:
7683
json_str = params.choice
77-
return (StructuredOutputOptions.CHOICE, json_str)
78-
elif params.grammar is not None:
79-
return (StructuredOutputOptions.GRAMMAR, params.grammar)
80-
elif params.structural_tag is not None:
81-
return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag)
82-
else:
83-
raise ValueError("No valid structured output parameter found")
84+
return StructuredOutputOptions.CHOICE, json_str
85+
if params.grammar is not None:
86+
return StructuredOutputOptions.GRAMMAR, params.grammar
87+
if params.structural_tag is not None:
88+
return StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag
89+
raise ValueError("No valid structured output parameter found")

vllm/v1/structured_output/utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def apply_grammar_bitmask(
4747
scheduler_output: SchedulerOutput,
4848
input_batch: InputBatch,
4949
logits: torch.Tensor,
50-
device: torch.device,
5150
) -> None:
5251
"""
5352
Apply grammar bitmask to output logits of the model with xgrammar function.
@@ -91,10 +90,7 @@ def apply_grammar_bitmask(
9190
dtype=grammar_bitmask.dtype,
9291
)
9392
cumulative_index = 0
94-
seq = sorted(
95-
scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1]
96-
)
97-
for req_id, _ in seq:
93+
for req_id in scheduler_output.structured_output_request_ids:
9894
num_spec_tokens = len(
9995
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
10096
)
@@ -117,7 +113,7 @@ def apply_grammar_bitmask(
117113

118114
xgr.apply_token_bitmask_inplace(
119115
logits,
120-
grammar_bitmask.to(device, non_blocking=True),
116+
grammar_bitmask.to(logits.device, non_blocking=True),
121117
indices=out_indices if not skip_out_indices else None,
122118
)
123119

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2570,10 +2570,8 @@ def execute_model(
25702570
logits = model_output_broadcast_data["logits"]
25712571

25722572
# Apply structured output bitmasks if present
2573-
if scheduler_output.grammar_bitmask is not None:
2574-
apply_grammar_bitmask(
2575-
scheduler_output, self.input_batch, logits, self.device
2576-
)
2573+
if scheduler_output.structured_output_request_ids:
2574+
apply_grammar_bitmask(scheduler_output, self.input_batch, logits)
25772575

25782576
with record_function_or_nullcontext("Sample"):
25792577
sampler_output = self._sample(logits, spec_decode_metadata)

vllm/v1/worker/tpu_model_runner.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,12 +1963,8 @@ def prepare_structured_decoding_input(
19631963
self.grammar_bitmask_cpu.zero_()
19641964
self.require_structured_out_cpu.zero_()
19651965

1966-
sorted_struct_requests = sorted(
1967-
scheduler_output.structured_output_request_ids.items(),
1968-
key=lambda item: item[1],
1969-
)
19701966
cumulative_mask_idx = 0
1971-
for req_id, _ in sorted_struct_requests:
1967+
for req_id in scheduler_output.structured_output_request_ids:
19721968
if req_id not in self.input_batch.req_id_to_index:
19731969
continue
19741970
batch_index = self.input_batch.req_id_to_index[req_id]

0 commit comments

Comments
 (0)