| 
5 | 5 | import time  | 
6 | 6 | from collections import defaultdict  | 
7 | 7 | from collections.abc import Iterable  | 
8 |  | -from typing import Any  | 
 | 8 | +from typing import TYPE_CHECKING, Any  | 
9 | 9 | 
 
  | 
10 | 10 | from vllm.config import VllmConfig  | 
11 | 11 | from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch  | 
 | 
34 | 34 | from vllm.v1.spec_decode.metrics import SpecDecodingStats  | 
35 | 35 | from vllm.v1.structured_output import StructuredOutputManager  | 
36 | 36 | 
 
  | 
 | 37 | +if TYPE_CHECKING:  | 
 | 38 | +    import numpy as np  | 
 | 39 | +    import numpy.typing as npt  | 
 | 40 | + | 
37 | 41 | logger = init_logger(__name__)  | 
38 | 42 | 
 
  | 
39 | 43 | 
 
  | 
@@ -608,11 +612,8 @@ def schedule(self) -> SchedulerOutput:  | 
608 | 612 |             scheduled_spec_decode_tokens,  | 
609 | 613 |             req_to_new_blocks,  | 
610 | 614 |         )  | 
611 |  | -        scheduled_requests = (  | 
612 |  | -            scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs  | 
613 |  | -        )  | 
614 | 615 |         structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(  | 
615 |  | -            scheduled_requests, scheduled_spec_decode_tokens  | 
 | 616 | +            num_scheduled_tokens.keys(), scheduled_spec_decode_tokens  | 
616 | 617 |         )  | 
617 | 618 |         scheduler_output = SchedulerOutput(  | 
618 | 619 |             scheduled_new_reqs=new_reqs_data,  | 
@@ -876,32 +877,28 @@ def _try_schedule_encoder_inputs(  | 
876 | 877 | 
 
  | 
877 | 878 |     def get_grammar_bitmask(  | 
878 | 879 |         self,  | 
879 |  | -        requests: list[Request],  | 
 | 880 | +        scheduled_request_ids: Iterable[str],  | 
880 | 881 |         scheduled_spec_decode_tokens: dict[str, list[int]],  | 
881 |  | -    ):  | 
882 |  | -        # NOTE: structured_output_request_ids maps  | 
883 |  | -        # a request's (request that uses structured output)  | 
884 |  | -        # request_id to its index in the batch.  | 
885 |  | -        # This will help us determine to slice the grammar bitmask  | 
886 |  | -        # and only applies valid mask for requests that  | 
887 |  | -        # uses structured decoding.  | 
888 |  | -        structured_output_request_ids: dict[str, int] = {}  | 
889 |  | -        for i, req in enumerate(requests):  | 
890 |  | -            if req.use_structured_output:  | 
891 |  | -                # PERF: in case of chunked prefill,  | 
892 |  | -                # request might not include any new tokens.  | 
893 |  | -                # Therefore, we might introduce some additional  | 
894 |  | -                # cycle to fill in the bitmask, which could be a big no-op.  | 
895 |  | -                structured_output_request_ids[req.request_id] = i  | 
896 |  | - | 
 | 882 | +    ) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:  | 
 | 883 | +        # Collect list of scheduled request ids that use structured output.  | 
 | 884 | +        # The corresponding rows of the bitmask will be in this order.  | 
 | 885 | +        # PERF: in case of chunked prefill,  | 
 | 886 | +        # request might not include any new tokens.  | 
 | 887 | +        # Therefore, we might introduce some additional  | 
 | 888 | +        # cycle to fill in the bitmask, which could be a big no-op.  | 
 | 889 | +        structured_output_request_ids = [  | 
 | 890 | +            req_id  | 
 | 891 | +            for req_id in scheduled_request_ids  | 
 | 892 | +            if (req := self.requests.get(req_id)) and req.use_structured_output  | 
 | 893 | +        ]  | 
897 | 894 |         if not structured_output_request_ids:  | 
898 |  | -            bitmask = None  | 
899 |  | -        else:  | 
900 |  | -            bitmask = self.structured_output_manager.grammar_bitmask(  | 
901 |  | -                self.requests,  | 
902 |  | -                structured_output_request_ids,  | 
903 |  | -                scheduled_spec_decode_tokens,  | 
904 |  | -            )  | 
 | 895 | +            return structured_output_request_ids, None  | 
 | 896 | + | 
 | 897 | +        bitmask = self.structured_output_manager.grammar_bitmask(  | 
 | 898 | +            self.requests,  | 
 | 899 | +            structured_output_request_ids,  | 
 | 900 | +            scheduled_spec_decode_tokens,  | 
 | 901 | +        )  | 
905 | 902 |         return structured_output_request_ids, bitmask  | 
906 | 903 | 
 
  | 
907 | 904 |     def update_from_output(  | 
@@ -1013,12 +1010,10 @@ def update_from_output(  | 
1013 | 1010 |                 new_logprobs = logprobs.slice(req_index, req_index + 1)  | 
1014 | 1011 | 
 
  | 
1015 | 1012 |             if new_token_ids and self.structured_output_manager.should_advance(request):  | 
1016 |  | -                # NOTE: structured_output_request  | 
1017 |  | -                # should not be None if use_structured_output, we have  | 
1018 |  | -                # checked above, so safe to ignore type warning  | 
1019 |  | -                request.structured_output_request.grammar.accept_tokens(  # type: ignore[union-attr]  | 
1020 |  | -                    req_id, new_token_ids  | 
1021 |  | -                )  | 
 | 1013 | +                struct_output_request = request.structured_output_request  | 
 | 1014 | +                assert struct_output_request is not None  | 
 | 1015 | +                assert struct_output_request.grammar is not None  | 
 | 1016 | +                struct_output_request.grammar.accept_tokens(req_id, new_token_ids)  | 
1022 | 1017 | 
 
  | 
1023 | 1018 |             if num_nans_in_logits is not None and req_id in num_nans_in_logits:  | 
1024 | 1019 |                 request.num_nans_in_logits = num_nans_in_logits[req_id]  | 
 | 
0 commit comments