Skip to content

Commit

Permalink
[Serve] Http proxy & router & handle to support multiplex impl (#35399)
Browse files Browse the repository at this point in the history
- Support handle.options(multiplexed_model_id="")
- Http proxy to extract model id on the fly
- Choose correct replica based on information.
- nit: Move handle metrics pusher to router.
  • Loading branch information
sihanwang41 authored May 17, 2023
1 parent 5ec3a36 commit 9ac0d44
Show file tree
Hide file tree
Showing 9 changed files with 370 additions and 156 deletions.
76 changes: 1 addition & 75 deletions python/ray/serve/_private/autoscaling_metrics.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,14 @@
import bisect
import logging
import threading
import time
from collections import defaultdict
from dataclasses import dataclass, field
from threading import Event
from typing import Callable, DefaultDict, Dict, List, Optional, Type
from typing import DefaultDict, Dict, List, Optional

import ray
from ray.serve._private.constants import SERVE_LOGGER_NAME

logger = logging.getLogger(SERVE_LOGGER_NAME)


def start_metrics_pusher(
interval_s: float,
collection_callback: Callable[[], Dict[str, float]],
metrics_process_func: Callable[[Dict[str, float], float], ray.ObjectRef],
stop_event: Type[Event] = None,
):
"""Start a background thread to push metrics to controller.
We use this background so it will be not blocked by user's code and ensure
consistently metrics delivery. Python GIL will ensure that this thread gets
fair timeshare to execute and run.
Stop_event is passed in only when a RayServeHandle calls this function to
push metrics for scale-to-zero. stop_event is set either when the handle
is garbage collected or when the Serve application shuts down.
Args:
interval_s: the push interval.
collection_callback: a callable that returns the metric data points to
be sent to the the controller. The collection callback should take
no argument and returns a dictionary of str_key -> float_value.
metrics_process_func: actor handle function.
stop_event: the backgroupd thread will be closed when this event is set
Returns:
timer: The background thread created by this function to push
metrics to the controller
"""

def send_once():
data = collection_callback()

# TODO(simon): maybe wait for ack or handle controller failure?
return metrics_process_func(data=data, send_timestamp=time.time())

def send_forever(stop_event):
last_ref: Optional[ray.ObjectRef] = None
last_send_succeeded: bool = True

while True:
start = time.time()
if stop_event and stop_event.is_set():
return

if ray.is_initialized():
try:
if last_ref:
ready_refs, _ = ray.wait([last_ref], timeout=0)
last_send_succeeded = len(ready_refs) == 1
if last_send_succeeded:
last_ref = send_once()
except Exception as e:
logger.warning(
"Autoscaling metrics pusher thread "
"is failing to send metrics to the controller "
f": {e}"
)

duration_s = time.time() - start
remaining_time = interval_s - duration_s
if remaining_time > 0:
time.sleep(remaining_time)

timer = threading.Thread(target=send_forever, args=[stop_event])
# Making this a daemon thread so it doesn't leak upon shutdown, and it
# doesn't need to block the replica's shutdown.
timer.setDaemon(True)
timer.start()
return timer


@dataclass(order=True)
class TimeStampedValue:
timestamp: float
Expand Down
3 changes: 3 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,6 @@ class ServeHandleType(str, Enum):
SERVE_LOG_LEVEL_NAME: "%(levelname)s",
SERVE_LOG_TIME: "%(asctime)s",
}

# Serve HTTP request header key for routing requests.
SERVE_MULTIPLEXED_MODEL_ID = "serve_multiplexed_model_id"
14 changes: 11 additions & 3 deletions python/ray/serve/_private/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
SERVE_LOGGER_NAME,
SERVE_NAMESPACE,
DEFAULT_LATENCY_BUCKET_MS,
SERVE_MULTIPLEXED_MODEL_ID,
)
from ray.serve._private.long_poll import LongPollClient, LongPollNamespace
from ray.serve._private.logging_utils import (
Expand Down Expand Up @@ -424,11 +425,18 @@ async def __call__(self, scope, receive, send):
scope["path"] = route_path.replace(route_prefix, "", 1)
scope["root_path"] = root_path + route_prefix

request_context_info = {
"route": route_path,
"request_id": get_random_letters(10),
"app_name": app_name,
}
start_time = time.time()
for key, value in scope["headers"]:
if key.decode() == SERVE_MULTIPLEXED_MODEL_ID:
request_context_info["multiplexed_model_id"] = value.decode()
break
ray.serve.context._serve_request_context.set(
ray.serve.context.RequestContext(
route_path, get_random_letters(10), app_name
)
ray.serve.context.RequestContext(**request_context_info)
)
status_code = await _send_request_to_handle(handle, scope, receive, send)

Expand Down
16 changes: 10 additions & 6 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ray.serve import metrics
from ray._private.async_compat import sync_to_async

from ray.serve._private.autoscaling_metrics import start_metrics_pusher
from ray.serve._private.common import (
HEALTH_CHECK_CONCURRENCY_GROUP,
ReplicaTag,
Expand Down Expand Up @@ -46,6 +45,7 @@
parse_request_item,
wrap_to_ray_error,
merge_dict,
MetricsPusher,
)
from ray.serve._private.version import DeploymentVersion

Expand Down Expand Up @@ -366,11 +366,12 @@ def user_health_check():
if autoscaling_config:
process_remote_func = controller_handle.record_autoscaling_metrics.remote
config = autoscaling_config
start_metrics_pusher(
interval_s=config.metrics_interval_s,
collection_callback=self._collect_autoscaling_metrics,
metrics_process_func=process_remote_func,
self.metrics_pusher = MetricsPusher(
process_remote_func,
config.metrics_interval_s,
self._collect_autoscaling_metrics,
)
self.metrics_pusher.start()

async def check_health(self):
await self.user_health_check()
Expand Down Expand Up @@ -525,7 +526,10 @@ async def handle_request(self, request: Query) -> asyncio.Future:
# handle can pass the correct request context to subsequent replicas.
ray.serve.context._serve_request_context.set(
ray.serve.context.RequestContext(
request.metadata.route, request.metadata.request_id, self.app_name
request.metadata.route,
request.metadata.request_id,
self.app_name,
request.metadata.multiplexed_model_id,
)
)

Expand Down
154 changes: 121 additions & 33 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,29 @@
import random
import sys
from typing import Any, Dict, List, Optional
from collections import defaultdict

import ray
from ray.actor import ActorHandle
from ray.dag.py_obj_scanner import _PyObjScanner
from ray.exceptions import RayActorError, RayTaskError
from ray.util import metrics

from ray.serve._private.common import RunningReplicaInfo
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray.serve._private.common import RunningReplicaInfo, DeploymentInfo
from ray.serve._private.constants import (
SERVE_LOGGER_NAME,
HANDLE_METRIC_PUSH_INTERVAL_S,
)
from ray.serve._private.long_poll import LongPollClient, LongPollNamespace
from ray.serve._private.utils import (
compute_iterable_delta,
JavaActorHandleProxy,
)
from ray.serve.generated.serve_pb2 import (
RequestMetadata as RequestMetadataProto,
DeploymentRoute,
)
from ray.serve._private.utils import MetricsPusher

logger = logging.getLogger(SERVE_LOGGER_NAME)

Expand All @@ -43,6 +49,9 @@ class RequestMetadata:
# Application Name
app_name: str = ""

# Multiplexed model ID
multiplexed_model_id: str = ""


@dataclass
class Query:
Expand Down Expand Up @@ -107,6 +116,12 @@ def __init__(
{"deployment": self.deployment_name}
)

# A map from multiplexed model id to a list of replicas that have the
# model loaded.
self.multiplexed_replicas_table: Dict[
str, List[RunningReplicaInfo]
] = defaultdict(list)

def _reset_replica_iterator(self):
"""Reset the iterator used to load balance replicas.
Expand All @@ -118,6 +133,13 @@ def _reset_replica_iterator(self):
random.shuffle(replicas)
self.replica_iterator = itertools.cycle(replicas)

# Update the multiplexed_replicas_table
new_multiplexed_replicas_table = defaultdict(list)
for replica in replicas:
for mdoel_id in replica.multiplexed_model_ids:
new_multiplexed_replicas_table[mdoel_id].append(replica)
self.multiplexed_replicas_table = new_multiplexed_replicas_table

def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
added, removed, _ = compute_iterable_delta(
self.in_flight_queries.keys(), running_replicas
Expand All @@ -137,51 +159,100 @@ def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
self._reset_replica_iterator()
self.config_updated_event.set()

def _assign_replica(self, query: Query, replica: RunningReplicaInfo):
"""Assign query to the replica.
Args:
query: Query object, containing the request metadata and args.
replica: Replica object, containing the actor handle to the replica.
Returns: object ref of the requests.
"""

logger.debug(
f"Assigned query {query.metadata.request_id} "
f"to replica {replica.replica_tag}."
)
if replica.is_cross_language:
# Handling requests for Java replica
arg = query.args[0]
if query.metadata.http_arg_is_pickled:
assert isinstance(arg, bytes)
loaded_http_input = pickle.loads(arg)
query_string = loaded_http_input.scope.get("query_string")
if query_string:
arg = query_string.decode().split("=", 1)[1]
elif loaded_http_input.body:
arg = loaded_http_input.body.decode()
user_ref = JavaActorHandleProxy(replica.actor_handle).handle_request.remote(
RequestMetadataProto(
request_id=query.metadata.request_id,
endpoint=query.metadata.endpoint,
call_method=query.metadata.call_method
if query.metadata.call_method != "__call__"
else "call",
).SerializeToString(),
[arg],
)
self.in_flight_queries[replica].add(user_ref)
else:
# Directly passing args because it might contain an ObjectRef.
tracker_ref, user_ref = replica.actor_handle.handle_request.remote(
pickle.dumps(query.metadata), *query.args, **query.kwargs
)
self.in_flight_queries[replica].add(tracker_ref)
return user_ref

def _try_assign_replica(self, query: Query) -> Optional[ray.ObjectRef]:
"""Try to assign query to a replica, return the object ref if succeeded
or return None if it can't assign this query to any replicas.
"""

# Try to find a replica that can handle this query
# If multiplexed model id is not specified, we can assign the query to
# any non-overloaded replica.
# If multiplexed model id is specified, we can try to assign the query
# to a replica that has the specified model loaded and
# is not overloaded with requests.
# If no such replica exists, we can assign the query to any non-overloaded
# replica.
if (
query.metadata.multiplexed_model_id
and query.metadata.multiplexed_model_id in self.multiplexed_replicas_table
):
# Try to find the replica that is already handling the model.
for replica in self.multiplexed_replicas_table[
query.metadata.multiplexed_model_id
]:
if (
len(self.in_flight_queries[replica])
>= replica.max_concurrent_queries
):
# This replica is overloaded, try next one
continue
logger.debug(
f"Assigned query {query.metadata.request_id} "
f"to replica {replica.replica_tag}."
)
return self._assign_replica(query, replica)

for _ in range(len(self.in_flight_queries.keys())):
replica = next(self.replica_iterator)
if len(self.in_flight_queries[replica]) >= replica.max_concurrent_queries:
# This replica is overloaded, try next one
continue

if query.metadata.multiplexed_model_id:
# This query has a multiplexed model id, but the model is not
# loaded on this replica. Save this replica for future queries
# with the same model id.
self.multiplexed_replicas_table[
query.metadata.multiplexed_model_id
].append(replica)

logger.debug(
f"Assigned query {query.metadata.request_id} "
f"to replica {replica.replica_tag}."
)
if replica.is_cross_language:
# Handling requests for Java replica
arg = query.args[0]
if query.metadata.http_arg_is_pickled:
assert isinstance(arg, bytes)
loaded_http_input = pickle.loads(arg)
query_string = loaded_http_input.scope.get("query_string")
if query_string:
arg = query_string.decode().split("=", 1)[1]
elif loaded_http_input.body:
arg = loaded_http_input.body.decode()
user_ref = JavaActorHandleProxy(
replica.actor_handle
).handle_request.remote(
RequestMetadataProto(
request_id=query.metadata.request_id,
endpoint=query.metadata.endpoint,
call_method=query.metadata.call_method
if query.metadata.call_method != "__call__"
else "call",
).SerializeToString(),
[arg],
)
self.in_flight_queries[replica].add(user_ref)
else:
# Directly passing args because it might contain an ObjectRef.
tracker_ref, user_ref = replica.actor_handle.handle_request.remote(
pickle.dumps(query.metadata), *query.args, **query.kwargs
)
self.in_flight_queries[replica].add(tracker_ref)
return user_ref
return self._assign_replica(query, replica)
return None

@property
Expand Down Expand Up @@ -305,6 +376,23 @@ def __init__(
call_in_event_loop=event_loop,
)

# Start the metrics pusher if autoscaling is enabled.
self.deployment_name = deployment_name
deployment_route = DeploymentRoute.FromString(
ray.get(controller_handle.get_deployment_info.remote(self.deployment_name))
)
deployment_info = DeploymentInfo.from_proto(deployment_route.deployment_info)
if deployment_info.deployment_config.autoscaling_config:
self.metrics_pusher = MetricsPusher(
controller_handle.record_handle_metrics.remote,
HANDLE_METRIC_PUSH_INTERVAL_S,
self._collect_handle_queue_metrics,
)
self.metrics_pusher.start()

def _collect_handle_queue_metrics(self) -> Dict[str, int]:
return {self.deployment_name: self.get_num_queued_queries()}

def get_num_queued_queries(self):
return self._replica_set.num_queued_queries

Expand Down
Loading

0 comments on commit 9ac0d44

Please sign in to comment.