diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index bd1dd01f9063..63604a335d9f 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations import multiprocessing -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig @@ -40,6 +40,17 @@ def __init__(self, vllm_config: VllmConfig): self._grammar_bitmask: Optional[torch.Tensor] = None self._full_mask = torch.tensor(-1, dtype=torch.int32) + max_batch_size = self.vllm_config.scheduler_config.max_num_seqs + self.fill_bitmask_parallel_threshold = 128 + if self.fill_bitmask_parallel_threshold < max_batch_size: + self.fill_bitmask_parallel_batch_size = 16 + # Use: + # - at least 1 CPU + # - at most half the number of CPUs or 8, whichever is less + max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8)) + self.executor_for_fillmask = ThreadPoolExecutor( + max_workers=max_workers) + if not self.vllm_config.model_config.skip_tokenizer_init: # The default max_workers if not specified is the number of # CPUs * 5, which is way too high since these tasks are CPU-bound, @@ -120,6 +131,26 @@ def _async_create_grammar( assert self.backend is not None return self.backend.compile_grammar(request_type, grammar_spec) + def _fill_bitmasks( + self, + batch: list[tuple[StructuredOutputGrammar, int, bool]], + ) -> None: + assert self._grammar_bitmask is not None + for grammar, index, apply_bitmask in batch: + if apply_bitmask and not grammar.is_terminated(): + grammar.fill_bitmask(self._grammar_bitmask, index) + else: + # Note that for thinking support, we will need to + # reset the relevant part of the bitmask for consequent + # requests here. + self._grammar_bitmask[index].fill_(self._full_mask) + + def _async_submit_fill_bitmask( + self, + batch: list[tuple[StructuredOutputGrammar, int, bool]], + ) -> Future: + return self.executor_for_fillmask.submit(self._fill_bitmasks, batch) + def grammar_bitmask( self, requests: dict[str, Request], @@ -146,7 +177,6 @@ def grammar_bitmask( self.backend.allocate_token_bitmask( max_batch_size * (1 + max_num_spec_tokens)) - bitmask_tensor = self._grammar_bitmask # Generate a batched bitmask for all structured output requests. # When speculative decoding is enabled, we need to include multiple # masks for each request, one for each possible bonus token position. @@ -155,47 +185,61 @@ def grammar_bitmask( ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) - # Note that for thinking support, we will need to - # reset the relevant part of the bitmask for consequent - # request here. - bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_( - self._full_mask) - - # NOTE: This outer loop can likely be parallelized to improve - # performance of bitmask generation for large batches. - for req_id, _ in ordered_seq: - request = requests[req_id] - structured_output_request = request.structured_output_request - - if TYPE_CHECKING: - assert structured_output_request is not None - assert structured_output_request.grammar is not None - apply_bitmask: bool = True - if self.reasoner is not None: - if structured_output_request.reasoning_ended is None: - structured_output_request.reasoning_ended = \ - self.reasoner.is_reasoning_end(request.prompt_token_ids) - apply_bitmask = structured_output_request.reasoning_ended - - state_advancements = 0 - req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] - for i, token in enumerate(req_tokens): - if apply_bitmask and not \ - structured_output_request.grammar.is_terminated(): - structured_output_request.grammar.fill_bitmask( - bitmask_tensor, cumulative_index) - if token is not None: - # In order to generate the correct bitmask for each - # position in the speculative sequence, we advance - # the FSM state for each speculative token and rollback - # to restore the previous state when we are finished. + # Optimized parallel filling of bitmasks for + # non-spec, large-batch-size cases + if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \ + max_num_spec_tokens == 0: + promises = [] + batch = [] + for req_id, _ in ordered_seq: + request = requests[req_id] + structured_output_request = request.structured_output_request + if TYPE_CHECKING: + assert structured_output_request is not None + assert structured_output_request.grammar is not None + + apply_bitmask = self.should_fill_bitmask(request) + batch.append((structured_output_request.grammar, + cumulative_index, apply_bitmask)) + if len(batch) == self.fill_bitmask_parallel_batch_size: + promises.append(self._async_submit_fill_bitmask(batch)) + batch = [] + + cumulative_index += 1 + if batch: + promises.append(self._async_submit_fill_bitmask(batch)) + + # Wait for all bitmask filling tasks to complete. + for promise in promises: + promise.result() + else: + # Fallback to serial filling of bitmasks for small-batch-size cases + for req_id, _ in ordered_seq: + request = requests[req_id] + structured_output_request = request.structured_output_request + + if TYPE_CHECKING: + assert structured_output_request is not None + assert structured_output_request.grammar is not None + apply_bitmask = self.should_fill_bitmask(request) + + state_advancements = 0 + req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + for i, token in enumerate(req_tokens + [None]): + self._fill_bitmasks([(structured_output_request.grammar, + cumulative_index, apply_bitmask)]) + + if apply_bitmask and token is not None and \ + not structured_output_request.grammar.is_terminated(): assert structured_output_request.grammar.accept_tokens( req_id, [token]) state_advancements += 1 - cumulative_index += 1 - if state_advancements > 0: - structured_output_request.grammar.rollback(state_advancements) + cumulative_index += 1 + if state_advancements > 0: + structured_output_request.grammar.rollback( + state_advancements) + bitmask_tensor = self._grammar_bitmask if cumulative_index < bitmask_tensor.shape[0]: bitmask_tensor = bitmask_tensor[:cumulative_index] @@ -204,6 +248,15 @@ def grammar_bitmask( # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() + def should_fill_bitmask(self, request: Request) -> bool: + if self.reasoner is not None: + assert request.structured_output_request is not None + if request.structured_output_request.reasoning_ended is None: + request.structured_output_request.reasoning_ended = \ + self.reasoner.is_reasoning_end(request.prompt_token_ids) + return request.structured_output_request.reasoning_ended + return True + def should_advance(self, request: Request) -> bool: if not request.use_structured_output: return False diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 88544565e544..5e00f6380416 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -148,6 +148,7 @@ class XgrammarGrammar(StructuredOutputGrammar): repr=False, hash=False, init=False) + _is_terminated: bool = field(default=False, repr=False, hash=False) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """Accepts a list of tokens and advances the FSM. @@ -155,6 +156,8 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: Returns True if the FSM was advanced successfully. Returns False if the FSM failed to advance. """ + if self._is_terminated: + return False for token in tokens: if not self.matcher.accept_token(token): logger.error( @@ -162,6 +165,7 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: "for tokens %s. Please file an issue.", request_id, token) return False self.num_processed_tokens += 1 + self._is_terminated = self.matcher.is_terminated() return True def validate_tokens(self, tokens: list[int]) -> list[int]: @@ -184,12 +188,13 @@ def validate_tokens(self, tokens: list[int]) -> list[int]: def rollback(self, num_tokens: int) -> None: self.matcher.rollback(num_tokens) self.num_processed_tokens -= num_tokens + self._is_terminated = self.matcher.is_terminated() def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: self.matcher.fill_next_token_bitmask(bitmask, idx) def is_terminated(self) -> bool: - return self.matcher.is_terminated() + return self._is_terminated def reset(self): self.num_processed_tokens = 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 84ad582c9c9d..89774bc661b7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1278,9 +1278,14 @@ def apply_grammar_bitmask( cumulative_index += 1 + num_spec_tokens grammar_bitmask = sorted_bitmask + # If the grammar bitmask and the logits have the same shape + # we don't need to pass indices to the kernel, + # since the bitmask is already aligned with the logits. + skip_out_indices = grammar_bitmask.shape[0] == logits.shape[0] + # Serialization of np.ndarray is much more efficient than a tensor, # so we receive it in that format. - grammar_bitmask = torch.from_numpy(grammar_bitmask) + grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() # Force use of the torch.compile implementation from xgrammar to work # around issues with the Triton kernel in concurrent structured output @@ -1288,7 +1293,7 @@ def apply_grammar_bitmask( xgr_torch_compile.apply_token_bitmask_inplace_torch_compile( logits, grammar_bitmask.to(self.device, non_blocking=True), - indices=out_indices, + indices=out_indices if not skip_out_indices else None, ) def sync_and_slice_intermediate_tensors(