From 7c8d34faf6fd73ea069bf7f094b1121f3c246c45 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 24 Sep 2024 21:37:38 -0600 Subject: [PATCH] [Bugfix] Use heartbeats instead of health checks (#8583) --- tests/mq_llm_engine/test_error_handling.py | 15 ++--- vllm/engine/multiprocessing/__init__.py | 7 +- vllm/engine/multiprocessing/client.py | 51 +++++++------- vllm/engine/multiprocessing/engine.py | 77 +++++++++++++++++----- 4 files changed, 87 insertions(+), 63 deletions(-) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 49cfc5aa04c36..76b2f494d5b25 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket): await client.check_health() # Trigger an abort on the client side. - async def bad_abort_after_2s(): - await asyncio.sleep(2.0) - await client.abort(request_id="foo") + # This request ID does not exist, and will cause the engine to error + await client.abort(request_id="foo") - # Trigger an abort in 2s from now. - abort_task = asyncio.create_task(bad_abort_after_2s()) - - # Exception in abort() will happen during this generation. - # This will kill the engine and should return ENGINE_DEAD_ERROR + # Future generation requests will now fail # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( inputs="Hello my name is", - sampling_params=SamplingParams(max_tokens=2000), + sampling_params=SamplingParams(max_tokens=10), request_id=uuid.uuid4()): pass assert "KeyError" in repr(execinfo.value) assert client.errored - await abort_task - # This should raise the original error. with pytest.raises(RAISED_ERROR): await client.check_health() diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 700332864d17a..165e6cc2146c3 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -43,10 +43,6 @@ class RPCAbortRequest: request_id: str -class RPCHealthRequest: - pass - - class RPCStartupRequest(Enum): IS_SERVER_READY = 1 @@ -56,8 +52,7 @@ class RPCStartupResponse: tracing_enabled: bool -RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest, - RPCStartupRequest] +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index aa9dbbd448af2..7e397cf408fba 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -20,9 +20,8 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCHealthRequest, - RPCProcessRequest, RPCStartupRequest, - RPCStartupResponse) + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT from vllm.inputs import PromptInputs @@ -95,9 +94,9 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig): self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") - # IPC path for ack of check_health requests. - self.health_socket: Socket = self.context.socket(zmq.constants.PULL) - self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + # IPC path for acking heartbeats. + self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) + self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -124,34 +123,28 @@ def get_data_socket(self) -> Iterator[Socket]: finally: socket.close(linger=0) - async def run_check_health_loop(self, timeout: int): - """Background loop that continually probes the RPCServer for health. - - The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which - the MQLLMEngine server is blocking on. - - The Server replies on the HEALTH_SOCKET (rather than on the - OUTPUT_SOCKET such that the messages are not intermingled with - output streaming). + async def run_heartbeat_loop(self, timeout: int): + """Background loop that continually listens to the RPCServer for + heartbeats. """ - try: while True: - if await self.health_socket.poll(timeout=timeout) == 0: - # Wakeup every N seconds and do a health probe. - await self._send_one_way_rpc_request( - RPCHealthRequest(), self.input_socket) - - # Wait for ack from the health socket. - await self._await_ack(error_message="Health check failed.", - socket=self.health_socket) + if await self.heartbeat_socket.poll(timeout=timeout) == 0: + # No heartbeat was received. Set error and exit the loop + self._set_errored( + TimeoutError("No heartbeat received " + "from MQLLMEngine")) + logger.debug("Shutting down MQLLMEngineClient check " + "health loop due to timeout") + break + else: - # Server sent a health status message unprompted. + # Heartbeat received- check the message await self._check_success( - error_message="Health check failed.", - socket=self.health_socket) + error_message="Heartbeat failed.", + socket=self.heartbeat_socket) - logger.debug("Health probe successful.") + logger.debug("Heartbeat successful.") except asyncio.CancelledError: logger.debug("Shutting down MQLLMEngineClient check health loop.") @@ -234,7 +227,7 @@ async def setup(self): # Start health_loop. self.health_loop = asyncio.create_task( - self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) + self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) def close(self): """Destroy the ZeroMQ Context.""" diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 485db0bab1297..b1dd9915cbbf5 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,5 +1,7 @@ import pickle import signal +import threading +import time from contextlib import contextmanager from typing import Iterator, List, Optional, Union @@ -15,10 +17,10 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCHealthRequest, - RPCProcessRequest, RPCStartupRequest, - RPCStartupResponse) + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse) # yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -91,9 +93,9 @@ def __init__(self, self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") - # Send health status back to client. - self.health_socket = self.ctx.socket(zmq.constants.PUSH) - self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + # Send heartbeats back to client. + self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) + self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -101,6 +103,20 @@ def __init__(self, # Error state. self._errored_with: Optional[BaseException] = None + # Heartbeat thread + self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, + daemon=True) + self._heartbeat_stop_event = threading.Event() + # The heartbeat needs to be faster than what the client will wait for + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0 + + self._last_alive_time = time.time() + # The heartbeats can tolerate a long period of the engine chugging + # away at a generation request. + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0 + @property def dead_error(self) -> BaseException: if self._errored_with is not None: @@ -131,6 +147,8 @@ def start(self): try: logger.debug("Starting Startup Loop.") self.run_startup_loop() + logger.debug("Starting heartbeat thread") + self.heartbeat_thread.start() logger.debug("Starting Engine Loop.") self.run_engine_loop() except Exception as e: @@ -144,6 +162,7 @@ def start(self): def cleanup(self): """Cleanup zeromq state on shutdown.""" # Closes all sockets and destroys context. + self._heartbeat_stop_event.set() self.ctx.destroy(linger=0) del self.engine @@ -182,9 +201,11 @@ def run_engine_loop(self): """Core busy loop of the LLMEngine.""" while True: + self._alive() if not self.engine.has_unfinished_requests(): # Poll until there is work to do. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + self._alive() self.engine.do_log_stats() logger.debug("Waiting for new requests in engine loop.") @@ -200,7 +221,6 @@ def run_engine_loop(self): def engine_step(self) -> List[RequestOutput]: """Engine step wrapper with error handling.""" - try: return self.engine.step() except SystemExit: @@ -229,10 +249,9 @@ def handle_new_input(self): self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) - elif isinstance(request, RPCHealthRequest): - self._handle_health_request() else: - raise ValueError("Unknown RPCRequest Type: {request}") + raise ValueError("Unknown RPCRequest Type: " + f"{type(request)}") except Exception as e: self._set_errored(e) @@ -279,13 +298,32 @@ def _handle_abort_request(self, request: RPCAbortRequest): if self.log_requests: logger.info("Aborted request %s.", request.request_id) - def _handle_health_request(self): + def _heartbeat_loop(self): + while not self._heartbeat_stop_event.wait( + timeout=self.heartbeat_interval_seconds): + # Loops until the stop event is set + self._heartbeat() + + logger.debug("Exiting MQLLMEngine heartbeat thread") + + def _heartbeat(self): + # Send unhealthy if engine has already errored if self._errored_with is not None: self._send_unhealthy(self._errored_with) - # Raises error if unhealthy. - self.engine.check_health() - self._send_healthy() + # Check for life of the main loop + elif time.time() - self._last_alive_time > self.last_alive_threshold: + self._send_unhealthy(RuntimeError("Engine loop has died")) + + else: + # Otherwise- check health of the engine + # self.engine.check_health() raises on unhealthy + try: + self.engine.check_health() + self._send_healthy() + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" @@ -295,12 +333,14 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): def _send_healthy(self): """Send HEALTHY message to RPCClient.""" - self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + if not self.heartbeat_socket.closed: + self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) def _send_unhealthy(self, error: BaseException): """Send UNHEALTHY message to RPCClient.""" - error_bytes = pickle.dumps(error) - self.health_socket.send_multipart((error_bytes, ), copy=False) + if not self.heartbeat_socket.closed: + error_bytes = pickle.dumps(error) + self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): @@ -313,6 +353,9 @@ def _set_errored(self, e: BaseException): if self._errored_with is None: self._errored_with = e + def _alive(self): + self._last_alive_time = time.time() + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str):