Skip to content

Commit

Permalink
[Bugfix] Use heartbeats instead of health checks (vllm-project#8583)
Browse files Browse the repository at this point in the history
  • Loading branch information
joerunde authored and siddharth9820 committed Sep 30, 2024
1 parent beb670a commit 7c8d34f
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 63 deletions.
15 changes: 4 additions & 11 deletions tests/mq_llm_engine/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket):
await client.check_health()

# Trigger an abort on the client side.
async def bad_abort_after_2s():
await asyncio.sleep(2.0)
await client.abort(request_id="foo")
# This request ID does not exist, and will cause the engine to error
await client.abort(request_id="foo")

# Trigger an abort in 2s from now.
abort_task = asyncio.create_task(bad_abort_after_2s())

# Exception in abort() will happen during this generation.
# This will kill the engine and should return ENGINE_DEAD_ERROR
# Future generation requests will now fail
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=2000),
sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored

await abort_task

# This should raise the original error.
with pytest.raises(RAISED_ERROR):
await client.check_health()
Expand Down
7 changes: 1 addition & 6 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ class RPCAbortRequest:
request_id: str


class RPCHealthRequest:
pass


class RPCStartupRequest(Enum):
IS_SERVER_READY = 1

Expand All @@ -56,8 +52,7 @@ class RPCStartupResponse:
tracing_enabled: bool


RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
RPCStartupRequest]
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest]

REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]

Expand Down
51 changes: 22 additions & 29 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs
Expand Down Expand Up @@ -95,9 +94,9 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig):
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")

# IPC path for ack of check_health requests.
self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for acking heartbeats.
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")

# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
Expand All @@ -124,34 +123,28 @@ def get_data_socket(self) -> Iterator[Socket]:
finally:
socket.close(linger=0)

async def run_check_health_loop(self, timeout: int):
"""Background loop that continually probes the RPCServer for health.
The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
the MQLLMEngine server is blocking on.
The Server replies on the HEALTH_SOCKET (rather than on the
OUTPUT_SOCKET such that the messages are not intermingled with
output streaming).
async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually listens to the RPCServer for
heartbeats.
"""

try:
while True:
if await self.health_socket.poll(timeout=timeout) == 0:
# Wakeup every N seconds and do a health probe.
await self._send_one_way_rpc_request(
RPCHealthRequest(), self.input_socket)

# Wait for ack from the health socket.
await self._await_ack(error_message="Health check failed.",
socket=self.health_socket)
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
# No heartbeat was received. Set error and exit the loop
self._set_errored(
TimeoutError("No heartbeat received "
"from MQLLMEngine"))
logger.debug("Shutting down MQLLMEngineClient check "
"health loop due to timeout")
break

else:
# Server sent a health status message unprompted.
# Heartbeat received- check the message
await self._check_success(
error_message="Health check failed.",
socket=self.health_socket)
error_message="Heartbeat failed.",
socket=self.heartbeat_socket)

logger.debug("Health probe successful.")
logger.debug("Heartbeat successful.")

except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.")
Expand Down Expand Up @@ -234,7 +227,7 @@ async def setup(self):

# Start health_loop.
self.health_loop = asyncio.create_task(
self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT))
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))

def close(self):
"""Destroy the ZeroMQ Context."""
Expand Down
77 changes: 60 additions & 17 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pickle
import signal
import threading
import time
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

Expand All @@ -15,10 +17,10 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -91,16 +93,30 @@ def __init__(self,
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")

# Send health status back to client.
self.health_socket = self.ctx.socket(zmq.constants.PUSH)
self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# Send heartbeats back to client.
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")

# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"

# Error state.
self._errored_with: Optional[BaseException] = None

# Heartbeat thread
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
daemon=True)
self._heartbeat_stop_event = threading.Event()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0

self._last_alive_time = time.time()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0

@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
Expand Down Expand Up @@ -131,6 +147,8 @@ def start(self):
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting heartbeat thread")
self.heartbeat_thread.start()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
Expand All @@ -144,6 +162,7 @@ def start(self):
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self._heartbeat_stop_event.set()
self.ctx.destroy(linger=0)
del self.engine

Expand Down Expand Up @@ -182,9 +201,11 @@ def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""

while True:
self._alive()
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
self._alive()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")

Expand All @@ -200,7 +221,6 @@ def run_engine_loop(self):

def engine_step(self) -> List[RequestOutput]:
"""Engine step wrapper with error handling."""

try:
return self.engine.step()
except SystemExit:
Expand Down Expand Up @@ -229,10 +249,9 @@ def handle_new_input(self):
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCHealthRequest):
self._handle_health_request()
else:
raise ValueError("Unknown RPCRequest Type: {request}")
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")

except Exception as e:
self._set_errored(e)
Expand Down Expand Up @@ -279,13 +298,32 @@ def _handle_abort_request(self, request: RPCAbortRequest):
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)

def _handle_health_request(self):
def _heartbeat_loop(self):
while not self._heartbeat_stop_event.wait(
timeout=self.heartbeat_interval_seconds):
# Loops until the stop event is set
self._heartbeat()

logger.debug("Exiting MQLLMEngine heartbeat thread")

def _heartbeat(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)

# Raises error if unhealthy.
self.engine.check_health()
self._send_healthy()
# Check for life of the main loop
elif time.time() - self._last_alive_time > self.last_alive_threshold:
self._send_unhealthy(RuntimeError("Engine loop has died"))

else:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)

def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
Expand All @@ -295,12 +333,14 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):

def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
if not self.heartbeat_socket.closed:
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)

def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
error_bytes = pickle.dumps(error)
self.health_socket.send_multipart((error_bytes, ), copy=False)
if not self.heartbeat_socket.closed:
error_bytes = pickle.dumps(error)
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)

def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
Expand All @@ -313,6 +353,9 @@ def _set_errored(self, e: BaseException):
if self._errored_with is None:
self._errored_with = e

def _alive(self):
self._last_alive_time = time.time()


def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str):
Expand Down

0 comments on commit 7c8d34f

Please sign in to comment.