Skip to content

Commit 40a3cc2

Browse files
kouroshHakhaelliot-barn
authored andcommitted
[Serve.llm][P/D] Fix health check in prefill disagg (#53937)
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
1 parent a0158e1 commit 40a3cc2

File tree

8 files changed

+123
-46
lines changed

8 files changed

+123
-46
lines changed

python/ray/llm/_internal/serve/deployments/llm/llm_server.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -437,10 +437,14 @@ async def __init__(
437437
"""
438438
await super().__init__(llm_config)
439439

440-
self._engine_cls = engine_cls or self._default_engine_cls
441-
self.engine = self._get_engine_class(self._llm_config)
442-
await asyncio.wait_for(self._start_engine(), timeout=ENGINE_START_TIMEOUT_S)
443-
440+
self._engine_cls = engine_cls or self._get_default_engine_class()
441+
self.engine: Optional[LLMEngine] = None
442+
if self._engine_cls is not None:
443+
self.engine = self._engine_cls(self._llm_config)
444+
await asyncio.wait_for(self._start_engine(), timeout=ENGINE_START_TIMEOUT_S)
445+
446+
# TODO (Kourosh): I think we can completely remove image retriever.
447+
# It was missed to get removed.
444448
self.image_retriever = (
445449
image_retriever_cls()
446450
if image_retriever_cls
@@ -466,25 +470,20 @@ async def __init__(
466470

467471
self.response_postprocessor = ResponsePostprocessor()
468472

469-
@property
470-
def _get_engine_class(self) -> Type[LLMEngine]:
473+
def _get_default_engine_class(self) -> Type[LLMEngine]:
471474
"""Helper to load the engine class from the environment variable.
472-
473475
This is used for testing or escape-hatch for patching purposes.
474476
If env variable is not set, it will fallback to the default engine class.
475477
"""
476478
engine_cls_path = os.environ.get(RAYLLM_VLLM_ENGINE_CLS_ENV)
477479
if engine_cls_path:
478-
try:
479-
return import_attr(engine_cls_path)
480-
except AttributeError:
481-
logger.warning(
482-
f"Failed to import engine class {engine_cls_path}. "
483-
f"Using the default engine class {self._engine_cls}."
484-
)
485-
return self._engine_cls
480+
return import_attr(engine_cls_path)
481+
return self._default_engine_cls
486482

487483
async def _start_engine(self):
484+
if self.engine is None:
485+
raise ValueError("Engine is not set")
486+
488487
await self.engine.start()
489488

490489
# Push telemetry reports for the model in the current deployment.
@@ -616,7 +615,13 @@ async def check_health(self) -> None:
616615
Check the health of the replica. Does not return anything. Raise error when
617616
the engine is dead and needs to be restarted.
618617
"""
619-
return await self.engine.check_health()
618+
if self.engine is None:
619+
return
620+
try:
621+
return await self.engine.check_health()
622+
except Exception as e:
623+
logger.error("Engine health check failed in LLMServer.check_health: %s", e)
624+
raise e
620625

621626
async def embeddings(self, request: EmbeddingRequest) -> LLMEmbeddingsResponse:
622627
"""Runs an embeddings request to the vllm engine, and return the response.

python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import os
32
import re
43
import time
@@ -816,9 +815,9 @@ async def check_health(self) -> None:
816815
raise RuntimeError(f"{type(self.engine)} does not support health check.")
817816

818817
try:
819-
return await asyncio.wait_for(self.engine.check_health(), timeout=15)
818+
await self.engine.check_health()
820819
except BaseException as e:
821-
logger.exception("Healthcheck failed. The replica will be restarted")
820+
logger.error("Healthcheck failed. The replica will be restarted")
822821
raise e from None
823822

824823
@staticmethod

python/ray/llm/_internal/serve/deployments/prefill_decode_disagg/prefill_decode_disagg.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Using Ray Serve to deploy LLM models with P/D disaggregation.
22
"""
3-
import asyncio
43
import logging
54
import uuid
65
from typing import Any, AsyncGenerator, Dict, Union
@@ -63,6 +62,7 @@ def parse_configs_and_cast_type(config: Union[str, LLMConfig]) -> LLMConfig:
6362

6463

6564
class PDProxyServer(LLMServer):
65+
_default_engine_cls = None
6666
"""
6767
Proxy between P/D LLM servers.
6868
@@ -83,22 +83,13 @@ async def __init__(
8383
prefill_server: DeploymentHandle,
8484
decode_server: DeploymentHandle,
8585
):
86-
class FakeEngine:
87-
"""Provide a fake engine such that proxy don't really start any engine."""
88-
89-
def __init__(self, *args, **kwargs):
90-
pass
91-
92-
async def start(self, *args, **kwargs):
93-
pass
9486

9587
# We pass `llm_config` here to let super() extract the model_id, such that /v1/models
9688
# endpoint can work correctly.
9789
# TODO(lk-chen): refactor LLMRouter <-> LLMServer such that router query model_id through
9890
# API, instead of passing it in as an argument.
9991
await super().__init__(
10092
llm_config,
101-
engine_cls=FakeEngine,
10293
)
10394

10495
self.prefill_server = prefill_server
@@ -160,13 +151,6 @@ async def _predict(
160151
):
161152
yield chunk
162153

163-
async def check_health(self) -> None:
164-
"""Check the health of the llm engine."""
165-
await asyncio.gather(
166-
self.prefill_server.check_health.remote(),
167-
self.decode_server.check_health.remote(),
168-
)
169-
170154
@classmethod
171155
def as_deployment(cls) -> serve.Deployment:
172156
"""Turns PDProxyServer into a Ray Serve deployment."""

python/ray/llm/_internal/serve/deployments/routers/router.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,6 @@ async def _setup_handle_and_config_maps(
232232

233233
async def check_health(self):
234234
await self._init_completed.wait()
235-
await asyncio.gather(
236-
*[
237-
handle.check_health.remote()
238-
for handle in self._default_serve_handles.values()
239-
]
240-
)
241235

242236
def _get_configured_serve_handle(self, model_id: str):
243237
"""Gets a ServeHandle to a model deployment.

python/ray/llm/tests/serve/cpu/deployments/routers/test_router.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,6 @@ async def test_check_health(self, llm_config: LLMConfig):
170170

171171
await router.check_health()
172172

173-
assert server.check_health.remote.call_count == 1
174-
175173

176174
if __name__ == "__main__":
177175
sys.exit(pytest.main(["-v", __file__]))
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import time
2+
from typing import Literal, List, Generator
3+
4+
import pytest
5+
import ray
6+
from ray import serve
7+
from ray.serve.llm import LLMConfig, ModelLoadingConfig, build_llm_deployment
8+
9+
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
10+
RAY_MODEL_ID = "qwen-0.5b"
11+
12+
13+
def get_llm_config(
14+
tensor_parallel_size: int = 1,
15+
) -> LLMConfig:
16+
"""Create LLMConfig with specified parallelism parameters."""
17+
return LLMConfig(
18+
model_loading_config=ModelLoadingConfig(
19+
model_id=RAY_MODEL_ID,
20+
model_source=MODEL_ID,
21+
),
22+
deployment_config=dict(
23+
name="test",
24+
num_replicas=2,
25+
),
26+
engine_kwargs=dict(
27+
tensor_parallel_size=tensor_parallel_size,
28+
enforce_eager=True,
29+
),
30+
runtime_env={"env_vars": {"VLLM_USE_V1": "1"}},
31+
)
32+
33+
34+
def find_replica_ids(deployment_name: str) -> List[str]:
35+
actors = ray.util.list_named_actors("serve")
36+
found_replica_ids = []
37+
for actor in actors:
38+
if deployment_name in actor["name"]:
39+
found_replica_ids.append(actor["name"])
40+
return found_replica_ids
41+
42+
43+
def kill_replica(replica_id: str) -> None:
44+
actor = ray.get_actor(replica_id, namespace="serve")
45+
ray.kill(actor)
46+
47+
48+
@pytest.fixture(name="app", scope="function")
49+
def start_ray_serve(
50+
tensor_parallel_size: int = 1,
51+
) -> Generator:
52+
"""Start Ray Serve with specified parallelism parameters."""
53+
llm_config: LLMConfig = get_llm_config(tensor_parallel_size)
54+
app = build_llm_deployment(llm_config, name_prefix="LLM:")
55+
serve.run(app, blocking=False)
56+
yield app
57+
serve.shutdown()
58+
59+
60+
def wait_for_deployment_status(
61+
deployment_name: str, status: Literal["HEALTHY", "UNHEALTHY"], timeout_s: int = 120
62+
) -> None:
63+
s = time.time()
64+
while time.time() - s < timeout_s:
65+
print(f"Waiting for deployment {deployment_name} to become {status}")
66+
state = serve.status()
67+
if state.applications["default"].deployments[deployment_name].status == status:
68+
return
69+
time.sleep(1)
70+
raise TimeoutError(
71+
f"Deployment {deployment_name} did not become "
72+
f"{status} within {timeout_s} seconds"
73+
)
74+
75+
76+
def test_recovery_from_replica_failure(app) -> None:
77+
"""Tests that the deployment recovers from replica failure."""
78+
dname = "LLM:test"
79+
wait_for_deployment_status(dname, "HEALTHY", timeout_s=60)
80+
81+
# Kill both replicas
82+
replica_ids = find_replica_ids(dname)
83+
for replica_id in replica_ids:
84+
print(f"Killing replica {replica_id}")
85+
kill_replica(replica_id)
86+
87+
# wait for deployment to get unhealthy
88+
wait_for_deployment_status(dname, "UNHEALTHY", timeout_s=60)
89+
90+
# Wait again for deployment to get healthy
91+
wait_for_deployment_status(dname, "HEALTHY", timeout_s=60)
92+
93+
94+
if __name__ == "__main__":
95+
pytest.main(["-xvs", __file__])

release/llm_tests/serve/test_llm_serve_integration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ async def test_engine_metrics():
2727
model="Qwen/Qwen2.5-0.5B-Instruct",
2828
dtype="auto",
2929
disable_log_stats=False,
30+
enforce_eager=True,
3031
)
3132

3233
engine = AsyncLLM.from_engine_args(
@@ -75,6 +76,7 @@ def remote_model_app(request):
7576
enable_chunked_prefill=True,
7677
enable_prefix_caching=True,
7778
trust_remote_code=remote_code,
79+
enforce_eager=True,
7880
),
7981
}
8082

release/release_tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4288,7 +4288,7 @@
42884288
long_running: false
42894289
script: pytest -vs test_llm_serve_correctness.py
42904290

4291-
- name: llm_serve_integration
4291+
- name: llm_serve_vllm_integration_tests
42924292
frequency: nightly
42934293
python: "3.11"
42944294
group: llm-serve
@@ -4307,7 +4307,7 @@
43074307
run:
43084308
timeout: 3600
43094309
long_running: false
4310-
script: pytest -vs test_llm_serve_integration.py
4310+
script: pytest -vs test_llm_serve_integration.py test_llm_serve_fault_tolerance.py
43114311

43124312
- name: llm_serve_llama_3dot1_8B_quantized_tp1_1p1d
43134313
frequency: nightly

0 commit comments

Comments
 (0)