diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index df9832fc4e48..b55018ae8ef0 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -32,6 +32,7 @@ class MyType: large_f_contig_tensor: torch.Tensor small_non_contig_tensor: torch.Tensor large_non_contig_tensor: torch.Tensor + empty_tensor: torch.Tensor def test_encode_decode(): @@ -58,6 +59,7 @@ def test_encode_decode(): large_f_contig_tensor=torch.rand(1024, 4).t(), small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], + empty_tensor=torch.empty(0), ) encoder = MsgpackEncoder(size_threshold=256) @@ -193,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType): obj2.small_non_contig_tensor) assert torch.equal(obj1.large_non_contig_tensor, obj2.large_non_contig_tensor) + assert torch.equal(obj1.empty_tensor, obj2.empty_tensor) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9590a9aadbec..80807665e779 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -5,6 +5,7 @@ import sys import threading import time +from collections import deque from concurrent.futures import Future from inspect import isclass, signature from logging import DEBUG @@ -527,8 +528,12 @@ def process_output_socket(self, output_path: str, engine_index: int): # Msgpack serialization encoding. encoder = MsgpackEncoder() - # Reuse send buffer. - buffer = bytearray() + # Send buffers to reuse. + reuse_buffers: list[bytearray] = [] + # Keep references to outputs and buffers until zmq is finished + # with them (outputs may contain tensors/np arrays whose + # backing buffers were extracted for zero-copy send). + pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]() # We must set linger to ensure the ENGINE_CORE_DEAD # message is sent prior to closing the socket. @@ -541,8 +546,22 @@ def process_output_socket(self, output_path: str, engine_index: int): break assert not isinstance(outputs, bytes) outputs.engine_index = engine_index + + # Reclaim buffers that zmq is finished with. + while pending and pending[-1][0].done: + reuse_buffers.append(pending.pop()[2]) + + buffer = reuse_buffers.pop() if reuse_buffers else bytearray() buffers = encoder.encode_into(outputs, buffer) - socket.send_multipart(buffers, copy=False) + tracker = socket.send_multipart(buffers, + copy=False, + track=True) + if not tracker.done: + ref = outputs if len(buffers) > 1 else None + pending.appendleft((tracker, ref, buffer)) + elif len(reuse_buffers) < 2: + # Keep at most 2 buffers to reuse. + reuse_buffers.append(buffer) class DPEngineCoreProc(EngineCoreProc): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 21ad446172a3..bf8473858088 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import contextlib import queue import uuid import weakref from abc import ABC, abstractmethod +from collections import deque from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass, field @@ -396,6 +398,12 @@ def __init__( self._wait_for_engine_startup() self.utility_results: dict[int, AnyFuture] = {} + + # Request objects which may contain pytorch-allocated tensors + # that we need to keep references to until zmq is done with the + # underlying data. + self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]() + success = True finally: if not success: @@ -459,6 +467,14 @@ def ensure_alive(self): if self.resources.engine_dead: raise EngineDeadError() + def add_pending_message(self, tracker: zmq.MessageTracker, msg: Any): + if not tracker.done: + self.pending_messages.appendleft((tracker, msg)) + + def free_pending_messages(self): + while self.pending_messages and self.pending_messages[-1][0].done: + self.pending_messages.pop() + def _process_utility_output(output: UtilityOutput, utility_results: dict[int, AnyFuture]): @@ -544,10 +560,18 @@ def get_output(self) -> EngineCoreOutputs: def _send_input(self, request_type: EngineCoreRequestType, request: Any): self.ensure_alive() + self.free_pending_messages() # (Identity, RequestType, SerializedRequest) msg = (self.core_engine.identity, request_type.value, *self.encoder.encode(request)) - self.input_socket.send_multipart(msg, copy=False) + + if len(msg) <= 3: + # No auxiliary buffers => no tensor backing buffers in request. + self.input_socket.send_multipart(msg, copy=False) + return + + tracker = self.input_socket.send_multipart(msg, copy=False, track=True) + self.add_pending_message(tracker, request) def call_utility(self, method: str, *args) -> Any: call_id = uuid.uuid1().int >> 64 @@ -698,19 +722,38 @@ async def get_output_async(self) -> EngineCoreOutputs: def _send_input(self, request_type: EngineCoreRequestType, request: Any, - engine: Optional[CoreEngine] = None) -> Awaitable[None]: + engine: Optional[CoreEngine] = None) -> Awaitable[Any]: self.ensure_alive() if engine is None: engine = self.core_engine message = (request_type.value, *self.encoder.encode(request)) - return self._send_input_message(message, engine) - - def _send_input_message(self, message: tuple[bytestr, ...], - engine: CoreEngine) -> Awaitable[None]: + return self._send_input_message(message, engine, request) + + def _send_input_message(self, message: tuple[bytestr, + ...], engine: CoreEngine, + objects: Any) -> Awaitable[Any]: + """ + objects is a reference to retain until zmq is finished with the + buffers, in case they were extracted from tensors in the request. + """ self.ensure_alive() - message = (engine.identity, ) + message - return self.input_socket.send_multipart(message, copy=False) + self.free_pending_messages() + + msg = (engine.identity, ) + message + if not objects or len(msg) <= 3: + # No auxiliary buffers => no tensor backing buffers in request. + return self.input_socket.send_multipart(msg, copy=False) + + future: asyncio.Future[zmq.MessageTracker] + future = self.input_socket.send_multipart(msg, copy=False, track=True) + + def add_pending(f: asyncio.Future[zmq.MessageTracker]): + with contextlib.suppress(BaseException): + self.add_pending_message(f.result(), objects) + + future.add_done_callback(add_pending) + return future async def call_utility_async(self, method: str, *args) -> Any: return await self._call_utility_async(method, @@ -724,7 +767,7 @@ async def _call_utility_async(self, method: str, *args, self.utility_results[call_id] = future message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( (call_id, method, args))) - await self._send_input_message(message, engine) + await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future