-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[Serve.llm][P/D] Fix health check in prefill disagg #53937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -437,10 +437,14 @@ async def __init__( | |
| """ | ||
| await super().__init__(llm_config) | ||
|
|
||
| self._engine_cls = engine_cls or self._default_engine_cls | ||
| self.engine = self._get_engine_class(self._llm_config) | ||
| await asyncio.wait_for(self._start_engine(), timeout=ENGINE_START_TIMEOUT_S) | ||
|
|
||
| self._engine_cls = engine_cls or self._get_default_engine_class() | ||
| self.engine: Optional[LLMEngine] = None | ||
| if self._engine_cls is not None: | ||
| self.engine = self._engine_cls(self._llm_config) | ||
| await asyncio.wait_for(self._start_engine(), timeout=ENGINE_START_TIMEOUT_S) | ||
|
|
||
| # TODO (Kourosh): I think we can completely remove image retriever. | ||
| # It was missed to get removed. | ||
| self.image_retriever = ( | ||
| image_retriever_cls() | ||
| if image_retriever_cls | ||
|
|
@@ -466,25 +470,20 @@ async def __init__( | |
|
|
||
| self.response_postprocessor = ResponsePostprocessor() | ||
|
|
||
| @property | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is useless. I don't know why we had it. removing. |
||
| def _get_engine_class(self) -> Type[LLMEngine]: | ||
| def _get_default_engine_class(self) -> Type[LLMEngine]: | ||
| """Helper to load the engine class from the environment variable. | ||
|
|
||
| This is used for testing or escape-hatch for patching purposes. | ||
| If env variable is not set, it will fallback to the default engine class. | ||
| """ | ||
| engine_cls_path = os.environ.get(RAYLLM_VLLM_ENGINE_CLS_ENV) | ||
| if engine_cls_path: | ||
| try: | ||
| return import_attr(engine_cls_path) | ||
| except AttributeError: | ||
| logger.warning( | ||
| f"Failed to import engine class {engine_cls_path}. " | ||
| f"Using the default engine class {self._engine_cls}." | ||
| ) | ||
| return self._engine_cls | ||
| return import_attr(engine_cls_path) | ||
| return self._default_engine_cls | ||
|
|
||
| async def _start_engine(self): | ||
| if self.engine is None: | ||
| raise ValueError("Engine is not set") | ||
|
|
||
| await self.engine.start() | ||
|
|
||
| # Push telemetry reports for the model in the current deployment. | ||
|
|
@@ -616,7 +615,13 @@ async def check_health(self) -> None: | |
| Check the health of the replica. Does not return anything. Raise error when | ||
| the engine is dead and needs to be restarted. | ||
| """ | ||
| return await self.engine.check_health() | ||
| if self.engine is None: | ||
| return | ||
| try: | ||
| return await self.engine.check_health() | ||
| except Exception as e: | ||
| logger.error("Engine health check failed in LLMServer.check_health: %s", e) | ||
| raise e | ||
|
|
||
| async def embeddings(self, request: EmbeddingRequest) -> LLMEmbeddingsResponse: | ||
| """Runs an embeddings request to the vllm engine, and return the response. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,5 @@ | ||
| """Using Ray Serve to deploy LLM models with P/D disaggregation. | ||
| """ | ||
| import asyncio | ||
| import logging | ||
| import uuid | ||
| from typing import Any, AsyncGenerator, Dict, Union | ||
|
|
@@ -63,6 +62,7 @@ def parse_configs_and_cast_type(config: Union[str, LLMConfig]) -> LLMConfig: | |
|
|
||
|
|
||
| class PDProxyServer(LLMServer): | ||
| _default_engine_cls = None | ||
| """ | ||
| Proxy between P/D LLM servers. | ||
|
|
||
|
|
@@ -83,22 +83,13 @@ async def __init__( | |
| prefill_server: DeploymentHandle, | ||
| decode_server: DeploymentHandle, | ||
| ): | ||
| class FakeEngine: | ||
| """Provide a fake engine such that proxy don't really start any engine.""" | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| pass | ||
|
|
||
| async def start(self, *args, **kwargs): | ||
| pass | ||
|
|
||
| # We pass `llm_config` here to let super() extract the model_id, such that /v1/models | ||
| # endpoint can work correctly. | ||
| # TODO(lk-chen): refactor LLMRouter <-> LLMServer such that router query model_id through | ||
| # API, instead of passing it in as an argument. | ||
| await super().__init__( | ||
| llm_config, | ||
| engine_cls=FakeEngine, | ||
| ) | ||
|
|
||
| self.prefill_server = prefill_server | ||
|
|
@@ -160,13 +151,6 @@ async def _predict( | |
| ): | ||
| yield chunk | ||
|
|
||
| async def check_health(self) -> None: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These must be removed. In general the health check of a deployment is not bounded to the health check of its child deployments.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @lk-chen fyi |
||
| """Check the health of the llm engine.""" | ||
| await asyncio.gather( | ||
| self.prefill_server.check_health.remote(), | ||
| self.decode_server.check_health.remote(), | ||
| ) | ||
|
|
||
| @classmethod | ||
| def as_deployment(cls) -> serve.Deployment: | ||
| """Turns PDProxyServer into a Ray Serve deployment.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -232,12 +232,6 @@ async def _setup_handle_and_config_maps( | |
|
|
||
| async def check_health(self): | ||
| await self._init_completed.wait() | ||
| await asyncio.gather( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same thing applies here. |
||
| *[ | ||
| handle.check_health.remote() | ||
| for handle in self._default_serve_handles.values() | ||
| ] | ||
| ) | ||
|
|
||
| def _get_configured_serve_handle(self, model_id: str): | ||
| """Gets a ServeHandle to a model deployment. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -170,8 +170,6 @@ async def test_check_health(self, llm_config: LLMConfig): | |
|
|
||
| await router.check_health() | ||
|
|
||
| assert server.check_health.remote.call_count == 1 | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. testing router's health check has nothing to do with server's health check. |
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| sys.exit(pytest.main(["-v", __file__])) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| import time | ||
| from typing import Literal, List, Generator | ||
|
|
||
| import pytest | ||
| import ray | ||
| from ray import serve | ||
| from ray.serve.llm import LLMConfig, ModelLoadingConfig, build_llm_deployment | ||
|
|
||
| MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" | ||
| RAY_MODEL_ID = "qwen-0.5b" | ||
|
|
||
|
|
||
| def get_llm_config( | ||
| tensor_parallel_size: int = 1, | ||
| ) -> LLMConfig: | ||
| """Create LLMConfig with specified parallelism parameters.""" | ||
| return LLMConfig( | ||
| model_loading_config=ModelLoadingConfig( | ||
| model_id=RAY_MODEL_ID, | ||
| model_source=MODEL_ID, | ||
| ), | ||
| deployment_config=dict( | ||
| name="test", | ||
| num_replicas=2, | ||
| ), | ||
| engine_kwargs=dict( | ||
| tensor_parallel_size=tensor_parallel_size, | ||
| enforce_eager=True, | ||
| ), | ||
| runtime_env={"env_vars": {"VLLM_USE_V1": "1"}}, | ||
| ) | ||
|
|
||
|
|
||
| def find_replica_ids(deployment_name: str) -> List[str]: | ||
| actors = ray.util.list_named_actors("serve") | ||
| found_replica_ids = [] | ||
| for actor in actors: | ||
| if deployment_name in actor["name"]: | ||
| found_replica_ids.append(actor["name"]) | ||
| return found_replica_ids | ||
|
|
||
|
|
||
| def kill_replica(replica_id: str) -> None: | ||
| actor = ray.get_actor(replica_id, namespace="serve") | ||
| ray.kill(actor) | ||
|
|
||
|
|
||
| @pytest.fixture(name="app", scope="function") | ||
| def start_ray_serve( | ||
| tensor_parallel_size: int = 1, | ||
| ) -> Generator: | ||
| """Start Ray Serve with specified parallelism parameters.""" | ||
| llm_config: LLMConfig = get_llm_config(tensor_parallel_size) | ||
| app = build_llm_deployment(llm_config, name_prefix="LLM:") | ||
| serve.run(app, blocking=False) | ||
| yield app | ||
| serve.shutdown() | ||
|
|
||
|
|
||
| def wait_for_deployment_status( | ||
| deployment_name: str, status: Literal["HEALTHY", "UNHEALTHY"], timeout_s: int = 120 | ||
| ) -> None: | ||
| s = time.time() | ||
| while time.time() - s < timeout_s: | ||
| print(f"Waiting for deployment {deployment_name} to become {status}") | ||
| state = serve.status() | ||
| if state.applications["default"].deployments[deployment_name].status == status: | ||
| return | ||
| time.sleep(1) | ||
| raise TimeoutError( | ||
| f"Deployment {deployment_name} did not become " | ||
| f"{status} within {timeout_s} seconds" | ||
| ) | ||
|
|
||
|
|
||
| def test_recovery_from_replica_failure(app) -> None: | ||
| """Tests that the deployment recovers from replica failure.""" | ||
| dname = "LLM:test" | ||
| wait_for_deployment_status(dname, "HEALTHY", timeout_s=60) | ||
|
|
||
| # Kill both replicas | ||
| replica_ids = find_replica_ids(dname) | ||
| for replica_id in replica_ids: | ||
| print(f"Killing replica {replica_id}") | ||
| kill_replica(replica_id) | ||
|
|
||
| # wait for deployment to get unhealthy | ||
| wait_for_deployment_status(dname, "UNHEALTHY", timeout_s=60) | ||
|
|
||
| # Wait again for deployment to get healthy | ||
| wait_for_deployment_status(dname, "HEALTHY", timeout_s=60) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main(["-xvs", __file__]) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ async def test_engine_metrics(): | |
| model="Qwen/Qwen2.5-0.5B-Instruct", | ||
| dtype="auto", | ||
| disable_log_stats=False, | ||
| enforce_eager=True, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making the tests faster |
||
| ) | ||
|
|
||
| engine = AsyncLLM.from_engine_args( | ||
|
|
@@ -75,6 +76,7 @@ def remote_model_app(request): | |
| enable_chunked_prefill=True, | ||
| enable_prefix_caching=True, | ||
| trust_remote_code=remote_code, | ||
| enforce_eager=True, | ||
| ), | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.