Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest

from .utils import EOS_TOKEN_ID, create_requests, create_scheduler

Expand Down Expand Up @@ -335,10 +334,10 @@ def test_stop_via_update_from_output():
requests[0].request_id: [],
requests[1].request_id: [10],
},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down Expand Up @@ -383,10 +382,10 @@ def test_stop_via_update_from_output():
requests[0].request_id: [10, 42],
requests[1].request_id: [13],
},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down Expand Up @@ -429,10 +428,10 @@ def test_stop_via_update_from_output():
requests[0].request_id: [10, 11],
requests[1].request_id: [],
},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down Expand Up @@ -470,10 +469,10 @@ def test_stop_via_update_from_output():
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down Expand Up @@ -1941,7 +1940,6 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
structured_output_request=StructuredOutputRequest(sampling_params),
)
scheduler.add_request(request)
output = scheduler.schedule()
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
kv_connector_metadata=SharedStorageConnectorMetadata(),
)
Expand Down
24 changes: 12 additions & 12 deletions tests/v1/tpu/worker/test_tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down Expand Up @@ -168,10 +168,10 @@ def test_update_states_request_finished(model_runner):
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand All @@ -198,10 +198,10 @@ def test_update_states_request_resumed(model_runner):
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand All @@ -225,10 +225,10 @@ def test_update_states_request_resumed(model_runner):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down Expand Up @@ -256,10 +256,10 @@ def test_update_states_no_changes(model_runner):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down Expand Up @@ -291,10 +291,10 @@ def test_update_states_request_unscheduled(model_runner):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down
24 changes: 12 additions & 12 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down Expand Up @@ -212,10 +212,10 @@ def test_update_states_request_finished(model_runner, dist_init):
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down Expand Up @@ -244,10 +244,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand All @@ -273,10 +273,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down Expand Up @@ -366,10 +366,10 @@ def test_update_states_no_changes(model_runner, dist_init):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down Expand Up @@ -403,10 +403,10 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

Expand Down
5 changes: 2 additions & 3 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,8 @@
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]

# Dict of request ids to their index within the batch
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# ids of structured outputs requests included in the bitmask, in order.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# ids of structured outputs requests included in the bitmask, in order.
# ids of structured outputs requests included in the bitmask.
# The index of the id in this list corresponds to the index into
# grammar_bitmask.

is this right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually not necessarily in spec decoding case. The requests will be in the order of the list but there may be more than one row per request so the index in the list won't match index into the bitmask. It would be true in non-spec decode case though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right -- clarifying the order definition here would be nice in any case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do this in follow-on PR to avoid running the whole CI again!

structured_output_request_ids: list[str]

Check notice on line 169 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function SchedulerOutput: structured_output_request_ids changed from dict[str, int] to list[str]
# the bitmask for the whole batch
grammar_bitmask: "npt.NDArray[np.int32] | None"

Expand Down
65 changes: 30 additions & 35 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from collections import defaultdict
from collections.abc import Iterable
from typing import Any
from typing import TYPE_CHECKING, Any

from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
Expand Down Expand Up @@ -34,6 +34,10 @@
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager

if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt

logger = init_logger(__name__)


Expand Down Expand Up @@ -610,11 +614,8 @@ def schedule(self) -> SchedulerOutput:
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
scheduled_requests = (
scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs
)
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
scheduled_requests, scheduled_spec_decode_tokens
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
Expand Down Expand Up @@ -878,32 +879,28 @@ def _try_schedule_encoder_inputs(

def get_grammar_bitmask(
self,
requests: list[Request],
scheduled_request_ids: Iterable[str],
scheduled_spec_decode_tokens: dict[str, list[int]],
):
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to its index in the batch.
# This will help us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
for i, req in enumerate(requests):
if req.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[req.request_id] = i

) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
# Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order.
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids = [
req_id
for req_id in scheduled_request_ids
if (req := self.requests.get(req_id)) and req.use_structured_output
]
if not structured_output_request_ids:
bitmask = None
else:
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
return structured_output_request_ids, None

bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
return structured_output_request_ids, bitmask

def update_from_output(
Expand Down Expand Up @@ -1011,12 +1008,10 @@ def update_from_output(
new_logprobs = logprobs.slice(req_index, req_index + 1)

if new_token_ids and self.structured_output_manager.should_advance(request):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# checked above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids
)
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)

if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
Expand Down
Loading
Loading