Skip to content

Commit

Permalink
[serve] immediately send ping in router when receiving new replica set (
Browse files Browse the repository at this point in the history
ray-project#47053)

When a new set of `RunningReplicaInfos` are broadcasted to a router, the
nested actor handles are "empty" and don't hold the necessary actor info
(e.g. actor address) to send a request to that replica. Upon first
request, the handle fetches that info from the GCS. If the GCS goes down
immediately after a replica set change is broadcasted to a router, requests
will all be blocked until the GCS recovers.

Fix:
- Upon receiving a new replica set, the router actively probes the queue
lengths for each replica.
- On proxies, also push its self actor handle to replicas upon replica set
change, else proxy requests to new replicas will hang when GCS is down.

Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
  • Loading branch information
zcin authored and simonsays1980 committed Aug 15, 2024
1 parent a137979 commit aac0a6e
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 14 deletions.
5 changes: 4 additions & 1 deletion python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import ray
from ray import cloudpickle
from ray._private.utils import get_or_create_event_loop
from ray.actor import ActorClass
from ray.actor import ActorClass, ActorHandle
from ray.remote_function import RemoteFunction
from ray.serve import metrics
from ray.serve._private.common import (
Expand Down Expand Up @@ -321,6 +321,9 @@ def _configure_logger_and_profilers(
component_id=self._component_id,
)

def push_proxy_handle(self, handle: ActorHandle):
pass

def get_num_ongoing_requests(self) -> int:
"""Fetch the number of ongoing requests at this replica (queue length).
Expand Down
8 changes: 8 additions & 0 deletions python/ray/serve/_private/replica_scheduler/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import ray
from ray import ObjectRef, ObjectRefGenerator
from ray.actor import ActorHandle
from ray.serve._private.common import (
ReplicaID,
ReplicaQueueLengthInfo,
Expand Down Expand Up @@ -58,6 +59,10 @@ def max_ongoing_requests(self) -> int:
"""Max concurrent requests that can be sent to this replica."""
pass

def push_proxy_handle(self, handle: ActorHandle):
"""When on proxy, push proxy's self handle to replica"""
pass

async def get_queue_len(self, *, deadline_s: float) -> int:
"""Returns current queue len for the replica.
Expand Down Expand Up @@ -120,6 +125,9 @@ def max_ongoing_requests(self) -> int:
def is_cross_language(self) -> bool:
return self._replica_info.is_cross_language

def push_proxy_handle(self, handle: ActorHandle):
self._actor_handle.push_proxy_handle.remote(handle)

async def get_queue_len(self, *, deadline_s: float) -> int:
# NOTE(edoakes): the `get_num_ongoing_requests` method name is shared by
# the Python and Java replica implementations. If you change it, you need to
Expand Down
22 changes: 22 additions & 0 deletions python/ray/serve/_private/replica_scheduler/pow_2_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
Tuple,
)

from ray.actor import ActorHandle
from ray.exceptions import ActorDiedError, ActorUnavailableError
from ray.serve._private.common import (
DeploymentHandleSource,
DeploymentID,
ReplicaID,
RequestMetadata,
Expand Down Expand Up @@ -89,19 +91,23 @@ def __init__(
self,
event_loop: asyncio.AbstractEventLoop,
deployment_id: DeploymentID,
handle_source: DeploymentHandleSource,
prefer_local_node_routing: bool = False,
prefer_local_az_routing: bool = False,
self_node_id: Optional[str] = None,
self_actor_id: Optional[str] = None,
self_actor_handle: Optional[ActorHandle] = None,
self_availability_zone: Optional[str] = None,
use_replica_queue_len_cache: bool = False,
get_curr_time_s: Optional[Callable[[], float]] = None,
):
self._loop = event_loop
self._deployment_id = deployment_id
self._handle_source = handle_source
self._prefer_local_node_routing = prefer_local_node_routing
self._prefer_local_az_routing = prefer_local_az_routing
self._self_node_id = self_node_id
self._self_actor_handle = self_actor_handle
self._self_availability_zone = self_availability_zone
self._use_replica_queue_len_cache = use_replica_queue_len_cache

Expand Down Expand Up @@ -240,7 +246,17 @@ def update_replicas(self, replicas: List[ReplicaWrapper]):
new_replica_id_set = set()
new_colocated_replica_ids = defaultdict(set)
new_multiplexed_model_id_to_replica_ids = defaultdict(set)

for r in replicas:
# If on the proxy, replica needs to call back into the proxy with
# `receive_asgi_messages` which can be blocked when GCS is down.
# To prevent that from happening, push proxy handle eagerly
if (
self._handle_source == DeploymentHandleSource.PROXY
and r.replica_id not in self._replicas
):
r.push_proxy_handle(self._self_actor_handle)

new_replicas[r.replica_id] = r
new_replica_id_set.add(r.replica_id)
if self._self_node_id is not None and r.node_id == self._self_node_id:
Expand All @@ -263,6 +279,10 @@ def update_replicas(self, replicas: List[ReplicaWrapper]):
extra={"log_to_stderr": False},
)

# Get list of new replicas
new_ids = new_replica_id_set - self._replica_id_set
replicas_to_ping = [new_replicas.get(id) for id in new_ids]

self._replicas = new_replicas
self._replica_id_set = new_replica_id_set
self._colocated_replica_ids = new_colocated_replica_ids
Expand All @@ -272,6 +292,8 @@ def update_replicas(self, replicas: List[ReplicaWrapper]):
self._replica_queue_len_cache.remove_inactive_replicas(
active_replica_ids=new_replica_id_set
)
# Populate cache for new replicas
self._loop.create_task(self._probe_queue_lens(replicas_to_ping, 0))
self._replicas_updated_event.set()
self.maybe_start_scheduling_tasks()

Expand Down
4 changes: 4 additions & 0 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,14 @@ def __init__(
replica_scheduler = PowerOfTwoChoicesReplicaScheduler(
self._event_loop,
deployment_id,
handle_source,
_prefer_local_node_routing,
RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING,
self_node_id,
self_actor_id,
ray.get_runtime_context().current_actor
if ray.get_runtime_context().get_actor_id()
else None,
self_availability_zone,
use_replica_queue_len_cache=enable_queue_len_cache,
)
Expand Down
10 changes: 10 additions & 0 deletions python/ray/serve/_private/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import datetime
import os
import threading
import time
from copy import copy, deepcopy
Expand Down Expand Up @@ -197,6 +198,15 @@ def __init__(
self._soft_target_node_id = _soft_target_node_id


@serve.deployment
class GetPID:
def __call__(self):
return os.getpid()


get_pid_entrypoint = GetPID.bind()


def check_ray_stopped():
try:
requests.get("http://localhost:52365/api/ray/version")
Expand Down
149 changes: 149 additions & 0 deletions python/ray/serve/tests/test_gcs_failure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
import ray
from ray import serve
from ray._private.test_utils import wait_for_condition
from ray.serve._private.common import DeploymentID, ReplicaState
from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME
from ray.serve._private.storage.kv_store import KVStoreError, RayInternalKVStore
from ray.serve._private.test_utils import check_apps_running, check_replica_counts
from ray.serve.context import _get_global_client
from ray.serve.handle import DeploymentHandle
from ray.serve.schema import ServeDeploySchema
from ray.tests.conftest import external_redis # noqa: F401


Expand All @@ -27,6 +31,8 @@ def serve_ha(external_redis, monkeypatch): # noqa: F811
serve.start()
yield (address_info, _get_global_client())
ray.shutdown()
# Clear cache and global serve client
serve.shutdown()


@pytest.mark.skipif(
Expand Down Expand Up @@ -105,6 +111,149 @@ def call():
assert pid == call()


def router_populated_with_replicas(handle: DeploymentHandle, threshold: int = 1):
replicas = handle._router._replica_scheduler._replica_id_set
assert len(replicas) >= threshold
return True


@pytest.mark.parametrize("use_proxy", [True, False])
def test_new_router_on_gcs_failure(serve_ha, use_proxy: bool):
"""Test that a new router can send requests to replicas when GCS is down.
Specifically, if a proxy was just brought up or a deployment handle
was just created, and the GCS goes down BEFORE the router is able to
send its first request, new incoming requests should successfully get
sent to replicas during GCS downtime.
"""

@serve.deployment
class Dummy:
def __call__(self):
return os.getpid()

h = serve.run(Dummy.options(num_replicas=2).bind())
# TODO(zcin): We want to test the behavior for when the router
# didn't get a chance to send even a single request yet. However on
# the very first request we record telemetry for whether the
# deployment handle API was used, which will hang when the GCS is
# down. As a workaround for now, avoid recording telemetry so we
# can properly test router behavior when GCS is down. We should look
# into adding a timeout on the kv cache operation. For now, the proxy
# doesn't run into this because we don't record telemetry on proxy
h._recorded_telemetry = True
# Eagerly create router so it receives the replica set instead of
# waiting for the first request
h._get_or_create_router()

wait_for_condition(router_populated_with_replicas, handle=h)

# Kill GCS server before a single request is sent.
ray.worker._global_node.kill_gcs_server()

returned_pids = set()
if use_proxy:
for _ in range(10):
returned_pids.add(
int(requests.get("http://localhost:8000", timeout=0.1).text)
)
else:
for _ in range(10):
returned_pids.add(int(h.remote().result(timeout_s=0.1)))

print("Returned pids:", returned_pids)
assert len(returned_pids) == 2


def test_handle_router_updated_replicas_then_gcs_failure(serve_ha):
"""Test the router's replica set is updated from 1 to 2 replicas, with the first
replica staying the same. Verify that if the GCS goes down before the router
gets a chance to send a request to the second replica, requests can be handled
during GCS failure.
This test uses a plain handle to send requests.
"""

_, client = serve_ha

config = {
"name": "default",
"import_path": "ray.serve._private.test_utils:get_pid_entrypoint",
"route_prefix": "/",
"deployments": [{"name": "GetPID", "num_replicas": 1}],
}
client.deploy_apps(ServeDeploySchema(**{"applications": [config]}))
wait_for_condition(check_apps_running, apps=["default"])

h = serve.get_app_handle("default")
print(h.remote().result())

config["deployments"][0]["num_replicas"] = 2
client.deploy_apps(ServeDeploySchema(**{"applications": [config]}))

wait_for_condition(router_populated_with_replicas, handle=h, threshold=2)

# Kill GCS server before router gets to send request to second replica
ray.worker._global_node.kill_gcs_server()

returned_pids = set()
for _ in range(10):
returned_pids.add(int(h.remote().result(timeout_s=0.1)))

print("Returned pids:", returned_pids)
assert len(returned_pids) == 2


def test_proxy_router_updated_replicas_then_gcs_failure(serve_ha):
"""Test the router's replica set is updated from 1 to 2 replicas, with the first
replica staying the same. Verify that if the GCS goes down before the router
gets a chance to send a request to the second replica, requests can be handled
during GCS failure.
This test sends http requests to the proxy.
"""
_, client = serve_ha

config = {
"name": "default",
"import_path": "ray.serve._private.test_utils:get_pid_entrypoint",
"route_prefix": "/",
"deployments": [{"name": "GetPID", "num_replicas": 1}],
}
client.deploy_apps(ServeDeploySchema(**{"applications": [config]}))
wait_for_condition(check_apps_running, apps=["default"])

r = requests.post("http://localhost:8000")
assert r.status_code == 200, r.text
print(r.text)

config["deployments"][0]["num_replicas"] = 2
client.deploy_apps(ServeDeploySchema(**{"applications": [config]}))

# There is no way to directly check if proxy has received updated replicas,
# so just check for the status. After controller updates status with new
# replicas, proxy should instantly receive updates from long poll
wait_for_condition(
check_replica_counts,
controller=client._controller,
deployment_id=DeploymentID("GetPID", "default"),
total=2,
by_state=[(ReplicaState.RUNNING, 2, None)],
)

# Kill GCS server before router gets to send request to second replica
ray.worker._global_node.kill_gcs_server()

returned_pids = set()
for _ in range(10):
r = requests.post("http://localhost:8000")
assert r.status_code == 200
returned_pids.add(int(r.text))

print("Returned pids:", returned_pids)
assert len(returned_pids) == 2


if __name__ == "__main__":
# When GCS is down, right now some core worker members are not cleared
# properly in ray.shutdown. Given that this is not hi-pri issue,
Expand Down
Loading

0 comments on commit aac0a6e

Please sign in to comment.