|
7 | 7 | from collections.abc import Iterable |
8 | 8 | from typing import Any |
9 | 9 |
|
| 10 | +import numpy as np |
| 11 | +from pandas._typing import npt |
| 12 | + |
10 | 13 | from vllm.config import VllmConfig |
11 | 14 | from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch |
12 | 15 | from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory |
@@ -610,11 +613,8 @@ def schedule(self) -> SchedulerOutput: |
610 | 613 | scheduled_spec_decode_tokens, |
611 | 614 | req_to_new_blocks, |
612 | 615 | ) |
613 | | - scheduled_requests = ( |
614 | | - scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs |
615 | | - ) |
616 | 616 | structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( |
617 | | - scheduled_requests, scheduled_spec_decode_tokens |
| 617 | + num_scheduled_tokens.keys(), scheduled_spec_decode_tokens |
618 | 618 | ) |
619 | 619 | scheduler_output = SchedulerOutput( |
620 | 620 | scheduled_new_reqs=new_reqs_data, |
@@ -878,32 +878,28 @@ def _try_schedule_encoder_inputs( |
878 | 878 |
|
879 | 879 | def get_grammar_bitmask( |
880 | 880 | self, |
881 | | - requests: list[Request], |
| 881 | + scheduled_request_ids: Iterable[str], |
882 | 882 | scheduled_spec_decode_tokens: dict[str, list[int]], |
883 | | - ): |
884 | | - # NOTE: structured_output_request_ids maps |
885 | | - # a request's (request that uses structured output) |
886 | | - # request_id to its index in the batch. |
887 | | - # This will help us determine to slice the grammar bitmask |
888 | | - # and only applies valid mask for requests that |
889 | | - # uses structured decoding. |
890 | | - structured_output_request_ids: dict[str, int] = {} |
891 | | - for i, req in enumerate(requests): |
892 | | - if req.use_structured_output: |
893 | | - # PERF: in case of chunked prefill, |
894 | | - # request might not include any new tokens. |
895 | | - # Therefore, we might introduce some additional |
896 | | - # cycle to fill in the bitmask, which could be a big no-op. |
897 | | - structured_output_request_ids[req.request_id] = i |
898 | | - |
| 883 | + ) -> tuple[list[str], npt.NDArray[np.int32] | None]: |
| 884 | + # Collect list of scheduled request ids that use structured output. |
| 885 | + # The corresponding rows of the bitmask will be in this order. |
| 886 | + # PERF: in case of chunked prefill, |
| 887 | + # request might not include any new tokens. |
| 888 | + # Therefore, we might introduce some additional |
| 889 | + # cycle to fill in the bitmask, which could be a big no-op. |
| 890 | + structured_output_request_ids = [ |
| 891 | + req_id |
| 892 | + for req_id in scheduled_request_ids |
| 893 | + if (req := self.requests.get(req_id)) and req.use_structured_output |
| 894 | + ] |
899 | 895 | if not structured_output_request_ids: |
900 | | - bitmask = None |
901 | | - else: |
902 | | - bitmask = self.structured_output_manager.grammar_bitmask( |
903 | | - self.requests, |
904 | | - structured_output_request_ids, |
905 | | - scheduled_spec_decode_tokens, |
906 | | - ) |
| 896 | + return structured_output_request_ids, None |
| 897 | + |
| 898 | + bitmask = self.structured_output_manager.grammar_bitmask( |
| 899 | + self.requests, |
| 900 | + structured_output_request_ids, |
| 901 | + scheduled_spec_decode_tokens, |
| 902 | + ) |
907 | 903 | return structured_output_request_ids, bitmask |
908 | 904 |
|
909 | 905 | def update_from_output( |
@@ -1011,12 +1007,10 @@ def update_from_output( |
1011 | 1007 | new_logprobs = logprobs.slice(req_index, req_index + 1) |
1012 | 1008 |
|
1013 | 1009 | if new_token_ids and self.structured_output_manager.should_advance(request): |
1014 | | - # NOTE: structured_output_request |
1015 | | - # should not be None if use_structured_output, we have |
1016 | | - # checked above, so safe to ignore type warning |
1017 | | - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] |
1018 | | - req_id, new_token_ids |
1019 | | - ) |
| 1010 | + struct_output_request = request.structured_output_request |
| 1011 | + assert struct_output_request is not None |
| 1012 | + assert struct_output_request.grammar is not None |
| 1013 | + struct_output_request.grammar.accept_tokens(req_id, new_token_ids) |
1020 | 1014 |
|
1021 | 1015 | if num_nans_in_logits is not None and req_id in num_nans_in_logits: |
1022 | 1016 | request.num_nans_in_logits = num_nans_in_logits[req_id] |
|
0 commit comments