Skip to content

Commit 7e87283

Browse files
authored
[serve][llm] Fix ReplicaContext serialization error in DPRankAssigner (#58504)
Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
1 parent cd09d10 commit 7e87283

File tree

3 files changed

+5
-7
lines changed

3 files changed

+5
-7
lines changed

doc/source/data/doc_code/working-with-llms/minimal_quickstart.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
model_source="unsloth/Llama-3.1-8B-Instruct",
2727
concurrency=1, # 1 vLLM engine replica
2828
batch_size=32, # 32 samples per batch
29+
engine_kwargs={
30+
"max_model_len": 4096, # Fit into test GPU memory
31+
}
2932
)
3033

3134
# Build processor

python/ray/llm/_internal/serve/serving_patterns/data_parallel/dp_rank_assigner.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,11 @@ def __init__(self, dp_size: int, dp_size_per_node: Optional[int] = None):
5454
f"with dp_size_per_node {self.dp_size_per_node}"
5555
)
5656

57-
async def register(
58-
self, replica_ctx: "serve.context.ReplicaContext", node_id: Optional[str] = None
59-
):
57+
async def register(self, node_id: Optional[str] = None):
6058
"""
6159
Register a replica and assign a rank to it.
6260
6361
Args:
64-
replica_ctx: The replica context.
6562
node_id: The node id of the replica.
6663
6764
Returns:

python/ray/llm/_internal/serve/serving_patterns/data_parallel/dp_server.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22
import time
33

4-
from ray import serve
54
from ray.experimental.collective.util import get_address_and_port
65
from ray.llm._internal.serve.core.configs.llm_config import LLMConfig
76
from ray.llm._internal.serve.core.server.llm_server import LLMServer
@@ -24,9 +23,8 @@ class DPServer(LLMServer):
2423
async def __init__(self, llm_config: LLMConfig, dp_rank_assigner: DeploymentHandle):
2524
self.dp_rank_assigner = dp_rank_assigner
2625

27-
replica_ctx = serve.get_replica_context()
2826
node_id = get_runtime_context().get_node_id()
29-
self.dp_rank = await self.dp_rank_assigner.register.remote(replica_ctx, node_id)
27+
self.dp_rank = await self.dp_rank_assigner.register.remote(node_id)
3028

3129
logger.info(f"DP rank {self.dp_rank} registered with rank assigner")
3230

0 commit comments

Comments
 (0)