Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Use heartbeats instead of health checks #8583

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
15 changes: 4 additions & 11 deletions tests/mq_llm_engine/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket):
await client.check_health()

# Trigger an abort on the client side.
async def bad_abort_after_2s():
await asyncio.sleep(2.0)
await client.abort(request_id="foo")
# This request ID does not exist, and will cause the engine to error
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test would always hang on my dev box- it looked like the generation finished before the bad abort, causing a failure, and the engine was not able to clean up properly for the test to finish.

This still tests that the KeyError is raised and stored so that a future request fails with it

await client.abort(request_id="foo")

# Trigger an abort in 2s from now.
abort_task = asyncio.create_task(bad_abort_after_2s())

# Exception in abort() will happen during this generation.
# This will kill the engine and should return ENGINE_DEAD_ERROR
# Future generation requests will now fail
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=2000),
sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored

await abort_task

# This should raise the original error.
with pytest.raises(RAISED_ERROR):
await client.check_health()
Expand Down
7 changes: 1 addition & 6 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ class RPCAbortRequest:
request_id: str


class RPCHealthRequest:
pass


class RPCStartupRequest(Enum):
IS_SERVER_READY = 1

Expand All @@ -56,8 +52,7 @@ class RPCStartupResponse:
tracing_enabled: bool


RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
RPCStartupRequest]
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest]

REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]

Expand Down
51 changes: 22 additions & 29 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs
Expand Down Expand Up @@ -95,9 +94,9 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig):
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")

# IPC path for ack of check_health requests.
self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for acking heartbeats.
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")

# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
Expand All @@ -124,34 +123,28 @@ def get_data_socket(self) -> Iterator[Socket]:
finally:
socket.close(linger=0)

async def run_check_health_loop(self, timeout: int):
"""Background loop that continually probes the RPCServer for health.
The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
the MQLLMEngine server is blocking on.
The Server replies on the HEALTH_SOCKET (rather than on the
OUTPUT_SOCKET such that the messages are not intermingled with
output streaming).
async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually listens to the RPCServer for
heartbeats.
"""

try:
while True:
if await self.health_socket.poll(timeout=timeout) == 0:
# Wakeup every N seconds and do a health probe.
await self._send_one_way_rpc_request(
RPCHealthRequest(), self.input_socket)

# Wait for ack from the health socket.
await self._await_ack(error_message="Health check failed.",
socket=self.health_socket)
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
# No heartbeat was received. Set error and exit the loop
self._set_errored(
TimeoutError("No heartbeat received "
"from MQLLMEngine"))
logger.debug("Shutting down MQLLMEngineClient check "
"health loop due to timeout")
break

else:
# Server sent a health status message unprompted.
# Heartbeat received- check the message
await self._check_success(
error_message="Health check failed.",
socket=self.health_socket)
error_message="Heartbeat failed.",
socket=self.heartbeat_socket)

logger.debug("Health probe successful.")
logger.debug("Heartbeat successful.")

except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.")
Expand Down Expand Up @@ -234,7 +227,7 @@ async def setup(self):

# Start health_loop.
self.health_loop = asyncio.create_task(
self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT))
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))

def close(self):
"""Destroy the ZeroMQ Context."""
Expand Down
67 changes: 52 additions & 15 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pickle
import signal
import threading
import time
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

Expand All @@ -15,10 +17,10 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -84,16 +86,25 @@ def __init__(self,
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")

# Send health status back to client.
self.health_socket = self.ctx.socket(zmq.constants.PUSH)
self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# Send heartbeats back to client.
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")

# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"

# Error state.
self._errored_with: Optional[BaseException] = None

# Heartbeat thread
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop)
self._heartbeat_stop_event = threading.Event()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0

self._last_alive_time = time.time()

@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
Expand Down Expand Up @@ -124,6 +135,8 @@ def start(self):
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting heartbeat thread")
self.heartbeat_thread.start()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
Expand All @@ -137,6 +150,7 @@ def start(self):
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self._heartbeat_stop_event.set()
self.ctx.destroy(linger=0)
del self.engine

Expand Down Expand Up @@ -175,9 +189,11 @@ def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""

while True:
self._alive()
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
self._alive()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")

Expand All @@ -193,7 +209,7 @@ def run_engine_loop(self):

def engine_step(self) -> List[RequestOutput]:
"""Engine step wrapper with error handling."""

self._alive()
try:
return self.engine.step()
except SystemExit:
Expand All @@ -210,6 +226,7 @@ def handle_new_input(self):
"""Handle new input from the socket"""
try:
while self.input_socket.poll(timeout=0) != 0:
self._alive()
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)

Expand All @@ -222,10 +239,8 @@ def handle_new_input(self):
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCHealthRequest):
self._handle_health_request()
else:
raise ValueError("Unknown RPCRequest Type: {request}")
raise ValueError(f"Unknown RPCRequest Type: {request}")

except Exception as e:
self._set_errored(e)
Expand Down Expand Up @@ -272,13 +287,32 @@ def _handle_abort_request(self, request: RPCAbortRequest):
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)

def _handle_health_request(self):
def _heartbeat_loop(self):
while not self._heartbeat_stop_event.wait(
timeout=self.heartbeat_interval_seconds):
# Loops until the stop event is set
self._heartbeat()

logger.debug("Exiting MQLLMEngine heartbeat thread")

def _heartbeat(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)

# Check for life of the main loop
if time.time() - self._last_alive_time > 30:
self._send_unhealthy(RuntimeError("Engine loop has died"))

# Raises error if unhealthy.
self.engine.check_health()
self._send_healthy()
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
# TODO: set errored_with here?
# Does it matter if we're just gonna kill the engine?
# Assume requests will be failing anyway if engine is unhealthy?
self._send_unhealthy(e)

def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
Expand All @@ -288,12 +322,12 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):

def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)

def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
error_bytes = pickle.dumps(error)
self.health_socket.send_multipart((error_bytes, ), copy=False)
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)

def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
Expand All @@ -306,6 +340,9 @@ def _set_errored(self, e: BaseException):
if self._errored_with is None:
self._errored_with = e

def _alive(self):
self._last_alive_time = time.time()


def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str):
Expand Down
Loading