Skip to content

Commit 4706bf1

Browse files
authored
[serve][llm] Unify and Extend Builder Configuration for LLM Deployments (#57724)
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
1 parent c8c446d commit 4706bf1

File tree

14 files changed

+724
-391
lines changed

14 files changed

+724
-391
lines changed

python/ray/llm/_internal/common/base_pydantic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,9 @@ def parse_yaml(cls: Type[ModelT], file, **kwargs) -> ModelT:
2323
kwargs.setdefault("Loader", yaml.SafeLoader)
2424
dict_args = yaml.load(file, **kwargs)
2525
return cls.model_validate(dict_args)
26+
27+
@classmethod
28+
def from_file(cls: Type[ModelT], path: str, **kwargs) -> ModelT:
29+
"""Load a model from a YAML file path."""
30+
with open(path, "r") as f:
31+
return cls.parse_yaml(f, **kwargs)

python/ray/llm/_internal/serve/configs/server_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def validate_experimental_configs(cls, value: Dict[str, Any]) -> Dict[str, Any]:
375375

376376
@model_validator(mode="after")
377377
def _check_log_stats_with_metrics(self):
378-
# Require disable_log_stats is not set to True when log_engine_metrics is enabled.
378+
"""Validate that disable_log_stats isn't enabled when log_engine_metrics is enabled."""
379379
if self.log_engine_metrics and self.engine_kwargs.get("disable_log_stats"):
380380
raise ValueError(
381381
"disable_log_stats cannot be set to True when log_engine_metrics is enabled. "
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Using Ray Serve to deploy LLM models with P/D disaggregation.
2+
"""
3+
from typing import Any, Optional, Union
4+
5+
from pydantic import Field, field_validator, model_validator
6+
7+
from ray import serve
8+
from ray.llm._internal.common.base_pydantic import BaseModelExtended
9+
from ray.llm._internal.common.dict_utils import deep_merge_dicts
10+
from ray.llm._internal.serve.deployments.prefill_decode_disagg.pd import PDProxyServer
11+
from ray.llm._internal.serve.deployments.routers.builder_ingress import (
12+
IngressClsConfig,
13+
load_class,
14+
)
15+
from ray.llm._internal.serve.deployments.routers.router import (
16+
make_fastapi_ingress,
17+
)
18+
from ray.serve.deployment import Application
19+
from ray.serve.llm import (
20+
LLMConfig,
21+
build_llm_deployment,
22+
)
23+
24+
25+
class ProxyClsConfig(BaseModelExtended):
26+
proxy_cls: Union[str, type[PDProxyServer]] = Field(
27+
default=PDProxyServer,
28+
description="The proxy class or the class module path to use.",
29+
)
30+
31+
proxy_extra_kwargs: Optional[dict] = Field(
32+
default_factory=dict,
33+
description="The kwargs to bind to the proxy deployment. This will be passed to the proxy class constructor.",
34+
)
35+
36+
@field_validator("proxy_cls")
37+
@classmethod
38+
def validate_class(
39+
cls, value: Union[str, type[PDProxyServer]]
40+
) -> type[PDProxyServer]:
41+
if isinstance(value, str):
42+
return load_class(value)
43+
return value
44+
45+
46+
class PDServingArgs(BaseModelExtended):
47+
"""Schema for P/D serving args."""
48+
49+
prefill_config: Union[str, dict, LLMConfig]
50+
decode_config: Union[str, dict, LLMConfig]
51+
proxy_cls_config: Union[dict, ProxyClsConfig] = Field(
52+
default_factory=ProxyClsConfig,
53+
description="The configuration for the proxy class.",
54+
)
55+
proxy_deployment_config: Optional[dict] = Field(
56+
default_factory=dict,
57+
description="The Ray @server.deployment options for the proxy server.",
58+
)
59+
ingress_cls_config: Union[dict, IngressClsConfig] = Field(
60+
default_factory=IngressClsConfig,
61+
description="The configuration for the ingress class.",
62+
)
63+
ingress_deployment_config: Optional[dict] = Field(
64+
default_factory=dict,
65+
description="The Ray @server.deployment options for the ingress.",
66+
)
67+
68+
@field_validator("prefill_config", "decode_config")
69+
@classmethod
70+
def _validate_llm_config(cls, value: Any) -> LLMConfig:
71+
if isinstance(value, str):
72+
return LLMConfig.from_file(value)
73+
elif isinstance(value, dict):
74+
return LLMConfig.model_validate(value)
75+
elif isinstance(value, LLMConfig):
76+
return value
77+
else:
78+
raise TypeError(f"Invalid LLMConfig type: {type(value)}")
79+
80+
@field_validator("proxy_cls_config")
81+
@classmethod
82+
def _validate_proxy_cls_config(
83+
cls, value: Union[dict, ProxyClsConfig]
84+
) -> ProxyClsConfig:
85+
if isinstance(value, dict):
86+
return ProxyClsConfig.model_validate(value)
87+
return value
88+
89+
@field_validator("ingress_cls_config")
90+
@classmethod
91+
def _validate_ingress_cls_config(
92+
cls, value: Union[dict, IngressClsConfig]
93+
) -> IngressClsConfig:
94+
if isinstance(value, dict):
95+
return IngressClsConfig.model_validate(value)
96+
return value
97+
98+
@model_validator(mode="after")
99+
def _validate_model_ids(self):
100+
"""Validate that prefill and decode configs use the same model ID."""
101+
if self.prefill_config.model_id != self.decode_config.model_id:
102+
raise ValueError("P/D model id mismatch")
103+
return self
104+
105+
@model_validator(mode="after")
106+
def _validate_kv_transfer_config(self):
107+
"""Validate that kv_transfer_config is set for both prefill and decode configs."""
108+
for config in [self.prefill_config, self.decode_config]:
109+
if config.engine_kwargs.get("kv_transfer_config") is None:
110+
raise ValueError(
111+
"kv_transfer_config is required for P/D disaggregation"
112+
)
113+
return self
114+
115+
116+
def build_pd_openai_app(pd_serving_args: dict) -> Application:
117+
"""Build a deployable application utilizing prefill/decode disaggregation."""
118+
pd_config = PDServingArgs.model_validate(pd_serving_args)
119+
120+
prefill_deployment = build_llm_deployment(
121+
pd_config.prefill_config, name_prefix="Prefill:"
122+
)
123+
decode_deployment = build_llm_deployment(
124+
pd_config.decode_config, name_prefix="Decode:"
125+
)
126+
127+
# Get the default deployment options from the PDProxyServer class based on the prefill and decode configs.
128+
proxy_cls_config = pd_config.proxy_cls_config
129+
130+
pd_proxy_server_options = proxy_cls_config.proxy_cls.get_deployment_options(
131+
pd_config.prefill_config, pd_config.decode_config
132+
)
133+
134+
# Override if the proxy deployment config is provided.
135+
if pd_config.proxy_deployment_config:
136+
pd_proxy_server_options = deep_merge_dicts(
137+
pd_proxy_server_options, pd_config.proxy_deployment_config
138+
)
139+
140+
proxy_server_deployment = (
141+
serve.deployment(proxy_cls_config.proxy_cls)
142+
.options(**pd_proxy_server_options)
143+
.bind(
144+
prefill_server=prefill_deployment,
145+
decode_server=decode_deployment,
146+
**proxy_cls_config.proxy_extra_kwargs,
147+
)
148+
)
149+
150+
ingress_cls_config = pd_config.ingress_cls_config
151+
ingress_options = ingress_cls_config.ingress_cls.get_deployment_options(
152+
[pd_config.prefill_config, pd_config.decode_config]
153+
)
154+
155+
if pd_config.ingress_deployment_config:
156+
ingress_options = deep_merge_dicts(
157+
ingress_options, pd_config.ingress_deployment_config
158+
)
159+
160+
ingress_cls = make_fastapi_ingress(ingress_cls_config.ingress_cls)
161+
return serve.deployment(ingress_cls, **ingress_options).bind(
162+
llm_deployments=[proxy_server_deployment],
163+
**ingress_cls_config.ingress_extra_kwargs,
164+
)
Lines changed: 10 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
"""Using Ray Serve to deploy LLM models with P/D disaggregation.
22
"""
33
import logging
4-
import uuid
54
from typing import Any, AsyncGenerator, Dict, Union
65

7-
from pydantic import Field
8-
9-
from ray import serve
10-
from ray.llm._internal.common.base_pydantic import BaseModelExtended
6+
from ray.llm._internal.serve.configs.constants import DEFAULT_MAX_ONGOING_REQUESTS
117
from ray.llm._internal.serve.configs.openai_api_models import (
128
ChatCompletionRequest,
139
ChatCompletionResponse,
@@ -18,53 +14,15 @@
1814
ErrorResponse,
1915
)
2016
from ray.llm._internal.serve.deployments.llm.llm_server import LLMServer
21-
from ray.llm._internal.serve.deployments.routers.builder_ingress import (
22-
parse_args as parse_llm_configs,
23-
)
24-
from ray.llm._internal.serve.deployments.routers.router import (
25-
OpenAiIngress,
26-
make_fastapi_ingress,
27-
)
28-
from ray.serve.deployment import Application
2917
from ray.serve.handle import DeploymentHandle
30-
from ray.serve.llm import (
31-
LLMConfig,
32-
build_llm_deployment,
33-
)
18+
from ray.serve.llm import LLMConfig
3419

3520
logger = logging.getLogger(__name__)
3621
RequestType = Union[ChatCompletionRequest, CompletionRequest]
3722

38-
39-
class PDServingArgs(BaseModelExtended):
40-
"""Schema for P/D serving args."""
41-
42-
prefill_config: Union[str, LLMConfig]
43-
decode_config: Union[str, LLMConfig]
44-
proxy_deployment_config: Dict[str, Any] = Field(
45-
default_factory=dict,
46-
description="""
47-
The Ray @server.deployment options for the proxy server.
48-
""",
49-
)
50-
51-
def parse_args(self) -> "PDServingArgs":
52-
"""Converts this LLMServingArgs object into an DeployArgs object."""
53-
54-
def parse_configs_and_cast_type(config: Union[str, LLMConfig]) -> LLMConfig:
55-
# ray.serve.llm.__init__ imports internal LLMConfig, and extends it to external-facing LLMConfig.
56-
# parse_llm_configs returns internal LLMConfig, while {prefill, decode}_configs expect external-facing LLMConfig.
57-
# So the model_dump() here is to convert the type, to satisfy pydantic.
58-
# TODO(lk-chen): refactor llm_config parsing to avoid this model_dump, and make llm_config more reusable.
59-
config = parse_llm_configs([config])[0]
60-
return LLMConfig(**config.model_dump())
61-
62-
return PDServingArgs(
63-
# Parse string file path into LLMConfig
64-
prefill_config=parse_configs_and_cast_type(self.prefill_config),
65-
decode_config=parse_configs_and_cast_type(self.decode_config),
66-
proxy_deployment_config=self.proxy_deployment_config,
67-
)
23+
DEFAULT_PD_PROXY_SERVER_OPTIONS = {
24+
"max_ongoing_requests": DEFAULT_MAX_ONGOING_REQUESTS,
25+
}
6826

6927

7028
class PDProxyServer(LLMServer):
@@ -171,45 +129,8 @@ async def completions(
171129
) -> AsyncGenerator[Union[str, CompletionResponse, ErrorResponse], None]:
172130
return self._handle_request(request)
173131

174-
175-
def build_pd_openai_app(pd_serving_args: dict) -> Application:
176-
"""Build a deployable application utilizing prefill/decode disaggregation."""
177-
178-
pd_config = PDServingArgs.model_validate(pd_serving_args).parse_args()
179-
180-
model_id = pd_config.decode_config.model_id
181-
assert model_id == pd_config.prefill_config.model_id, "P/D model id mismatch"
182-
183-
for config in [pd_config.prefill_config, pd_config.decode_config]:
184-
if "kv_transfer_config" not in config.engine_kwargs:
185-
config.update_engine_kwargs(
186-
kv_transfer_config=dict(
187-
kv_connector="NixlConnector",
188-
kv_role="kv_both",
189-
engine_id=str(uuid.uuid4()),
190-
)
191-
)
192-
193-
prefill_deployment = build_llm_deployment(
194-
pd_config.prefill_config, name_prefix="Prefill:"
195-
)
196-
decode_deployment = build_llm_deployment(
197-
pd_config.decode_config, name_prefix="Decode:"
198-
)
199-
200-
proxy_server_deployment = (
201-
serve.deployment(PDProxyServer)
202-
.options(**pd_config.proxy_deployment_config)
203-
.bind(
204-
prefill_server=prefill_deployment,
205-
decode_server=decode_deployment,
206-
)
207-
)
208-
209-
ingress_options = OpenAiIngress.get_deployment_options(
210-
[pd_config.prefill_config, pd_config.decode_config]
211-
)
212-
ingress_cls = make_fastapi_ingress(OpenAiIngress)
213-
return serve.deployment(ingress_cls, **ingress_options).bind(
214-
llm_deployments=[proxy_server_deployment]
215-
)
132+
@classmethod
133+
def get_deployment_options(
134+
cls, prefill_config: "LLMConfig", decode_config: "LLMConfig"
135+
) -> Dict[str, Any]:
136+
return DEFAULT_PD_PROXY_SERVER_OPTIONS

0 commit comments

Comments
 (0)