Skip to content

Commit fa625a6

Browse files
authored
[3/n] [Serve] Defer rank assignment after replica is allocated (#58477)
**Summary** Modified replica rank assignment to defer rank allocation until the replica is actually allocated, rather than assigning it during the startup call. This is necessary when we want to add node local rank in future, in order to support node rank and node local rank we need to know the node_id which is only known after replica is allocated. **Changes** - Changed `start()` method signature to accept `assign_rank_callback` instead of a pre-assigned `rank` parameter - Rank is now assigned after `_allocated_obj_ref` is resolved, ensuring the replica is allocated before rank assignment - Pass rank to `initialize_and_get_metadata()` method on the replica actor, allowing rank to be set during initialization - Updated `ReplicaBase.initialize()` to accept rank as a parameter and set it along with the internal replica context - Added `PENDING_INITIALIZATION` status check to handle cases where `_ready_obj_ref` is not yet set Next PR #58479 --------- Signed-off-by: abrar <abrar@anyscale.com>
1 parent ad35438 commit fa625a6

File tree

4 files changed

+95
-54
lines changed

4 files changed

+95
-54
lines changed

python/ray/serve/_private/deployment_state.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def __init__(
254254
self._docs_path: Optional[str] = None
255255
self._route_patterns: Optional[List[str]] = None
256256
# Rank assigned to the replica.
257+
self._assign_rank_callback: Optional[Callable[[ReplicaID], ReplicaRank]] = None
257258
self._rank: Optional[ReplicaRank] = None
258259
# Populated in `on_scheduled` or `recover`.
259260
self._actor_handle: ActorHandle = None
@@ -445,14 +446,16 @@ def initialization_latency_s(self) -> Optional[float]:
445446
return self._initialization_latency_s
446447

447448
def start(
448-
self, deployment_info: DeploymentInfo, rank: ReplicaRank
449+
self,
450+
deployment_info: DeploymentInfo,
451+
assign_rank_callback: Callable[[ReplicaID], ReplicaRank],
449452
) -> ReplicaSchedulingRequest:
450453
"""Start the current DeploymentReplica instance.
451454
452455
The replica will be in the STARTING and PENDING_ALLOCATION states
453456
until the deployment scheduler schedules the underlying actor.
454457
"""
455-
self._rank = rank # Store the rank assigned to this replica
458+
self._assign_rank_callback = assign_rank_callback
456459
self._actor_resources = deployment_info.replica_config.resource_dict
457460
self._ingress = deployment_info.ingress
458461
# it is currently not possible to create a placement group
@@ -496,7 +499,6 @@ def start(
496499
self._version,
497500
deployment_info.ingress,
498501
deployment_info.route_prefix,
499-
rank,
500502
)
501503
# TODO(simon): unify the constructor arguments across language
502504
elif (
@@ -577,31 +579,11 @@ def on_scheduled(
577579
self._actor_handle = actor_handle
578580
self._placement_group = placement_group
579581

580-
# Perform auto method name translation for java handles.
581-
# See https://github.com/ray-project/ray/issues/21474
582-
deployment_config = copy(self._version.deployment_config)
583-
deployment_config.user_config = self._format_user_config(
584-
deployment_config.user_config
585-
)
586582
if self._is_cross_language:
587583
self._actor_handle = JavaActorHandleProxy(self._actor_handle)
588584
self._allocated_obj_ref = self._actor_handle.is_allocated.remote()
589-
self._ready_obj_ref = self._actor_handle.is_initialized.remote(
590-
deployment_config.to_proto_bytes()
591-
)
592585
else:
593586
self._allocated_obj_ref = self._actor_handle.is_allocated.remote()
594-
replica_ready_check_func = self._actor_handle.initialize_and_get_metadata
595-
self._ready_obj_ref = replica_ready_check_func.remote(
596-
deployment_config,
597-
# Ensure that `is_allocated` will execute
598-
# before `initialize_and_get_metadata`,
599-
# because `initialize_and_get_metadata` runs
600-
# user code that could block the replica
601-
# asyncio loop. If that happens before `is_allocated` is executed,
602-
# the `is_allocated` call won't be able to run.
603-
self._allocated_obj_ref,
604-
)
605587

606588
def _format_user_config(self, user_config: Any):
607589
temp = copy(user_config)
@@ -746,6 +728,28 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]:
746728
logger.exception(msg)
747729
return ReplicaStartupStatus.FAILED, msg
748730

731+
if self._ready_obj_ref is None:
732+
# Perform auto method name translation for java handles.
733+
# See https://github.com/ray-project/ray/issues/21474
734+
deployment_config = copy(self._version.deployment_config)
735+
deployment_config.user_config = self._format_user_config(
736+
deployment_config.user_config
737+
)
738+
if self._is_cross_language:
739+
self._ready_obj_ref = self._actor_handle.is_initialized.remote(
740+
deployment_config.to_proto_bytes()
741+
)
742+
else:
743+
replica_ready_check_func = (
744+
self._actor_handle.initialize_and_get_metadata
745+
)
746+
self._rank = self._assign_rank_callback(self._replica_id.unique_id)
747+
self._ready_obj_ref = replica_ready_check_func.remote(
748+
deployment_config, self._rank
749+
)
750+
751+
return ReplicaStartupStatus.PENDING_INITIALIZATION, None
752+
749753
# Check whether replica initialization has completed.
750754
replica_ready = check_obj_ref_ready_nowait(self._ready_obj_ref)
751755
# In case of deployment constructor failure, ray.get will help to
@@ -1173,12 +1177,16 @@ def initialization_latency_s(self) -> Optional[float]:
11731177
return self._actor.initialization_latency_s
11741178

11751179
def start(
1176-
self, deployment_info: DeploymentInfo, rank: ReplicaRank
1180+
self,
1181+
deployment_info: DeploymentInfo,
1182+
assign_rank_callback: Callable[[ReplicaID], ReplicaRank],
11771183
) -> ReplicaSchedulingRequest:
11781184
"""
11791185
Start a new actor for current DeploymentReplica instance.
11801186
"""
1181-
replica_scheduling_request = self._actor.start(deployment_info, rank=rank)
1187+
replica_scheduling_request = self._actor.start(
1188+
deployment_info, assign_rank_callback=assign_rank_callback
1189+
)
11821190
self._start_time = time.time()
11831191
self._logged_shutdown_message = False
11841192
self.update_actor_details(start_time_s=self._start_time)
@@ -2544,18 +2552,13 @@ def scale_deployment_replicas(
25442552
for _ in range(to_add):
25452553
replica_id = ReplicaID(get_random_string(), deployment_id=self._id)
25462554

2547-
# Assign rank during replica creation (startup process)
2548-
assigned_rank = self._rank_manager.assign_rank(replica_id.unique_id)
2549-
2550-
logger.debug(
2551-
f"Assigned rank {assigned_rank.rank} to new replica {replica_id.unique_id} during startup"
2552-
)
25532555
new_deployment_replica = DeploymentReplica(
25542556
replica_id,
25552557
self._target_state.version,
25562558
)
25572559
scheduling_request = new_deployment_replica.start(
2558-
self._target_state.info, rank=assigned_rank
2560+
self._target_state.info,
2561+
assign_rank_callback=self._rank_manager.assign_rank,
25592562
)
25602563

25612564
upscale.append(scheduling_request)

python/ray/serve/_private/replica.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,6 @@ def __init__(
511511
version: DeploymentVersion,
512512
ingress: bool,
513513
route_prefix: str,
514-
rank: ReplicaRank,
515514
):
516515
self._version = version
517516
self._replica_id = replica_id
@@ -562,7 +561,7 @@ def __init__(
562561

563562
# Set metadata for logs and metrics.
564563
# servable_object will be populated in `initialize_and_get_metadata`.
565-
self._set_internal_replica_context(servable_object=None, rank=rank)
564+
self._set_internal_replica_context(servable_object=None, rank=None)
566565

567566
self._metrics_manager = create_replica_metrics_manager(
568567
replica_id=replica_id,
@@ -576,7 +575,7 @@ def __init__(
576575
self._http_port: Optional[int] = None
577576
self._grpc_port: Optional[int] = None
578577

579-
self._rank = rank
578+
self._rank: Optional[ReplicaRank] = None
580579

581580
@property
582581
def max_ongoing_requests(self) -> int:
@@ -951,7 +950,14 @@ async def handle_request_with_rejection(
951950
async def _on_initialized(self):
952951
raise NotImplementedError
953952

954-
async def initialize(self, deployment_config: DeploymentConfig):
953+
async def initialize(
954+
self, deployment_config: Optional[DeploymentConfig], rank: Optional[ReplicaRank]
955+
):
956+
if rank is not None:
957+
self._rank = rank
958+
self._set_internal_replica_context(
959+
servable_object=self._user_callable_wrapper.user_callable, rank=rank
960+
)
955961
try:
956962
# Ensure that initialization is only performed once.
957963
# When controller restarts, it will call this method again.
@@ -982,7 +988,7 @@ async def initialize(self, deployment_config: DeploymentConfig):
982988
record_autoscaling_stats_fn=self._user_callable_wrapper.call_record_autoscaling_stats,
983989
)
984990

985-
if deployment_config:
991+
if deployment_config is not None:
986992
await self._user_callable_wrapper.set_sync_method_threadpool_limit(
987993
deployment_config.max_ongoing_requests
988994
)
@@ -1227,7 +1233,6 @@ async def __init__(
12271233
version: DeploymentVersion,
12281234
ingress: bool,
12291235
route_prefix: str,
1230-
rank: ReplicaRank,
12311236
):
12321237
deployment_config = DeploymentConfig.from_proto_bytes(
12331238
deployment_config_proto_bytes
@@ -1244,7 +1249,6 @@ async def __init__(
12441249
version=version,
12451250
ingress=ingress,
12461251
route_prefix=route_prefix,
1247-
rank=rank,
12481252
)
12491253

12501254
def push_proxy_handle(self, handle: ActorHandle):
@@ -1287,7 +1291,7 @@ def list_outbound_deployments(self) -> Optional[List[DeploymentID]]:
12871291
return self._replica_impl.list_outbound_deployments()
12881292

12891293
async def initialize_and_get_metadata(
1290-
self, deployment_config: DeploymentConfig = None, _after: Optional[Any] = None
1294+
self, deployment_config: DeploymentConfig = None, rank: ReplicaRank = None
12911295
) -> ReplicaMetadata:
12921296
"""Handles initializing the replica.
12931297
@@ -1300,7 +1304,7 @@ async def initialize_and_get_metadata(
13001304
"""
13011305
# Unused `_after` argument is for scheduling: passing an ObjectRef
13021306
# allows delaying this call until after the `_after` call has returned.
1303-
await self._replica_impl.initialize(deployment_config)
1307+
await self._replica_impl.initialize(deployment_config, rank)
13041308
return self._replica_impl.get_metadata()
13051309

13061310
async def check_health(self):

python/ray/serve/tests/test_cluster.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,30 +176,56 @@ def get_replicas(replica_state):
176176
)
177177
return replicas.get([replica_state])
178178

179-
# wait for serve to start the replica, and catch a reference to it.
179+
# wait for serve to start the replica
180180
wait_for_condition(lambda: len(get_replicas(ReplicaState.STARTING)) > 0)
181-
replica = get_replicas(ReplicaState.STARTING)[0]
182181

183182
# currently there are no resources to allocate the replica
184-
assert replica.check_started()[0] == ReplicaStartupStatus.PENDING_ALLOCATION
183+
def get_starting_replica():
184+
replicas = get_replicas(ReplicaState.STARTING)
185+
return replicas[0] if replicas else None
186+
187+
def is_pending_allocation():
188+
replica = get_starting_replica()
189+
if replica is None:
190+
return False
191+
return replica.check_started()[0] == ReplicaStartupStatus.PENDING_ALLOCATION
192+
193+
wait_for_condition(is_pending_allocation)
185194

186195
# add the necessary resources to allocate the replica
187196
cluster.add_node(num_cpus=4)
188197
wait_for_condition(lambda: (ray.cluster_resources().get("CPU", 0) >= 4))
189198
wait_for_condition(lambda: (ray.available_resources().get("CPU", 0) >= 2))
190199

191200
def is_replica_pending_initialization():
201+
replica = get_starting_replica()
202+
if replica is None:
203+
return False
192204
status, _ = replica.check_started()
193-
print(status)
194205
return status == ReplicaStartupStatus.PENDING_INITIALIZATION
195206

196207
wait_for_condition(is_replica_pending_initialization, timeout=25)
197208

198209
# send signal to complete replica initialization
199-
signal.send.remote()
200-
wait_for_condition(
201-
lambda: replica.check_started()[0] == ReplicaStartupStatus.SUCCEEDED
202-
)
210+
ray.get(signal.send.remote())
211+
212+
def check_succeeded():
213+
# After initialization succeeds, replica transitions to RUNNING state
214+
# So check both STARTING and RUNNING states
215+
replica = get_starting_replica()
216+
if replica:
217+
status, _ = replica.check_started()
218+
if status == ReplicaStartupStatus.SUCCEEDED:
219+
return True
220+
221+
# Check if replica has moved to RUNNING state (which means it succeeded)
222+
running_replicas = get_replicas(ReplicaState.RUNNING)
223+
if running_replicas and len(running_replicas) > 0:
224+
return True
225+
226+
return False
227+
228+
wait_for_condition(check_succeeded)
203229

204230

205231
@serve.deployment

python/ray/serve/tests/unit/test_deployment_state.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22
from copy import deepcopy
3-
from typing import Any, Dict, List, Optional, Tuple
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
44
from unittest.mock import Mock, patch
55

66
import pytest
@@ -108,6 +108,7 @@ def __init__(
108108
self._initialization_latency_s = -1
109109
self._docs_path = None
110110
self._rank = replica_rank_context.get(replica_id.unique_id, None)
111+
self._assign_rank_callback = None
111112

112113
@property
113114
def is_cross_language(self) -> bool:
@@ -226,10 +227,15 @@ def set_node_id(self, node_id: str):
226227
def set_actor_id(self, actor_id: str):
227228
self._actor_id = actor_id
228229

229-
def start(self, deployment_info: DeploymentInfo, rank: ReplicaRank):
230+
def start(
231+
self,
232+
deployment_info: DeploymentInfo,
233+
assign_rank_callback: Callable[[ReplicaID], ReplicaRank],
234+
):
230235
self.started = True
231-
self._rank = rank
232-
replica_rank_context[self._replica_id.unique_id] = rank
236+
self._assign_rank_callback = assign_rank_callback
237+
self._rank = assign_rank_callback(self._replica_id.unique_id)
238+
replica_rank_context[self._replica_id.unique_id] = self._rank
233239

234240
def _on_scheduled_stub(*args, **kwargs):
235241
pass
@@ -2685,7 +2691,9 @@ def test_max_concurrency_override(self):
26852691
)
26862692
max_ongoing_requests = DEFAULT_MAX_CONCURRENCY_ASYNC + 1
26872693
d_info, _ = deployment_info(max_ongoing_requests=max_ongoing_requests)
2688-
replica_scheduling_request = actor_replica.start(d_info, rank=0)
2694+
replica_scheduling_request = actor_replica.start(
2695+
d_info, assign_rank_callback=lambda x: 0
2696+
)
26892697
assert (
26902698
"max_concurrency" in replica_scheduling_request.actor_options
26912699
and replica_scheduling_request.actor_options["max_concurrency"]

0 commit comments

Comments
 (0)