From 8d757389b6d12abbb67a3df2f74d7da4f65c45a1 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Thu, 18 Sep 2025 13:50:18 +0400 Subject: [PATCH 01/14] async guided decoding Signed-off-by: Vadim Gimpelson --- vllm/v1/core/sched/output.py | 5 +- vllm/v1/core/sched/scheduler.py | 34 +- vllm/v1/engine/core.py | 12 +- vllm/v1/structured_output/__init__.py | 769 +++++++++++++++++++++++--- vllm/v1/structured_output/request.py | 53 +- vllm/v1/worker/gpu_model_runner.py | 7 +- vllm/v1/worker/tpu_model_runner.py | 7 +- 7 files changed, 737 insertions(+), 150 deletions(-) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 3ec5b91bf286..b532dd0c8217 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -7,10 +7,9 @@ from typing import TYPE_CHECKING, Optional from vllm._bc_linter import bc_linter_include +from vllm.v1.structured_output import GrammarBitmaskPlaceholder if TYPE_CHECKING: - import numpy as np - import numpy.typing as npt from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata) @@ -148,7 +147,7 @@ class SchedulerOutput: # for filling the next token bitmask structured_output_request_ids: dict[str, int] # the bitmask for the whole batch - grammar_bitmask: Optional[npt.NDArray[np.int32]] + grammar_bitmask: GrammarBitmaskPlaceholder # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 85ca858ad7bd..be8a3303b179 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -354,7 +354,8 @@ def schedule(self) -> SchedulerOutput: # for FSM compilation. if request.status == RequestStatus.WAITING_FOR_FSM: structured_output_req = request.structured_output_request - if structured_output_req and structured_output_req.grammar: + if structured_output_req and \ + structured_output_req.is_grammar_ready: request.status = RequestStatus.WAITING else: self.waiting.pop_request() @@ -851,7 +852,9 @@ def get_grammar_bitmask( if not structured_output_request_ids: bitmask = None else: - bitmask = self.structured_output_manager.grammar_bitmask( + # Submit async grammar bitmask computation, return the placeholder + # The actual result will be retrieved later in gpu_model_runner + bitmask = self.structured_output_manager.submit_grammar_bitmask( self.requests, structured_output_request_ids, scheduled_spec_decode_tokens, @@ -869,6 +872,7 @@ def update_from_output( num_scheduled_tokens = scheduler_output.num_scheduled_tokens pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits + structured_list = [] outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None @@ -942,11 +946,7 @@ def update_from_output( if new_token_ids and self.structured_output_manager.should_advance( request): - # NOTE: structured_output_request - # should not be None if use_structured_output, we have - # checked above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) + structured_list.append((req_id, new_token_ids)) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] @@ -982,6 +982,9 @@ def update_from_output( # This is a rare case and unlikely to impact performance. self.waiting.remove_requests(stopped_preempted_reqs) + self.structured_output_manager.submit_batch_accept_tokens( + structured_list) + # KV Connector: update state for finished KV Transfers. if model_runner_output.kv_connector_output: self._update_from_kv_xfer_finished( @@ -1066,6 +1069,8 @@ def update_draft_token_ids( self, draft_token_ids: DraftTokenIds, ) -> None: + spec_structured_dict = {} + for req_id, spec_token_ids in zip( draft_token_ids.req_ids, draft_token_ids.draft_token_ids, @@ -1080,12 +1085,21 @@ def update_draft_token_ids( # NOTE(woosuk): request.spec_token_ids should be updated. request.spec_token_ids.clear() elif self.structured_output_manager.should_advance(request): - metadata = request.structured_output_request - request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids) + spec_structured_dict[req_id] = spec_token_ids else: request.spec_token_ids = spec_token_ids + # Batch validate tokens for structured output requests + spec_structured_dict = ( + self.structured_output_manager.submit_batch_validate_tokens( + spec_structured_dict)) + + # Update requests with validated tokens + for req_id in spec_structured_dict: + request = self.requests.get(req_id) + if request is not None: + request.spec_token_ids = spec_structured_dict[req_id] + def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" return len(self.running), len(self.waiting) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a022e9c0d705..add0c892437d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -356,7 +356,7 @@ def step_with_batch_queue( return engine_core_outputs, model_executed def shutdown(self): - self.structured_output_manager.clear_backend() + self.structured_output_manager.shutdown() if self.model_executor: self.model_executor.shutdown() if self.scheduler: @@ -446,11 +446,11 @@ def preprocess_add_request( self.request_block_hasher) if req.use_structured_output: # Note on thread safety: no race condition. - # `grammar_init` is only invoked in input processing thread. For - # `structured_output_manager`, each request is independent and - # grammar compilation is async. Scheduler always checks grammar - # compilation status before scheduling request. - self.structured_output_manager.grammar_init(req) + # `submit_grammar_init` is only invoked in input processing thread. + # For `structured_output_manager`, each request is independent + # and grammar compilation is async. Scheduler always checks + # grammar compilation status before scheduling request. + self.structured_output_manager.submit_grammar_init(req) return req, request.current_wave diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 1ab29dfecd9e..4d26e6998780 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -3,21 +3,32 @@ from __future__ import annotations import multiprocessing +import queue from concurrent.futures import Future, ThreadPoolExecutor -from typing import TYPE_CHECKING, Optional +from dataclasses import dataclass +from enum import Enum +from functools import partial +from multiprocessing import shared_memory +from typing import TYPE_CHECKING, Any, Optional +from weakref import finalize + +import numpy as np from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs -from vllm.utils import LazyLoader +from vllm.utils import LazyLoader, get_mp_context from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar) from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend +from vllm.v1.structured_output.request import FutureGrammar + +GRAMMAR_BITMASK_SHM_NAME = "vllm_grammar_bitmask_shm" +GRAMMAR_READY_FLAG_SHM_NAME = "vllm_grammar_ready_flag" if TYPE_CHECKING: - import numpy as np import numpy.typing as npt import torch @@ -27,18 +38,613 @@ torch = LazyLoader("torch", globals(), "torch") logger = init_logger(__name__) +""" +## Main Classes + +This module contains 3 main classes for structured output processing: + +### 1. StructuredOutputManager +**Process**: Main +**Purpose**: Queue-based manager that coordinates structured output operations +**Responsibilities**: +- Submit tasks via multiprocessing queues +- Retrieve results from child process +- Create and manage single child process + +### 2. StructuredOutputGateway +**Process**: Child +**Purpose**: Background process that receives and executes tasks +**Responsibilities**: +- Receive tasks from task queue +- Execute tasks using StructuredOutputExecutor +- Send results back via result queues +- Isolate heavy computation from main process + +### 3. StructuredOutputExecutor +**Process**: Child (inside gateway) +**Purpose**: Performs actual structured output operations +**Responsibilities**: +- Grammar compilation +- Bitmask generation +- Token acceptance +- All core structured output work + + +## Communication Architecture + +The structured output system uses a combination of multiprocessing queues and +shared memory for communication between the main process +(StructuredOutputManager) and the child process (StructuredOutputGateway): + +### 1. task_queue +**Direction**: Main → Child +**Purpose**: Sends task requests from manager to gateway + +### 2. bitmask_shared_memory + grammar_ready_flag +**Direction**: Child → Main +**Purpose**: Returns grammar bitmask results for token generation via shared +memory +**Usage**: Runs asynchronously with submit_grammar_bitmask() and return +GrammarBitmaskPlaceholder. When GrammarBitmaskPlaceholder.result() is called, +polls shared memory flag until bitmask is ready +**Content**: Bitmask numpy array written to shared memory with ready flag +signaling +**Implementation**: Uses `GRAMMAR_BITMASK_SHM_NAME` and +`GRAMMAR_READY_FLAG_SHM_NAME` + +### 3. batch_validate_result_queue +**Direction**: Child → Main +**Purpose**: Validated token sequences for speculative decoding +**Usage**: Synchronous - manager blocks until validation completes +**Content**: StructuredOutputResult with validated token dictionaries + +### 4. grammar_init_notification_queue +**Direction**: Child → Main +**Purpose**: Notifies main process when grammar initialization completes +**Usage**: Asynchronous - manager polls queue to update initialization +status +**Content**: StructuredOutputResult with completed request_id + +### Communication Flow +1. **Task Submission**: Manager creates StructuredOutputTask + and puts in task_queue +2. **Task Execution**: Gateway retrieves task, + executes via StructuredOutputExecutor +3. **Result Routing**: Gateway routes results to appropriate result queue + based on task type +4. **Result Retrieval**: Manager retrieves results (synchronously or polls + asynchronously) +""" + + +@dataclass +class GrammarInitData: + """ + Lightweight data structure containing only the necessary fields + from Request for grammar initialization. + """ + request_id: str + guided_decoding_backend: str + structured_output_key: tuple + + @classmethod + def from_request(cls, request: Request) -> GrammarInitData: + assert request.structured_output_request is not None + assert request.sampling_params is not None + return cls(request_id=request.request_id, + guided_decoding_backend=request.sampling_params. + guided_decoding.backend, + structured_output_key=request.structured_output_request. + structured_output_key) + + +class TaskType(Enum): + GRAMMAR_INIT = 1 + GRAMMAR_DELETE = 2 + GRAMMAR_BITMASK = 3 + BATCH_ACCEPT_TOKENS = 4 + BATCH_VALIDATE_TOKENS = 5 + CLEAR_BACKEND = 6 + SHUTDOWN = 7 + + +class StructuredOutputTask: + + def __init__(self, task_type: TaskType, args: tuple, kwargs: dict): + self.task_type = task_type + self.args = args + self.kwargs = kwargs + + +class StructuredOutputResult: + + def __init__(self, + task_type: TaskType, + result: Any, + error: Optional[Exception] = None): + self.task_type = task_type + self.result = result + self.error = error + + +class StructuredOutputGateway: + """ + Runs on single CHILD process (created by StructuredOutputManager). + Background process that receives tasks from queue, executes them using + StructuredOutputExecutor, and sends results back via queues. Isolates heavy + computation from the main process. + """ + + def __init__(self, task_queue, batch_validate_result_queue, + grammar_init_notification_queue, vllm_config: VllmConfig): + self.task_queue = task_queue + self.batch_validate_result_queue = batch_validate_result_queue + self.grammar_init_notification_queue = grammar_init_notification_queue + self.vllm_config = vllm_config + self.structured_output_executor: Optional[ + StructuredOutputExecutor] = None + + @staticmethod + def run_gateway(task_queue, batch_validate_result_queue, + grammar_init_notification_queue, vllm_config: VllmConfig): + """Static method to run the gateway in a separate process.""" + gateway = StructuredOutputGateway(task_queue, + batch_validate_result_queue, + grammar_init_notification_queue, + vllm_config) + gateway.run() + + def run(self): + """Main processing loop for the child process.""" + logger.debug("StructuredOutputGateway starting - PID: %s", + multiprocessing.current_process().pid) + self.structured_output_executor = StructuredOutputExecutor( + self.vllm_config) + + # Attach to shared memory in child process + self.bitmask_shm = shared_memory.SharedMemory( + name=GRAMMAR_BITMASK_SHM_NAME) + + while True: + try: + task = self.task_queue.get() + if task.task_type == TaskType.SHUTDOWN: + logger.debug("StructuredOutputGateway shutting down") + self.bitmask_shm.close() + break + result = self._execute_task(task) + # Only put result in queue if it's needed + if task.task_type == TaskType.GRAMMAR_INIT: + # Notify main process that grammar init is complete + self.grammar_init_notification_queue.put(result) + elif task.task_type == TaskType.GRAMMAR_BITMASK: + # Write bitmask to shared memory and signal flag + if result.result is not None and result.error is None: + self._write_bitmask_to_shared_memory(result.result) + else: + # Set the flag even on error so result() doesn't hang + flag_shm = shared_memory.SharedMemory( + name=GRAMMAR_READY_FLAG_SHM_NAME) + flag_shm.buf[0] = 1 + flag_shm.close() + elif task.task_type == TaskType.BATCH_VALIDATE_TOKENS: + self.batch_validate_result_queue.put(result) + except Exception as e: + logger.debug("Error in StructuredOutputGateway: %s", e) + if task.task_type == TaskType.GRAMMAR_INIT: + error_result = StructuredOutputResult( + task.task_type, None, e) + self.grammar_init_notification_queue.put(error_result) + elif task.task_type == TaskType.GRAMMAR_BITMASK: + # Signal flag even on error so result() doesn't hang + flag_shm = shared_memory.SharedMemory( + name=GRAMMAR_READY_FLAG_SHM_NAME) + flag_shm.buf[0] = 1 + flag_shm.close() + elif task.task_type == TaskType.BATCH_VALIDATE_TOKENS: + error_result = StructuredOutputResult( + task.task_type, None, e) + self.batch_validate_result_queue.put(error_result) + + def _execute_task(self, + task: StructuredOutputTask) -> StructuredOutputResult: + assert self.structured_output_executor is not None + try: + if task.task_type == TaskType.GRAMMAR_INIT: + self.structured_output_executor.grammar_init( + *task.args, **task.kwargs) + # Return the request_id so Gateway can notify main process + grammar_init_data = task.args[0] + return StructuredOutputResult(task.task_type, + grammar_init_data.request_id) + elif task.task_type == TaskType.GRAMMAR_DELETE: + self.structured_output_executor.grammar_delete( + *task.args, **task.kwargs) + return StructuredOutputResult(task.task_type, None) + elif task.task_type == TaskType.GRAMMAR_BITMASK: + result = self.structured_output_executor.grammar_bitmask( + *task.args, **task.kwargs) + return StructuredOutputResult(task.task_type, result) + elif task.task_type == TaskType.BATCH_ACCEPT_TOKENS: + self.structured_output_executor.batch_accept_tokens( + *task.args, **task.kwargs) + return StructuredOutputResult(task.task_type, None) + elif task.task_type == TaskType.BATCH_VALIDATE_TOKENS: + result = self.structured_output_executor.batch_validate_tokens( + *task.args, **task.kwargs) + return StructuredOutputResult(task.task_type, result) + elif task.task_type == TaskType.CLEAR_BACKEND: + self.structured_output_executor.clear_backend() + return StructuredOutputResult(task.task_type, None) + else: + raise ValueError(f"Unknown task type: {task.task_type}") + except Exception as e: + return StructuredOutputResult(task.task_type, None, e) + + def _write_bitmask_to_shared_memory(self, bitmask_array: np.ndarray): + """Write bitmask numpy array to shared memory.""" + # Write shape info first (2 int32 values) + shape_info = np.array(bitmask_array.shape, dtype=np.int32) + self.bitmask_shm.buf[:8] = shape_info.tobytes() + + # Write actual data + data_bytes = bitmask_array.tobytes() + self.bitmask_shm.buf[8:8 + len(data_bytes)] = data_bytes + + # Ensure all writes are visible before setting the flag + # Python's SharedMemory is backed by mmap, which provides coherent + # memory access across processes. On most architectures, the memory + # model ensures that when process B opens a SharedMemory object and + # sees flag=1, all previous writes by process A will be visible. + # This is because SharedMemory creation/attachment involves system + # calls that act as memory barriers. + + # Set the flag to indicate bitmask is ready + flag_shm = shared_memory.SharedMemory(name=GRAMMAR_READY_FLAG_SHM_NAME) + flag_shm.buf[0] = 1 + flag_shm.close() class StructuredOutputManager: - """Engine-level manager for structured output requests.""" + """ + Runs on MAIN process. Queue-based manager that coordinates structured + output operations by submitting tasks via multiprocessing queues and + retrieving results. Creates and manages a single child process. Methods + with `submit_` prefix are run on child process by passing the task via + queues. Other methods are executed on the current process. + """ def __init__(self, vllm_config: VllmConfig): - self.backend: Optional[StructuredOutputBackend] = None + + self.vllm_config = vllm_config self.reasoner: Optional[ReasoningParser] = None + + if not self.vllm_config.model_config.skip_tokenizer_init: + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config, + ).get_lora_tokenizer(None) + + reasoning_backend = vllm_config.decoding_config.reasoning_backend + if reasoning_backend: + reasoner_cls = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + self.reasoner = reasoner_cls(tokenizer=self.tokenizer) + + # Set to track initialized grammars in main process + self.initialized_grammars: set[str] = set() + + # Start the child process using vLLM's multiprocessing context + mp_context = get_mp_context() + self.task_queue = mp_context.Queue() + self.batch_validate_result_queue = mp_context.Queue() + self.grammar_init_notification_queue = mp_context.Queue() + + # Create shared memory for bitmask results + # Calculate size based on max batch size and vocab size + max_batch_size = vllm_config.scheduler_config.max_num_seqs + max_spec_tokens = 0 + if vllm_config.speculative_config is not None: + max_spec_tokens = \ + vllm_config.speculative_config.num_speculative_tokens + vocab_size = vllm_config.model_config.get_vocab_size() + # Size: max_num_seqs * ((vocab_size + 31) // 32) * 4 bytes + # + 8 bytes for shape + max_num_seqs = max_batch_size * (1 + max_spec_tokens) + shm_size = max_num_seqs * ((vocab_size + 31) // 32) * 4 + 8 + + self._cleanup_existing_shared_memory() + self.bitmask_shm = shared_memory.SharedMemory( + create=True, size=shm_size, name=GRAMMAR_BITMASK_SHM_NAME) + + # Create shared memory flag for bitmask ready signaling + self.flag_shm = shared_memory.SharedMemory( + name=GRAMMAR_READY_FLAG_SHM_NAME, create=True, size=1) + self.flag_shm.buf[0] = 0 # Initialize flag to 0 (not ready) + + # Create a partial config with only the required fields + partial_config = VllmConfig( + scheduler_config=self.vllm_config.scheduler_config, + model_config=self.vllm_config.model_config, + speculative_config=self.vllm_config.speculative_config, + decoding_config=self.vllm_config.decoding_config, + lora_config=self.vllm_config.lora_config, + ) + + self.gateway_process = mp_context.Process( + target=StructuredOutputGateway.run_gateway, + name="StructuredOutputGateway", + args=(self.task_queue, self.batch_validate_result_queue, + self.grammar_init_notification_queue, partial_config), + daemon=True) + self.gateway_process.start() + + logger.debug( + "StructuredOutputManager started with child process PID: %s", + self.gateway_process.pid) + + def submit_grammar_bitmask(self, requests: dict[str, Request], + structured_output_request_ids: dict[str, int], + scheduled_spec_decode_tokens: dict[str, + list[int]]): + """Submit grammar_bitmask task asynchronously.""" + if not structured_output_request_ids: + return None + + # Clear the flag for new bitmask calculation + self.flag_shm.buf[0] = 0 + + self.update_reasoning_ended(requests, structured_output_request_ids) + req_reasoning_ended = {} + for request_id, _ in structured_output_request_ids.items(): + request = requests[request_id] + assert request.structured_output_request is not None + req_reasoning_ended[request_id] = ( + request.structured_output_request.reasoning_ended) + + task = StructuredOutputTask( + TaskType.GRAMMAR_BITMASK, + (structured_output_request_ids, req_reasoning_ended, + scheduled_spec_decode_tokens), {}) + self.task_queue.put(task) + + # Return a placeholder that consumer can check + return GrammarBitmaskPlaceholder() + + def submit_batch_accept_tokens( + self, request_id_to_new_token_ids: list[tuple[str, list[int]]]): + """Submit batch_accept_tokens task (fire-and-forget).""" + if len(request_id_to_new_token_ids) == 0: + return + task = StructuredOutputTask(TaskType.BATCH_ACCEPT_TOKENS, + (request_id_to_new_token_ids, ), {}) + self.task_queue.put(task) + + def submit_batch_validate_tokens( + self, + request_id_to_token_ids: dict[str, + list[int]]) -> dict[str, list[int]]: + """Validate tokens for multiple requests and return validated tokens.""" + if len(request_id_to_token_ids) == 0: + return {} + + task = StructuredOutputTask(TaskType.BATCH_VALIDATE_TOKENS, + (request_id_to_token_ids, ), {}) + self.task_queue.put(task) + result = self.batch_validate_result_queue.get() + if result.error: + raise Exception(f"Error in batch_validate_tokens: {result.error}") + return result.result + + def submit_grammar_init(self, request): + """Submit grammar_init task.""" + if request.structured_output_request is None: + return + + # Extract only the necessary fields from request + # to reduce data transfer overhead + grammar_init_data = GrammarInitData.from_request(request) + + task = StructuredOutputTask(TaskType.GRAMMAR_INIT, + (grammar_init_data, ), {}) + self.task_queue.put(task) + # Set up automatic cleanup when structured_output_request + # is garbage collected + finalize(request.structured_output_request, + partial(self.submit_grammar_delete, request.request_id)) + # Set the compiled_grammar AFTER putting the task in the queue + # to avoid pickling the callback + request.structured_output_request.compiled_grammar = FutureGrammar( + self._is_grammar_init_done, request.request_id) + + def _is_grammar_init_done(self, request_id: str) -> bool: + # Read all available notifications from the queue + # and add them to the set + while not self.grammar_init_notification_queue.empty(): + try: + result = self.grammar_init_notification_queue.get_nowait() + if result.error: + # Log error but don't add to initialized set + logger.debug("Error in grammar initialization: %s", + result.error) + else: + completed_request_id = result.result + self.initialized_grammars.add(completed_request_id) + except queue.Empty: + # Queue is empty + break + + # Check if this request_id is in our set of initialized grammars + return request_id in self.initialized_grammars + + def submit_grammar_delete(self, request_id: str): + """Submit grammar_delete task (fire-and-forget).""" + task = StructuredOutputTask(TaskType.GRAMMAR_DELETE, (request_id, ), + {}) + self.task_queue.put(task) + + # Remove from our set of initialized grammars + self.initialized_grammars.discard(request_id) + + def _submit_clear_backend(self): + """Submit clear_backend task (fire-and-forget).""" + task = StructuredOutputTask(TaskType.CLEAR_BACKEND, (), {}) + self.task_queue.put(task) + + def update_reasoning_ended(self, requests: dict[str, Request], + structured_output_request_ids: dict[str, int]): + """Update the reasoning_ended flag for the given requests.""" + if self.reasoner is not None: + for request_id, _ in structured_output_request_ids.items(): + request = requests[request_id] + structured_output = request.structured_output_request + assert structured_output is not None + if structured_output.reasoning_ended is None: + structured_output.reasoning_ended = \ + self.reasoner.is_reasoning_end(request.prompt_token_ids) + + def should_advance(self, request: Request) -> bool: + """Determine whether we can advance the FSM.""" + if not request.use_structured_output: + return False + + # To determine whether we can advance the FSM. + # Supports thinking usage where we skip the reasoning components. + if TYPE_CHECKING: + assert request.structured_output_request is not None + assert (request.structured_output_request.compiled_grammar + is not None) + # by default, we should always advance + # for cases that doesn't uses thinking mode. + if self.reasoner is not None: + structured_req = request.structured_output_request + + if structured_req.reasoning_ended: + return True + + # Check if reasoning ends in *this* step + if self.reasoner.is_reasoning_end(request.all_token_ids): + # Reasoning just ended, so we shouldn't advanced til + # next pass + structured_req.reasoning_ended = True + + return False + else: + return True + + def shutdown(self): + """Shutdown the manager and child process.""" + self._submit_clear_backend() + task = StructuredOutputTask(TaskType.SHUTDOWN, (), {}) + self.task_queue.put(task) + self.gateway_process.join(timeout=5) + if self.gateway_process.is_alive(): + logger.debug("Force terminating StructuredOutputGateway") + self.gateway_process.terminate() + self.gateway_process.join() + + # Clean up shared memory + if hasattr(self, 'bitmask_shm'): + self.bitmask_shm.close() + self.bitmask_shm.unlink() + + if hasattr(self, 'flag_shm'): + self.flag_shm.close() + self.flag_shm.unlink() + + def _cleanup_existing_shared_memory(self): + """Clean up any existing shared memory segments from previous runs. + + This handles cases where a previous process was killed abruptly + and didn't properly clean up its shared memory segments. + """ + # Try to unlink bitmask shared memory if it exists + try: + existing_bitmask_shm = shared_memory.SharedMemory( + name=GRAMMAR_BITMASK_SHM_NAME) + existing_bitmask_shm.close() + existing_bitmask_shm.unlink() + logger.info("Cleaned up existing bitmask shared memory from " + "previous run") + except FileNotFoundError: + # No existing shared memory, which is fine + pass + except Exception as e: + logger.warning("Error cleaning up bitmask shared memory: %s", e) + + # Try to unlink flag shared memory if it exists + try: + existing_flag_shm = shared_memory.SharedMemory( + name=GRAMMAR_READY_FLAG_SHM_NAME) + existing_flag_shm.close() + existing_flag_shm.unlink() + logger.info("Cleaned up existing flag shared memory from " + "previous run") + except FileNotFoundError: + # No existing shared memory, which is fine + pass + except Exception as e: + logger.warning("Error cleaning up flag shared memory: %s", e) + + +class GrammarBitmaskPlaceholder: + """ + Placeholder object that gpu_model_runner.py can check and get result from. + Uses shared memory flag to wait for result + and shared memory to retrieve bitmask data. + """ + + def __init__(self): + pass + + def result(self) -> np.ndarray: + import time + + import numpy as np + + # Poll the shared memory flag until it's set + flag_shm = shared_memory.SharedMemory(name=GRAMMAR_READY_FLAG_SHM_NAME) + while True: + flag_value = flag_shm.buf[0] + + if flag_value == 1: # Flag is set, bitmask is ready + break + + time.sleep(0.001) # Short sleep to avoid busy waiting + + flag_shm.close() + + bitmask_shm = shared_memory.SharedMemory(name=GRAMMAR_BITMASK_SHM_NAME) + # Read shape info first + shape_bytes = bytes(bitmask_shm.buf[:8]) # Create a copy of the bytes + shape = np.frombuffer(shape_bytes, dtype=np.int32).copy() + # Read bitmask data + data_size = shape[0] * shape[1] * 4 # int32 = 4 bytes + # Create a complete copy of the data before closing shared memory + data_bytes = bytes(bitmask_shm.buf[8:8 + data_size]) + bitmask_shm.close() + # Now create the numpy array from the copied bytes + bitmask = np.frombuffer(data_bytes, + dtype=np.int32).reshape(shape).copy() + return bitmask + + +class StructuredOutputExecutor: + """ + Runs on CHILD process (inside StructuredOutputGateway). + Executor that performs the actual structured output work + including grammar compilation, bitmask generation, and token acceptance. + All communication between processes happens via queues. + """ + + def __init__(self, vllm_config: VllmConfig): + self.backend: Optional[StructuredOutputBackend] = None self.vllm_config = vllm_config self._grammar_bitmask: Optional[torch.Tensor] = None self._full_mask = torch.tensor(-1, dtype=torch.int32) + self.request_id_to_grammar: dict[str, StructuredOutputGrammar] = {} max_batch_size = self.vllm_config.scheduler_config.max_num_seqs self.fill_bitmask_parallel_threshold = 128 @@ -52,37 +658,29 @@ def __init__(self, vllm_config: VllmConfig): max_workers=max_workers) if not self.vllm_config.model_config.skip_tokenizer_init: - # The default max_workers if not specified is the number of - # CPUs * 5, which is way too high since these tasks are CPU-bound, - # not I/O bound. We also know we would never dominate CPU usage - # with just grammar compilation, so we set it to half the number - # of CPUs. + # The default max_workers if not specified is the number of CPUs*5, + # which is way too high since these tasks are CPU-bound, + # not I/O bound. + # We also know we would never dominate CPU usage with just grammar + # compilation, so we set it to half the number of CPUs. max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) self.tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config) - reasoning_backend = \ - self.vllm_config.decoding_config.reasoning_backend - if reasoning_backend: - reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_backend) - self.reasoner = reasoner_cls(tokenizer=self.tokenizer) - - def grammar_init(self, request: Request) -> None: - if request.structured_output_request is None: - return + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, + ).get_lora_tokenizer(None) - if TYPE_CHECKING: - assert request.sampling_params is not None and \ - request.sampling_params.guided_decoding is not None + def _get_grammar(self, request_id: str) -> StructuredOutputGrammar: + return self.request_id_to_grammar[request_id] + def grammar_init(self, grammar_init_data: GrammarInitData) -> None: # Initialize the backend the first time it is needed. # # NOTE: We only support a single backend. We do NOT support different # backends on a per-request basis in V1 (for now, anyway...). if self.backend is None: - assert request.sampling_params is not None - backend = request.sampling_params.guided_decoding.backend + backend = grammar_init_data.guided_decoding_backend vocab_size = self.vllm_config.model_config.get_vocab_size() if backend == "xgrammar": self.backend = XgrammarBackend( @@ -117,14 +715,15 @@ def grammar_init(self, request: Request) -> None: raise ValueError( f"Unsupported structured output backend: {backend}") - grammar = self.executor.submit(self._async_create_grammar, request) - request.structured_output_request.grammar = grammar # type: ignore[assignment] + grammar = self.executor.submit(self._async_create_grammar, + grammar_init_data).result() + self.request_id_to_grammar[grammar_init_data.request_id] = grammar def _async_create_grammar( self, - request: Request, + grammar_init_data: GrammarInitData, ) -> StructuredOutputGrammar: - key = request.structured_output_request.structured_output_key # type: ignore[union-attr] + key = grammar_init_data.structured_output_key # Note that the request was validated in the engine core client, # so at this point we know it is a supported type of request. @@ -158,8 +757,8 @@ def _async_submit_fill_bitmask( def grammar_bitmask( self, - requests: dict[str, Request], structured_output_request_ids: dict[str, int], + req_reasoning_ended: dict[str, Optional[bool]], scheduled_spec_decode_tokens: dict[str, list[int]], ) -> Optional[npt.NDArray[np.int32]]: # Prepare the structured output bitmask for this batch. @@ -196,16 +795,12 @@ def grammar_bitmask( max_num_spec_tokens == 0: promises = [] batch = [] + for req_id, _ in ordered_seq: - request = requests[req_id] - structured_output_request = request.structured_output_request - if TYPE_CHECKING: - assert structured_output_request is not None - assert structured_output_request.grammar is not None - - apply_bitmask = self.should_fill_bitmask(request) - batch.append((structured_output_request.grammar, - cumulative_index, apply_bitmask)) + grammar = self._get_grammar(req_id) + apply_bitmask = self.should_fill_bitmask( + req_reasoning_ended, req_id) + batch.append((grammar, cumulative_index, apply_bitmask)) if len(batch) == self.fill_bitmask_parallel_batch_size: promises.append(self._async_submit_fill_bitmask(batch)) batch = [] @@ -220,74 +815,66 @@ def grammar_bitmask( else: # Fallback to serial filling of bitmasks for small-batch-size cases for req_id, _ in ordered_seq: - request = requests[req_id] - structured_output_request = request.structured_output_request - - if TYPE_CHECKING: - assert structured_output_request is not None - assert structured_output_request.grammar is not None - apply_bitmask = self.should_fill_bitmask(request) + grammar = self._get_grammar(req_id) + apply_bitmask = self.should_fill_bitmask( + req_reasoning_ended, req_id) state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) for i, token in enumerate(req_tokens + [None]): - self._fill_bitmasks([(structured_output_request.grammar, - cumulative_index, apply_bitmask)]) + self._fill_bitmasks([(grammar, cumulative_index, + apply_bitmask)]) if apply_bitmask and token is not None and \ - not structured_output_request.grammar.is_terminated(): - assert structured_output_request.grammar.accept_tokens( - req_id, [token]) + not grammar.is_terminated(): + assert grammar.accept_tokens(req_id, [token]) state_advancements += 1 cumulative_index += 1 if state_advancements > 0: - structured_output_request.grammar.rollback( - state_advancements) + grammar.rollback(state_advancements) bitmask_tensor = self._grammar_bitmask if cumulative_index < bitmask_tensor.shape[0]: bitmask_tensor = bitmask_tensor[:cumulative_index] - - # After finishing with the xgrammar operations, we convert to - # np.ndarray, because that is much more efficient for serialization - # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() - def should_fill_bitmask(self, request: Request) -> bool: - if self.reasoner is not None: - assert request.structured_output_request is not None - if request.structured_output_request.reasoning_ended is None: - request.structured_output_request.reasoning_ended = \ - self.reasoner.is_reasoning_end(request.prompt_token_ids) - return request.structured_output_request.reasoning_ended + def should_fill_bitmask(self, req_reasoning_ended, req_id) -> bool: + if req_reasoning_ended[req_id] is not None: + return req_reasoning_ended[req_id] return True - def should_advance(self, request: Request) -> bool: - if not request.use_structured_output: - return False - - # To determine whether we can advance the FSM. - # Supports thinking usage where we skip the reasoning components. - if TYPE_CHECKING: - assert request.structured_output_request is not None - assert request.structured_output_request.grammar is not None - # by default, we should always advance - # for cases that don't use thinking mode. - if self.reasoner is not None: - structured_req = request.structured_output_request - - if structured_req.reasoning_ended: - return True - - # Check if reasoning ends in *this* step - if self.reasoner.is_reasoning_end(request.all_token_ids): - # Reasoning just ended, so we shouldn't advance til - # next pass - structured_req.reasoning_ended = True - - return False - else: - return True + def _accept_tokens(self, request_id: str, + new_token_ids: list[int]) -> None: + grammar = self._get_grammar(request_id) + grammar.accept_tokens(request_id, new_token_ids) + + def batch_accept_tokens( + self, request_id_to_new_token_ids: list[tuple[str, + list[int]]]) -> None: + for req_id, new_token_ids in request_id_to_new_token_ids: + self._accept_tokens(req_id, new_token_ids) + + def batch_validate_tokens( + self, + request_id_to_token_ids: dict[str, + list[int]]) -> dict[str, list[int]]: + """Validate tokens for multiple requests without advancing the FSM + state.""" + result = {} + for req_id, token_ids in request_id_to_token_ids.items(): + grammar = self._get_grammar(req_id) + validated_tokens = grammar.validate_tokens(token_ids) + result[req_id] = validated_tokens + return result + + def grammar_delete(self, request_id: str) -> None: + """Remove grammar for the given request_id to prevent memory leaks.""" + if request_id in self.request_id_to_grammar: + grammar = self.request_id_to_grammar[request_id] + # Reset the grammar state before deletion for clean cleanup + grammar.reset() + # Remove from dictionary to allow garbage collection + del self.request_id_to_grammar[request_id] def clear_backend(self) -> None: if self.backend is not None: diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index fc365f12573f..582c5634f819 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -5,53 +5,36 @@ import dataclasses import functools import json -from concurrent.futures import Future -from concurrent.futures._base import TimeoutError -from typing import Optional, Union, cast +from typing import Callable, Optional from vllm.sampling_params import SamplingParams -from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar, - StructuredOutputKey, +from vllm.v1.structured_output.backend_types import (StructuredOutputKey, StructuredOutputOptions) +class FutureGrammar: + + def __init__(self, call_back: Callable[[str], bool], request_id: str): + self.call_back = call_back + self.request_id = request_id + + def done(self): + if self.call_back is None: + return False + return self.call_back(self.request_id) + + @dataclasses.dataclass class StructuredOutputRequest: - sampling_params: SamplingParams - _grammar: Optional[Union[Future[StructuredOutputGrammar], - StructuredOutputGrammar]] = None + compiled_grammar: Optional[FutureGrammar] = None reasoning_ended: Optional[bool] = None - def _check_grammar_completion(self) -> bool: - # NOTE: We have to lazy import to gate circular imports - from vllm.v1.request import RequestStatus - - if isinstance(self._grammar, Future): - try: - # We will check whether the future is ready within 100 us - self._grammar = self._grammar.result(timeout=0.0001) - self.status = RequestStatus.WAITING - except TimeoutError: - return False - return True - @property def is_grammar_ready(self) -> bool: - return self._check_grammar_completion() - - @property - def grammar(self) -> Optional[StructuredOutputGrammar]: - completed = self._check_grammar_completion() - return cast(Optional[StructuredOutputGrammar], - self._grammar) if completed else None - - @grammar.setter - def grammar( - self, grammar: Union[StructuredOutputGrammar, - Future[StructuredOutputGrammar]] - ) -> None: - self._grammar = grammar + if self.compiled_grammar is None: + return False + return self.compiled_grammar.done() @functools.cached_property def structured_output_key(self) -> StructuredOutputKey: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f256dc160a6b..77c3f532b570 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -105,6 +105,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.structured_output import GrammarBitmaskPlaceholder else: xgr = LazyLoader("xgr", globals(), "xgrammar") @@ -1603,9 +1604,9 @@ def apply_grammar_bitmask( scheduler_output: "SchedulerOutput", logits: torch.Tensor, ): - grammar_bitmask = scheduler_output.grammar_bitmask - if grammar_bitmask is None: - return + grammar_bitmask_placeholder: GrammarBitmaskPlaceholder = \ + scheduler_output.grammar_bitmask + grammar_bitmask: np.ndarray = grammar_bitmask_placeholder.result() # We receive the structured output bitmask from the scheduler, # compacted to contain bitmasks only for structured output requests. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 43f12912707f..cfb2a9a68e96 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -61,6 +61,7 @@ if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.structured_output import GrammarBitmaskPlaceholder logger = init_logger(__name__) @@ -1757,8 +1758,10 @@ def get_input_embeddings(self, *args, **kwargs): def prepare_structured_decoding_input( self, logits: torch.Tensor, scheduler_output: "SchedulerOutput" ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - grammar_bitmask = scheduler_output.grammar_bitmask - assert grammar_bitmask is not None + grammar_bitmask_placeholder: GrammarBitmaskPlaceholder = \ + scheduler_output.grammar_bitmask + grammar_bitmask: np.ndarray = grammar_bitmask_placeholder.result() + num_reqs, _ = logits.shape # Reset pre-allocated tensors From e4e716796f106207bb73b7bc88a1a04f71c47033 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Fri, 19 Sep 2025 04:00:15 +0400 Subject: [PATCH 02/14] resolve merge conflicts Signed-off-by: Vadim Gimpelson --- vllm/v1/structured_output/__init__.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index afd32b928a60..a08e114985af 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -133,7 +133,7 @@ def from_request(cls, request: Request) -> GrammarInitData: assert request.sampling_params is not None return cls(request_id=request.request_id, guided_decoding_backend=request.sampling_params. - guided_decoding.backend, + structured_outputs._backend, structured_output_key=request.structured_output_request. structured_output_key) @@ -321,15 +321,12 @@ def __init__(self, vllm_config: VllmConfig): if not self.vllm_config.model_config.skip_tokenizer_init: self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config, - ).get_lora_tokenizer(None) - - reasoning_backend = vllm_config.decoding_config.reasoning_backend - if reasoning_backend: + model_config=vllm_config.model_config) + reasoning_parser = \ + self.vllm_config.structured_outputs_config.reasoning_parser + if reasoning_parser: reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_backend) + reasoning_parser) self.reasoner = reasoner_cls(tokenizer=self.tokenizer) # Set to track initialized grammars in main process @@ -368,7 +365,8 @@ def __init__(self, vllm_config: VllmConfig): scheduler_config=self.vllm_config.scheduler_config, model_config=self.vllm_config.model_config, speculative_config=self.vllm_config.speculative_config, - decoding_config=self.vllm_config.decoding_config, + structured_outputs_config=self.vllm_config. + structured_outputs_config, lora_config=self.vllm_config.lora_config, ) @@ -666,10 +664,7 @@ def __init__(self, vllm_config: VllmConfig): max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) self.tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config, - scheduler_config=self.vllm_config.scheduler_config, - lora_config=self.vllm_config.lora_config, - ).get_lora_tokenizer(None) + model_config=self.vllm_config.model_config) def _get_grammar(self, request_id: str) -> StructuredOutputGrammar: return self.request_id_to_grammar[request_id] From 5282ed1560f740276b3587b64881266384f0d15a Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Fri, 19 Sep 2025 04:32:48 +0400 Subject: [PATCH 03/14] fix Signed-off-by: Vadim Gimpelson --- vllm/v1/worker/gpu_model_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 21bd14595a39..3ee2160a42ff 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -107,8 +107,6 @@ if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput -else: - xgr = LazyLoader("xgr", globals(), "xgrammar") logger = init_logger(__name__) From 59004f00d448b66d86a376b1ffe8b68e1be20678 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Wed, 24 Sep 2025 15:16:40 +0400 Subject: [PATCH 04/14] #suppress-api-compatibility-check Signed-off-by: Vadim Gimpelson From aba59ed9ae236a4fb442f4d34fe59ec0d577b529 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Wed, 24 Sep 2025 15:23:58 +0400 Subject: [PATCH 05/14] pre-commit Signed-off-by: Vadim Gimpelson --- vllm/v1/core/sched/output.py | 2 -- vllm/v1/structured_output/utils.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 0551ef2c8bd7..934321c63311 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -10,8 +10,6 @@ from vllm.v1.structured_output import GrammarBitmaskPlaceholder if TYPE_CHECKING: - import numpy as np - import numpy.typing as npt import torch from vllm.distributed.kv_transfer.kv_connector.v1.base import ( diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 1edb781e1229..7af376861712 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -26,8 +26,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.structured_output import GrammarBitmaskPlaceholder + from vllm.v1.worker.gpu_input_batch import InputBatch else: xgr = LazyLoader("xgr", globals(), "xgrammar") oc = LazyLoader("oc", globals(), "outlines_core") From 8654924f852133f80d61139efa2eb521d78a5b37 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Wed, 24 Sep 2025 15:25:59 +0400 Subject: [PATCH 06/14] #suppress-bc-linter Signed-off-by: Vadim Gimpelson From af176b9dac7cf4e211e310a5de84b3095067b94e Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Wed, 24 Sep 2025 17:44:27 +0400 Subject: [PATCH 07/14] fix doc build Signed-off-by: Vadim Gimpelson --- requirements/docs.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/docs.txt b/requirements/docs.txt index d1c546398780..2089d297a604 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -16,3 +16,4 @@ cachetools msgspec pydantic torch +diskcache == 5.6.3 From 8fd308f1f2b52f8f6b5f72cc3eaf0ea0f5c26198 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Thu, 25 Sep 2025 03:03:41 +0400 Subject: [PATCH 08/14] #suppress-bc-linter Signed-off-by: Vadim Gimpelson From e8818eadd01ae48e883c247041638dfc18ed4b4b Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Thu, 25 Sep 2025 03:27:54 +0400 Subject: [PATCH 09/14] use sched_yield \n #suppress-bc-linter Signed-off-by: Vadim Gimpelson --- vllm/v1/structured_output/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index a08e114985af..08f91ac10943 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -15,6 +15,7 @@ import numpy as np from vllm.config import VllmConfig +from vllm.distributed.utils import sched_yield from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs @@ -597,8 +598,6 @@ def __init__(self): pass def result(self) -> np.ndarray: - import time - import numpy as np # Poll the shared memory flag until it's set @@ -609,7 +608,7 @@ def result(self) -> np.ndarray: if flag_value == 1: # Flag is set, bitmask is ready break - time.sleep(0.001) # Short sleep to avoid busy waiting + sched_yield() # Short sleep to avoid busy waiting flag_shm.close() From 03d43cdf6a9521038cf41c995c7fc1ec6915dc4f Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Wed, 1 Oct 2025 22:40:43 +0400 Subject: [PATCH 10/14] - fix doc build - fix a bug with dp Signed-off-by: Vadim Gimpelson #suppress-bc-linter --- requirements/docs.txt | 3 +- vllm/v1/structured_output/__init__.py | 66 +++++++++++++++------------ 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/requirements/docs.txt b/requirements/docs.txt index 2089d297a604..49f8be5c3667 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -15,5 +15,4 @@ ruff cachetools msgspec pydantic -torch -diskcache == 5.6.3 +torch \ No newline at end of file diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 08f91ac10943..1d62ae01fa05 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -177,22 +177,28 @@ class StructuredOutputGateway: """ def __init__(self, task_queue, batch_validate_result_queue, - grammar_init_notification_queue, vllm_config: VllmConfig): + grammar_init_notification_queue, vllm_config: VllmConfig, + bitmask_shm_name: str, ready_flag_shm_name: str): self.task_queue = task_queue self.batch_validate_result_queue = batch_validate_result_queue self.grammar_init_notification_queue = grammar_init_notification_queue self.vllm_config = vllm_config + self.bitmask_shm_name = bitmask_shm_name + self.ready_flag_shm_name = ready_flag_shm_name self.structured_output_executor: Optional[ StructuredOutputExecutor] = None @staticmethod def run_gateway(task_queue, batch_validate_result_queue, - grammar_init_notification_queue, vllm_config: VllmConfig): + grammar_init_notification_queue, vllm_config: VllmConfig, + bitmask_shm_name: str, ready_flag_shm_name: str): """Static method to run the gateway in a separate process.""" gateway = StructuredOutputGateway(task_queue, batch_validate_result_queue, grammar_init_notification_queue, - vllm_config) + vllm_config, + bitmask_shm_name, + ready_flag_shm_name) gateway.run() def run(self): @@ -204,7 +210,7 @@ def run(self): # Attach to shared memory in child process self.bitmask_shm = shared_memory.SharedMemory( - name=GRAMMAR_BITMASK_SHM_NAME) + name=self.bitmask_shm_name) while True: try: @@ -225,7 +231,7 @@ def run(self): else: # Set the flag even on error so result() doesn't hang flag_shm = shared_memory.SharedMemory( - name=GRAMMAR_READY_FLAG_SHM_NAME) + name=self.ready_flag_shm_name) flag_shm.buf[0] = 1 flag_shm.close() elif task.task_type == TaskType.BATCH_VALIDATE_TOKENS: @@ -239,7 +245,7 @@ def run(self): elif task.task_type == TaskType.GRAMMAR_BITMASK: # Signal flag even on error so result() doesn't hang flag_shm = shared_memory.SharedMemory( - name=GRAMMAR_READY_FLAG_SHM_NAME) + name=self.ready_flag_shm_name) flag_shm.buf[0] = 1 flag_shm.close() elif task.task_type == TaskType.BATCH_VALIDATE_TOKENS: @@ -301,7 +307,7 @@ def _write_bitmask_to_shared_memory(self, bitmask_array: np.ndarray): # calls that act as memory barriers. # Set the flag to indicate bitmask is ready - flag_shm = shared_memory.SharedMemory(name=GRAMMAR_READY_FLAG_SHM_NAME) + flag_shm = shared_memory.SharedMemory(name=self.ready_flag_shm_name) flag_shm.buf[0] = 1 flag_shm.close() @@ -319,6 +325,11 @@ def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config self.reasoner: Optional[ReasoningParser] = None + + # Create shared memory names with data parallel rank + data_parallel_rank = str(self.vllm_config.parallel_config.data_parallel_rank) + self.bitmask_shm_name = GRAMMAR_BITMASK_SHM_NAME + data_parallel_rank + self.ready_flag_shm_name = GRAMMAR_READY_FLAG_SHM_NAME + data_parallel_rank if not self.vllm_config.model_config.skip_tokenizer_init: self.tokenizer = init_tokenizer_from_configs( @@ -354,11 +365,11 @@ def __init__(self, vllm_config: VllmConfig): self._cleanup_existing_shared_memory() self.bitmask_shm = shared_memory.SharedMemory( - create=True, size=shm_size, name=GRAMMAR_BITMASK_SHM_NAME) + create=True, size=shm_size, name=self.bitmask_shm_name) # Create shared memory flag for bitmask ready signaling self.flag_shm = shared_memory.SharedMemory( - name=GRAMMAR_READY_FLAG_SHM_NAME, create=True, size=1) + name=self.ready_flag_shm_name, create=True, size=1) self.flag_shm.buf[0] = 0 # Initialize flag to 0 (not ready) # Create a partial config with only the required fields @@ -375,7 +386,8 @@ def __init__(self, vllm_config: VllmConfig): target=StructuredOutputGateway.run_gateway, name="StructuredOutputGateway", args=(self.task_queue, self.batch_validate_result_queue, - self.grammar_init_notification_queue, partial_config), + self.grammar_init_notification_queue, partial_config, + self.bitmask_shm_name, self.ready_flag_shm_name), daemon=True) self.gateway_process.start() @@ -409,7 +421,7 @@ def submit_grammar_bitmask(self, requests: dict[str, Request], self.task_queue.put(task) # Return a placeholder that consumer can check - return GrammarBitmaskPlaceholder() + return GrammarBitmaskPlaceholder(self.bitmask_shm_name, self.ready_flag_shm_name) def submit_batch_accept_tokens( self, request_id_to_new_token_ids: list[tuple[str, list[int]]]): @@ -544,13 +556,10 @@ def shutdown(self): self.gateway_process.join() # Clean up shared memory - if hasattr(self, 'bitmask_shm'): - self.bitmask_shm.close() - self.bitmask_shm.unlink() - - if hasattr(self, 'flag_shm'): - self.flag_shm.close() - self.flag_shm.unlink() + self.bitmask_shm.close() + self.bitmask_shm.unlink() + self.flag_shm.close() + self.flag_shm.unlink() def _cleanup_existing_shared_memory(self): """Clean up any existing shared memory segments from previous runs. @@ -561,11 +570,11 @@ def _cleanup_existing_shared_memory(self): # Try to unlink bitmask shared memory if it exists try: existing_bitmask_shm = shared_memory.SharedMemory( - name=GRAMMAR_BITMASK_SHM_NAME) + name=self.bitmask_shm_name) existing_bitmask_shm.close() existing_bitmask_shm.unlink() - logger.info("Cleaned up existing bitmask shared memory from " - "previous run") + logger.debug("Cleaned up existing bitmask shared memory from " + "previous run") except FileNotFoundError: # No existing shared memory, which is fine pass @@ -575,11 +584,11 @@ def _cleanup_existing_shared_memory(self): # Try to unlink flag shared memory if it exists try: existing_flag_shm = shared_memory.SharedMemory( - name=GRAMMAR_READY_FLAG_SHM_NAME) + name=self.ready_flag_shm_name) existing_flag_shm.close() existing_flag_shm.unlink() - logger.info("Cleaned up existing flag shared memory from " - "previous run") + logger.debug("Cleaned up existing flag shared memory from " + "previous run") except FileNotFoundError: # No existing shared memory, which is fine pass @@ -594,14 +603,15 @@ class GrammarBitmaskPlaceholder: and shared memory to retrieve bitmask data. """ - def __init__(self): - pass + def __init__(self, bitmask_shm_name: str, ready_flag_shm_name: str): + self.bitmask_shm_name = bitmask_shm_name + self.ready_flag_shm_name = ready_flag_shm_name def result(self) -> np.ndarray: import numpy as np # Poll the shared memory flag until it's set - flag_shm = shared_memory.SharedMemory(name=GRAMMAR_READY_FLAG_SHM_NAME) + flag_shm = shared_memory.SharedMemory(name=self.ready_flag_shm_name) while True: flag_value = flag_shm.buf[0] @@ -612,7 +622,7 @@ def result(self) -> np.ndarray: flag_shm.close() - bitmask_shm = shared_memory.SharedMemory(name=GRAMMAR_BITMASK_SHM_NAME) + bitmask_shm = shared_memory.SharedMemory(name=self.bitmask_shm_name) # Read shape info first shape_bytes = bytes(bitmask_shm.buf[:8]) # Create a copy of the bytes shape = np.frombuffer(shape_bytes, dtype=np.int32).copy() From 97e89225f560c47db972c536a95a497ae40f54d4 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Wed, 1 Oct 2025 23:21:49 +0400 Subject: [PATCH 11/14] - pre-commit fix doc build Signed-off-by: Vadim Gimpelson #suppress-bc-linter --- vllm/v1/core/sched/output.py | 2 +- vllm/v1/structured_output/__init__.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 934321c63311..bcd88b14d2f7 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional from vllm._bc_linter import bc_linter_include -from vllm.v1.structured_output import GrammarBitmaskPlaceholder if TYPE_CHECKING: import torch @@ -19,6 +18,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.request import Request + from vllm.v1.structured_output import GrammarBitmaskPlaceholder @bc_linter_include diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 1d62ae01fa05..6b8a64952217 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -196,8 +196,7 @@ def run_gateway(task_queue, batch_validate_result_queue, gateway = StructuredOutputGateway(task_queue, batch_validate_result_queue, grammar_init_notification_queue, - vllm_config, - bitmask_shm_name, + vllm_config, bitmask_shm_name, ready_flag_shm_name) gateway.run() @@ -325,11 +324,13 @@ def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config self.reasoner: Optional[ReasoningParser] = None - + # Create shared memory names with data parallel rank - data_parallel_rank = str(self.vllm_config.parallel_config.data_parallel_rank) + data_parallel_rank = str( + self.vllm_config.parallel_config.data_parallel_rank) self.bitmask_shm_name = GRAMMAR_BITMASK_SHM_NAME + data_parallel_rank - self.ready_flag_shm_name = GRAMMAR_READY_FLAG_SHM_NAME + data_parallel_rank + self.ready_flag_shm_name = GRAMMAR_READY_FLAG_SHM_NAME \ + + data_parallel_rank if not self.vllm_config.model_config.skip_tokenizer_init: self.tokenizer = init_tokenizer_from_configs( @@ -421,7 +422,8 @@ def submit_grammar_bitmask(self, requests: dict[str, Request], self.task_queue.put(task) # Return a placeholder that consumer can check - return GrammarBitmaskPlaceholder(self.bitmask_shm_name, self.ready_flag_shm_name) + return GrammarBitmaskPlaceholder(self.bitmask_shm_name, + self.ready_flag_shm_name) def submit_batch_accept_tokens( self, request_id_to_new_token_ids: list[tuple[str, list[int]]]): From c1c05bc36a45ebcd9751789a673a17a99d0739e7 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Thu, 2 Oct 2025 01:44:14 +0400 Subject: [PATCH 12/14] rerun CI #suppress-bc-linter Signed-off-by: Vadim Gimpelson From 057ea6ca308fbe573930541954b07a8ff1c9ac1a Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Sat, 4 Oct 2025 06:02:37 +0400 Subject: [PATCH 13/14] rerun CI #suppress-bc-linter Signed-off-by: Vadim Gimpelson From a0924637f2abdd1a7d223ce9675e9c9cef90e82c Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Sat, 4 Oct 2025 23:10:11 +0400 Subject: [PATCH 14/14] rerun CI #suppress-bc-linter Signed-off-by: Vadim Gimpelson