diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml index d5c6b8d43..45d5d8191 100644 --- a/.github/workflows/cleanup_pr_body.yml +++ b/.github/workflows/cleanup_pr_body.yml @@ -13,7 +13,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Set up Python uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml index 2b1086b7f..b7e1b3395 100644 --- a/.github/workflows/lint-and-deploy.yaml +++ b/.github/workflows/lint-and-deploy.yaml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: fetch-depth: 0 diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 195579f20..cd352fa00 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -16,7 +16,7 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: "3.12" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index bfd028799..983e76e08 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -21,7 +21,7 @@ jobs: upload_url: ${{ steps.create_release.outputs.upload_url }} steps: - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Extract branch info shell: bash @@ -55,7 +55,7 @@ jobs: # steps: # - name: Checkout - # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 # - name: Setup ccache # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bbe958351..fea416111 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -417,6 +417,14 @@ def _initialize_kv_caches(self) -> None: logger.info(("init engine (profile, create kv cache, " "warmup model) took %.2f seconds"), elapsed) + self.max_kv_cache_size = num_gpu_blocks * self.cache_config.block_size + + + def get_max_kv_cache_size(self) -> int: + """Get the maximum size of the KV cache.""" + return self.max_kv_cache_size + + @classmethod def _get_executor_cls(cls, engine_config: VllmConfig) -> Type[ExecutorBase]: @@ -559,6 +567,7 @@ def _add_processed_request( return None self._validate_model_inputs(processed_inputs, lora_request) + self._compute_free_tokens(processed_inputs, params) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) @@ -1844,6 +1853,32 @@ def _validate_model_input( # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + def _compute_free_tokens( + self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs], params + ): + if self.model_config.is_multimodal_model: + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + prompt_ids = inputs.get("prompt_token_ids") + elif self.is_encoder_decoder_model(): + prompt_ids = inputs.get("encoder_prompt_token_ids") + else: + prompt_ids = inputs.get("prompt_token_ids") + + input_token_count = len(prompt_ids) + max_token_count = self.model_config.max_model_len + + max_output_tokens = 0 + if isinstance(params, SamplingParams): + max_output_tokens = params.max_tokens + + free_tokens = max_token_count - max_output_tokens - input_token_count + + logger.info(f"Free tokens available after tokenization: {free_tokens}") + + if free_tokens > 0: + self.send_free_tokens_callback(free_tokens) + def _build_logits_processors( self, sampling_params: SamplingParams, lora_request: Optional[LoRARequest]) -> SamplingParams: diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index ff0405d2f..87948f0bc 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -19,6 +19,7 @@ IPC_OUTPUT_EXT = "_output_socket" IPC_HEALTH_EXT = "_health_socket" IPC_DATA_EXT = "_data_socket" +IPC_TOKENS_EXT = "_tokens_socket" class MQEngineDeadError(RuntimeError): @@ -72,6 +73,7 @@ class RPCStartupRequest(Enum): @dataclass class RPCStartupResponse: tracing_enabled: bool + max_kv_cache_size: int class RPCUProfileRequest(Enum): @@ -116,6 +118,10 @@ class RPCLoadAdapterRequest: # Set the default value of request_id to a new UUID request_id: str = field(default_factory=lambda: str(uuid.uuid4())) +@dataclass +class FreeTokensRequest: + free_token_count: int + @dataclass class RPCAdapterLoadedResponse: diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index eca29af50..95f955917 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -23,6 +23,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, + IPC_TOKENS_EXT, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCAdapterLoadedResponse, RPCError, RPCIsSleepingRequest, @@ -120,6 +121,10 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + # IPC path for acking tokens. + self.tokens_socket: Socket = self.context.socket(zmq.constants.PULL) + self.tokens_socket.connect(f"{ipc_path}{IPC_TOKENS_EXT}") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -131,6 +136,16 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, # build the Client in an executor to enable clean shutdown. self.output_loop: Optional[asyncio.Task] = None + # Loop to handle tokens from the LLMEngine periodically. + # Started after the MQLLMEngine is ready so that we can + # build the Client in an executor to enable clean shutdown. + self.token_loop: Optional[asyncio.Task] = None + + self.max_kv_cache_size : Optional[int] = None + self.current_kv_cache_size : Optional[int] = None + + self.kv_cache_size_updated_event = asyncio.Event() + # Loop to check health of the LLMEngine periodically. # Started after the MQLLMEngine is ready. self.health_loop: Optional[asyncio.Task] = None @@ -185,6 +200,42 @@ async def run_heartbeat_loop(self, timeout: int): except Exception as e: self._set_errored(e) + async def run_token_handler_loop(self): + """As kv cache token slots become free, they are pushed to the + tokens_socket. This loop listens to the tokens_socket and + updates the kv cache size. + """ + try: + while True: + # Poll, checking for ENGINE_DEAD + if await self.tokens_socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + logger.debug("Waiting for tokens from MQLLMEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) + return + + message: Frame = await self.tokens_socket.recv(copy=False) + tokens_response = pickle.loads(message.buffer) + + if isinstance(tokens_response, BaseException): + raise tokens_response + + # Update the kv cache size. + self.current_kv_cache_size += tokens_response.freed_tokens + self.current_kv_cache_size = min( + self.current_kv_cache_size, self.max_kv_cache_size) + + # Set the event to notify that the kv cache size has + # been updated. + self.kv_cache_size_updated_event.set() + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient token handler loop.") + async def run_output_handler_loop(self): """Get RequestOutputs from Engine and stream to Request Queues""" @@ -288,11 +339,19 @@ async def setup(self): self.tracing_flag = response.tracing_enabled + self.max_kv_cache_size = response.max_kv_cache_size + self.current_kv_cache_size = self.max_kv_cache_size + # Start health_loop. if self.health_loop is None: self.health_loop = asyncio.create_task( self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) + # Start token handler loop. + if self.token_loop is None: + self.token_loop = asyncio.create_task( + self.run_token_handler_loop()) + def close(self): """Destroy the ZeroMQ Context.""" # Close all sockets and terminate the context. @@ -301,6 +360,8 @@ def close(self): # Cancel background tasks. if self.health_loop is not None: self.health_loop.cancel() + if self.token_loop is not None: + self.token_loop.cancel() if self.output_loop is not None: self.output_loop.cancel() diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 903f3fd71..1a37e8a68 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -16,6 +16,7 @@ # yapf: disable from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_TOKENS_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCAdapterLoadedResponse, RPCError, @@ -27,7 +28,8 @@ RPCResetPrefixCacheRequest, RPCSleepRequest, RPCStartupRequest, RPCStartupResponse, - RPCUProfileRequest, RPCWakeUpRequest) + RPCUProfileRequest, RPCWakeUpRequest, + FreeTokensRequest) # yapf: enable from vllm.logger import init_logger from vllm.outputs import RequestOutput @@ -93,6 +95,9 @@ def __init__(self, self.engine.process_request_outputs_callback = \ self._async_socket_engine_callback + self.engine.send_free_tokens_callback = \ + self._send_free_tokens_callback + self.ctx = zmq.Context() # type: ignore[attr-defined] # Receive input from the client. @@ -107,6 +112,10 @@ def __init__(self, self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + # Send tokens back to client. + self.tokens_socket = self.ctx.socket(zmq.constants.PUSH) + self.tokens_socket.bind(f"{ipc_path}{IPC_TOKENS_EXT}") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -209,7 +218,9 @@ def run_startup_loop(self) -> None: if request == RPCStartupRequest.IS_SERVER_READY: tracing_enabled = self.engine.is_tracing_enabled() response = RPCStartupResponse( - tracing_enabled=tracing_enabled) + tracing_enabled=tracing_enabled, + max_kv_cache_size=self.engine.get_max_kv_cache_size(), + ) except Exception as e: response = e @@ -411,9 +422,29 @@ def _send_unhealthy(self, error: BaseException): def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): """Callback used by engine to make socket handling async with GPU.""" + self._send_free_tokens_for_outputs(request_outputs) self._send_outputs(request_outputs) self.handle_new_input() + def _send_free_tokens_for_outputs(self, request_outputs: REQUEST_OUTPUTS_T): + """Send free tokens for outputs if available.""" + free_tokens = 0 + for output in request_outputs: + if not isinstance(output, RequestOutput): + continue + free_tokens += len(output.prompt_token_ids) + free_tokens += output.max_tokens + + if free_tokens > 0: + self._send_free_tokens_callback(free_tokens) + + def _send_free_tokens_callback(self, free_tokens: int): + """Callback used by engine to send free tokens to the client.""" + if not self.tokens_socket.closed: + self.tokens_socket.send_multipart( + (pickle.dumps(FreeTokensRequest(free_token_count=free_tokens)),), copy=False + ) + def _set_errored(self, e: BaseException): """Log and set errored status if this is the first issue.""" if self._errored_with is None: