Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions python/ray/llm/_internal/serve/deployments/llm/llm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,14 @@ async def __init__(
"""
await super().__init__(llm_config)

self._engine_cls = engine_cls or self._default_engine_cls
self.engine = self._get_engine_class(self._llm_config)
await asyncio.wait_for(self._start_engine(), timeout=ENGINE_START_TIMEOUT_S)

self._engine_cls = engine_cls or self._get_default_engine_class()
self.engine: Optional[LLMEngine] = None
if self._engine_cls is not None:
self.engine = self._engine_cls(self._llm_config)
await asyncio.wait_for(self._start_engine(), timeout=ENGINE_START_TIMEOUT_S)

# TODO (Kourosh): I think we can completely remove image retriever.
# It was missed to get removed.
self.image_retriever = (
image_retriever_cls()
if image_retriever_cls
Expand All @@ -466,25 +470,20 @@ async def __init__(

self.response_postprocessor = ResponsePostprocessor()

@property
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is useless. I don't know why we had it. removing.

def _get_engine_class(self) -> Type[LLMEngine]:
def _get_default_engine_class(self) -> Type[LLMEngine]:
"""Helper to load the engine class from the environment variable.

This is used for testing or escape-hatch for patching purposes.
If env variable is not set, it will fallback to the default engine class.
"""
engine_cls_path = os.environ.get(RAYLLM_VLLM_ENGINE_CLS_ENV)
if engine_cls_path:
try:
return import_attr(engine_cls_path)
except AttributeError:
logger.warning(
f"Failed to import engine class {engine_cls_path}. "
f"Using the default engine class {self._engine_cls}."
)
return self._engine_cls
return import_attr(engine_cls_path)
return self._default_engine_cls

async def _start_engine(self):
if self.engine is None:
raise ValueError("Engine is not set")

await self.engine.start()

# Push telemetry reports for the model in the current deployment.
Expand Down Expand Up @@ -616,7 +615,13 @@ async def check_health(self) -> None:
Check the health of the replica. Does not return anything. Raise error when
the engine is dead and needs to be restarted.
"""
return await self.engine.check_health()
if self.engine is None:
return
try:
return await self.engine.check_health()
except Exception as e:
logger.error("Engine health check failed in LLMServer.check_health: %s", e)
raise e

async def embeddings(self, request: EmbeddingRequest) -> LLMEmbeddingsResponse:
"""Runs an embeddings request to the vllm engine, and return the response.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import os
import re
import time
Expand Down Expand Up @@ -816,9 +815,9 @@ async def check_health(self) -> None:
raise RuntimeError(f"{type(self.engine)} does not support health check.")

try:
return await asyncio.wait_for(self.engine.check_health(), timeout=15)
await self.engine.check_health()
except BaseException as e:
logger.exception("Healthcheck failed. The replica will be restarted")
logger.error("Healthcheck failed. The replica will be restarted")
raise e from None

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Using Ray Serve to deploy LLM models with P/D disaggregation.
"""
import asyncio
import logging
import uuid
from typing import Any, AsyncGenerator, Dict, Union
Expand Down Expand Up @@ -63,6 +62,7 @@ def parse_configs_and_cast_type(config: Union[str, LLMConfig]) -> LLMConfig:


class PDProxyServer(LLMServer):
_default_engine_cls = None
"""
Proxy between P/D LLM servers.

Expand All @@ -83,22 +83,13 @@ async def __init__(
prefill_server: DeploymentHandle,
decode_server: DeploymentHandle,
):
class FakeEngine:
"""Provide a fake engine such that proxy don't really start any engine."""

def __init__(self, *args, **kwargs):
pass

async def start(self, *args, **kwargs):
pass

# We pass `llm_config` here to let super() extract the model_id, such that /v1/models
# endpoint can work correctly.
# TODO(lk-chen): refactor LLMRouter <-> LLMServer such that router query model_id through
# API, instead of passing it in as an argument.
await super().__init__(
llm_config,
engine_cls=FakeEngine,
)

self.prefill_server = prefill_server
Expand Down Expand Up @@ -160,13 +151,6 @@ async def _predict(
):
yield chunk

async def check_health(self) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These must be removed. In general the health check of a deployment is not bounded to the health check of its child deployments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @lk-chen fyi

"""Check the health of the llm engine."""
await asyncio.gather(
self.prefill_server.check_health.remote(),
self.decode_server.check_health.remote(),
)

@classmethod
def as_deployment(cls) -> serve.Deployment:
"""Turns PDProxyServer into a Ray Serve deployment."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,6 @@ async def _setup_handle_and_config_maps(

async def check_health(self):
await self._init_completed.wait()
await asyncio.gather(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing applies here.

*[
handle.check_health.remote()
for handle in self._default_serve_handles.values()
]
)

def _get_configured_serve_handle(self, model_id: str):
"""Gets a ServeHandle to a model deployment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ async def test_check_health(self, llm_config: LLMConfig):

await router.check_health()

assert server.check_health.remote.call_count == 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testing router's health check has nothing to do with server's health check.



if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
95 changes: 95 additions & 0 deletions release/llm_tests/serve/test_llm_serve_fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import time
from typing import Literal, List, Generator

import pytest
import ray
from ray import serve
from ray.serve.llm import LLMConfig, ModelLoadingConfig, build_llm_deployment

MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
RAY_MODEL_ID = "qwen-0.5b"


def get_llm_config(
tensor_parallel_size: int = 1,
) -> LLMConfig:
"""Create LLMConfig with specified parallelism parameters."""
return LLMConfig(
model_loading_config=ModelLoadingConfig(
model_id=RAY_MODEL_ID,
model_source=MODEL_ID,
),
deployment_config=dict(
name="test",
num_replicas=2,
),
engine_kwargs=dict(
tensor_parallel_size=tensor_parallel_size,
enforce_eager=True,
),
runtime_env={"env_vars": {"VLLM_USE_V1": "1"}},
)


def find_replica_ids(deployment_name: str) -> List[str]:
actors = ray.util.list_named_actors("serve")
found_replica_ids = []
for actor in actors:
if deployment_name in actor["name"]:
found_replica_ids.append(actor["name"])
return found_replica_ids


def kill_replica(replica_id: str) -> None:
actor = ray.get_actor(replica_id, namespace="serve")
ray.kill(actor)


@pytest.fixture(name="app", scope="function")
def start_ray_serve(
tensor_parallel_size: int = 1,
) -> Generator:
"""Start Ray Serve with specified parallelism parameters."""
llm_config: LLMConfig = get_llm_config(tensor_parallel_size)
app = build_llm_deployment(llm_config, name_prefix="LLM:")
serve.run(app, blocking=False)
yield app
serve.shutdown()


def wait_for_deployment_status(
deployment_name: str, status: Literal["HEALTHY", "UNHEALTHY"], timeout_s: int = 120
) -> None:
s = time.time()
while time.time() - s < timeout_s:
print(f"Waiting for deployment {deployment_name} to become {status}")
state = serve.status()
if state.applications["default"].deployments[deployment_name].status == status:
return
time.sleep(1)
raise TimeoutError(
f"Deployment {deployment_name} did not become "
f"{status} within {timeout_s} seconds"
)


def test_recovery_from_replica_failure(app) -> None:
"""Tests that the deployment recovers from replica failure."""
dname = "LLM:test"
wait_for_deployment_status(dname, "HEALTHY", timeout_s=60)

# Kill both replicas
replica_ids = find_replica_ids(dname)
for replica_id in replica_ids:
print(f"Killing replica {replica_id}")
kill_replica(replica_id)

# wait for deployment to get unhealthy
wait_for_deployment_status(dname, "UNHEALTHY", timeout_s=60)

# Wait again for deployment to get healthy
wait_for_deployment_status(dname, "HEALTHY", timeout_s=60)


if __name__ == "__main__":
pytest.main(["-xvs", __file__])
2 changes: 2 additions & 0 deletions release/llm_tests/serve/test_llm_serve_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ async def test_engine_metrics():
model="Qwen/Qwen2.5-0.5B-Instruct",
dtype="auto",
disable_log_stats=False,
enforce_eager=True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making the tests faster

)

engine = AsyncLLM.from_engine_args(
Expand Down Expand Up @@ -75,6 +76,7 @@ def remote_model_app(request):
enable_chunked_prefill=True,
enable_prefix_caching=True,
trust_remote_code=remote_code,
enforce_eager=True,
),
}

Expand Down
4 changes: 2 additions & 2 deletions release/release_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4288,7 +4288,7 @@
long_running: false
script: pytest -vs test_llm_serve_correctness.py

- name: llm_serve_integration
- name: llm_serve_vllm_integration_tests
frequency: nightly
python: "3.11"
group: llm-serve
Expand All @@ -4307,7 +4307,7 @@
run:
timeout: 3600
long_running: false
script: pytest -vs test_llm_serve_integration.py
script: pytest -vs test_llm_serve_integration.py test_llm_serve_fault_tolerance.py

- name: llm_serve_llama_3dot1_8B_quantized_tp1_1p1d
frequency: nightly
Expand Down