Skip to content

Commit a2c0511

Browse files
njhillxuebwang-amd
authored andcommitted
[Core] Streamline some structured output related code (vllm-project#26737)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent a0a8296 commit a2c0511

File tree

13 files changed

+119
-136
lines changed

13 files changed

+119
-136
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
3131
from vllm.v1.request import Request, RequestStatus
3232
from vllm.v1.structured_output import StructuredOutputManager
33-
from vllm.v1.structured_output.request import StructuredOutputRequest
3433

3534
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
3635

@@ -335,10 +334,10 @@ def test_stop_via_update_from_output():
335334
requests[0].request_id: [],
336335
requests[1].request_id: [10],
337336
},
338-
num_common_prefix_blocks=0,
337+
num_common_prefix_blocks=[],
339338
finished_req_ids=set(),
340339
free_encoder_mm_hashes=[],
341-
structured_output_request_ids={},
340+
structured_output_request_ids=[],
342341
grammar_bitmask=None,
343342
)
344343

@@ -383,10 +382,10 @@ def test_stop_via_update_from_output():
383382
requests[0].request_id: [10, 42],
384383
requests[1].request_id: [13],
385384
},
386-
num_common_prefix_blocks=0,
385+
num_common_prefix_blocks=[],
387386
finished_req_ids=set(),
388387
free_encoder_mm_hashes=[],
389-
structured_output_request_ids={},
388+
structured_output_request_ids=[],
390389
grammar_bitmask=None,
391390
)
392391

@@ -429,10 +428,10 @@ def test_stop_via_update_from_output():
429428
requests[0].request_id: [10, 11],
430429
requests[1].request_id: [],
431430
},
432-
num_common_prefix_blocks=0,
431+
num_common_prefix_blocks=[],
433432
finished_req_ids=set(),
434433
free_encoder_mm_hashes=[],
435-
structured_output_request_ids={},
434+
structured_output_request_ids=[],
436435
grammar_bitmask=None,
437436
)
438437

@@ -470,10 +469,10 @@ def test_stop_via_update_from_output():
470469
total_num_scheduled_tokens=3,
471470
scheduled_encoder_inputs={},
472471
scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]},
473-
num_common_prefix_blocks=0,
472+
num_common_prefix_blocks=[],
474473
finished_req_ids=set(),
475474
free_encoder_mm_hashes=[],
476-
structured_output_request_ids={},
475+
structured_output_request_ids=[],
477476
grammar_bitmask=None,
478477
)
479478

@@ -1941,7 +1940,6 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
19411940
sampling_params=sampling_params,
19421941
pooling_params=None,
19431942
eos_token_id=EOS_TOKEN_ID,
1944-
structured_output_request=StructuredOutputRequest(sampling_params),
19451943
)
19461944
scheduler.add_request(request)
19471945
output = scheduler.schedule()

tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _make_empty_scheduler_output():
2626
num_common_prefix_blocks=[],
2727
finished_req_ids=set(),
2828
free_encoder_mm_hashes=[],
29-
structured_output_request_ids={},
29+
structured_output_request_ids=[],
3030
grammar_bitmask=None,
3131
kv_connector_metadata=SharedStorageConnectorMetadata(),
3232
)

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
8989
total_num_scheduled_tokens=total_num_scheduled_tokens,
9090
scheduled_spec_decode_tokens={},
9191
scheduled_encoder_inputs={},
92-
num_common_prefix_blocks=0,
92+
num_common_prefix_blocks=[],
9393
finished_req_ids=set(),
9494
free_encoder_mm_hashes=[],
95-
structured_output_request_ids={},
95+
structured_output_request_ids=[],
9696
grammar_bitmask=None,
9797
)
9898

@@ -168,10 +168,10 @@ def test_update_states_request_finished(model_runner):
168168
total_num_scheduled_tokens=0,
169169
scheduled_spec_decode_tokens={},
170170
scheduled_encoder_inputs={},
171-
num_common_prefix_blocks=0,
171+
num_common_prefix_blocks=[],
172172
finished_req_ids={req_id},
173173
free_encoder_mm_hashes=[],
174-
structured_output_request_ids={},
174+
structured_output_request_ids=[],
175175
grammar_bitmask=None,
176176
)
177177

@@ -198,10 +198,10 @@ def test_update_states_request_resumed(model_runner):
198198
total_num_scheduled_tokens=0,
199199
scheduled_spec_decode_tokens={},
200200
scheduled_encoder_inputs={},
201-
num_common_prefix_blocks=0,
201+
num_common_prefix_blocks=[],
202202
finished_req_ids=set(),
203203
free_encoder_mm_hashes=[],
204-
structured_output_request_ids={},
204+
structured_output_request_ids=[],
205205
grammar_bitmask=None,
206206
)
207207

@@ -225,10 +225,10 @@ def test_update_states_request_resumed(model_runner):
225225
total_num_scheduled_tokens=1,
226226
scheduled_spec_decode_tokens={},
227227
scheduled_encoder_inputs={},
228-
num_common_prefix_blocks=0,
228+
num_common_prefix_blocks=[],
229229
finished_req_ids=set(),
230230
free_encoder_mm_hashes=[],
231-
structured_output_request_ids={},
231+
structured_output_request_ids=[],
232232
grammar_bitmask=None,
233233
)
234234

@@ -256,10 +256,10 @@ def test_update_states_no_changes(model_runner):
256256
total_num_scheduled_tokens=1,
257257
scheduled_spec_decode_tokens={},
258258
scheduled_encoder_inputs={},
259-
num_common_prefix_blocks=0,
259+
num_common_prefix_blocks=[],
260260
finished_req_ids=set(),
261261
free_encoder_mm_hashes=[],
262-
structured_output_request_ids={},
262+
structured_output_request_ids=[],
263263
grammar_bitmask=None,
264264
)
265265

@@ -291,10 +291,10 @@ def test_update_states_request_unscheduled(model_runner):
291291
total_num_scheduled_tokens=1,
292292
scheduled_spec_decode_tokens={},
293293
scheduled_encoder_inputs={},
294-
num_common_prefix_blocks=0,
294+
num_common_prefix_blocks=[],
295295
finished_req_ids=set(),
296296
free_encoder_mm_hashes=[],
297-
structured_output_request_ids={},
297+
structured_output_request_ids=[],
298298
grammar_bitmask=None,
299299
)
300300

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
146146
total_num_scheduled_tokens=total_num_scheduled_tokens,
147147
scheduled_spec_decode_tokens={},
148148
scheduled_encoder_inputs={},
149-
num_common_prefix_blocks=0,
149+
num_common_prefix_blocks=[],
150150
finished_req_ids=set(),
151151
free_encoder_mm_hashes=[],
152-
structured_output_request_ids={},
152+
structured_output_request_ids=[],
153153
grammar_bitmask=None,
154154
)
155155

@@ -212,10 +212,10 @@ def test_update_states_request_finished(model_runner, dist_init):
212212
total_num_scheduled_tokens=0,
213213
scheduled_spec_decode_tokens={},
214214
scheduled_encoder_inputs={},
215-
num_common_prefix_blocks=0,
215+
num_common_prefix_blocks=[],
216216
finished_req_ids={req_id},
217217
free_encoder_mm_hashes=[],
218-
structured_output_request_ids={},
218+
structured_output_request_ids=[],
219219
grammar_bitmask=None,
220220
)
221221

@@ -244,10 +244,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
244244
total_num_scheduled_tokens=0,
245245
scheduled_spec_decode_tokens={},
246246
scheduled_encoder_inputs={},
247-
num_common_prefix_blocks=0,
247+
num_common_prefix_blocks=[],
248248
finished_req_ids=set(),
249249
free_encoder_mm_hashes=[],
250-
structured_output_request_ids={},
250+
structured_output_request_ids=[],
251251
grammar_bitmask=None,
252252
)
253253

@@ -273,10 +273,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
273273
total_num_scheduled_tokens=1,
274274
scheduled_spec_decode_tokens={},
275275
scheduled_encoder_inputs={},
276-
num_common_prefix_blocks=0,
276+
num_common_prefix_blocks=[],
277277
finished_req_ids=set(),
278278
free_encoder_mm_hashes=[],
279-
structured_output_request_ids={},
279+
structured_output_request_ids=[],
280280
grammar_bitmask=None,
281281
)
282282

@@ -366,10 +366,10 @@ def test_update_states_no_changes(model_runner, dist_init):
366366
total_num_scheduled_tokens=1,
367367
scheduled_spec_decode_tokens={},
368368
scheduled_encoder_inputs={},
369-
num_common_prefix_blocks=0,
369+
num_common_prefix_blocks=[],
370370
finished_req_ids=set(),
371371
free_encoder_mm_hashes=[],
372-
structured_output_request_ids={},
372+
structured_output_request_ids=[],
373373
grammar_bitmask=None,
374374
)
375375

@@ -403,10 +403,10 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
403403
total_num_scheduled_tokens=1,
404404
scheduled_spec_decode_tokens={},
405405
scheduled_encoder_inputs={},
406-
num_common_prefix_blocks=0,
406+
num_common_prefix_blocks=[],
407407
finished_req_ids=set(),
408408
free_encoder_mm_hashes=[],
409-
structured_output_request_ids={},
409+
structured_output_request_ids=[],
410410
grammar_bitmask=None,
411411
)
412412

vllm/v1/core/sched/output.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,8 @@ class SchedulerOutput:
165165
# freed from the encoder cache.
166166
free_encoder_mm_hashes: list[str]
167167

168-
# Dict of request ids to their index within the batch
169-
# for filling the next token bitmask
170-
structured_output_request_ids: dict[str, int]
168+
# ids of structured outputs requests included in the bitmask, in order.
169+
structured_output_request_ids: list[str]
171170
# the bitmask for the whole batch
172171
grammar_bitmask: "npt.NDArray[np.int32] | None"
173172

vllm/v1/core/sched/scheduler.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from collections import defaultdict
77
from collections.abc import Iterable
8-
from typing import Any
8+
from typing import TYPE_CHECKING, Any
99

1010
from vllm.config import VllmConfig
1111
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
@@ -34,6 +34,10 @@
3434
from vllm.v1.spec_decode.metrics import SpecDecodingStats
3535
from vllm.v1.structured_output import StructuredOutputManager
3636

37+
if TYPE_CHECKING:
38+
import numpy as np
39+
import numpy.typing as npt
40+
3741
logger = init_logger(__name__)
3842

3943

@@ -608,11 +612,8 @@ def schedule(self) -> SchedulerOutput:
608612
scheduled_spec_decode_tokens,
609613
req_to_new_blocks,
610614
)
611-
scheduled_requests = (
612-
scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs
613-
)
614615
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
615-
scheduled_requests, scheduled_spec_decode_tokens
616+
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
616617
)
617618
scheduler_output = SchedulerOutput(
618619
scheduled_new_reqs=new_reqs_data,
@@ -876,32 +877,28 @@ def _try_schedule_encoder_inputs(
876877

877878
def get_grammar_bitmask(
878879
self,
879-
requests: list[Request],
880+
scheduled_request_ids: Iterable[str],
880881
scheduled_spec_decode_tokens: dict[str, list[int]],
881-
):
882-
# NOTE: structured_output_request_ids maps
883-
# a request's (request that uses structured output)
884-
# request_id to its index in the batch.
885-
# This will help us determine to slice the grammar bitmask
886-
# and only applies valid mask for requests that
887-
# uses structured decoding.
888-
structured_output_request_ids: dict[str, int] = {}
889-
for i, req in enumerate(requests):
890-
if req.use_structured_output:
891-
# PERF: in case of chunked prefill,
892-
# request might not include any new tokens.
893-
# Therefore, we might introduce some additional
894-
# cycle to fill in the bitmask, which could be a big no-op.
895-
structured_output_request_ids[req.request_id] = i
896-
882+
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
883+
# Collect list of scheduled request ids that use structured output.
884+
# The corresponding rows of the bitmask will be in this order.
885+
# PERF: in case of chunked prefill,
886+
# request might not include any new tokens.
887+
# Therefore, we might introduce some additional
888+
# cycle to fill in the bitmask, which could be a big no-op.
889+
structured_output_request_ids = [
890+
req_id
891+
for req_id in scheduled_request_ids
892+
if (req := self.requests.get(req_id)) and req.use_structured_output
893+
]
897894
if not structured_output_request_ids:
898-
bitmask = None
899-
else:
900-
bitmask = self.structured_output_manager.grammar_bitmask(
901-
self.requests,
902-
structured_output_request_ids,
903-
scheduled_spec_decode_tokens,
904-
)
895+
return structured_output_request_ids, None
896+
897+
bitmask = self.structured_output_manager.grammar_bitmask(
898+
self.requests,
899+
structured_output_request_ids,
900+
scheduled_spec_decode_tokens,
901+
)
905902
return structured_output_request_ids, bitmask
906903

907904
def update_from_output(
@@ -1013,12 +1010,10 @@ def update_from_output(
10131010
new_logprobs = logprobs.slice(req_index, req_index + 1)
10141011

10151012
if new_token_ids and self.structured_output_manager.should_advance(request):
1016-
# NOTE: structured_output_request
1017-
# should not be None if use_structured_output, we have
1018-
# checked above, so safe to ignore type warning
1019-
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
1020-
req_id, new_token_ids
1021-
)
1013+
struct_output_request = request.structured_output_request
1014+
assert struct_output_request is not None
1015+
assert struct_output_request.grammar is not None
1016+
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
10221017

10231018
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
10241019
request.num_nans_in_logits = num_nans_in_logits[req_id]

0 commit comments

Comments
 (0)