From 1f4082c1826834e3231dedd3d868f55f15313288 Mon Sep 17 00:00:00 2001 From: Alind Khare Date: Tue, 7 Jan 2020 12:41:23 -0500 Subject: [PATCH] [Serve] Added conditional enqueue Added conditional pipeline execution - 1. Added a feature for conditional service where it's backend should return prediction results and the predicate boolean (to be passed later on in the pipeline.) 2. Added __predicate__ feature 3. Added service consistency check --- python/ray/experimental/serve/__init__.py | 3 +- python/ray/experimental/serve/api.py | 30 ++++- .../ray/experimental/serve/backend_config.py | 10 +- python/ray/experimental/serve/constants.py | 7 ++ .../examples/echo_conditional_pipeline.py | 74 +++++++++++++ .../serve/examples/echo_pipeline.py | 6 +- python/ray/experimental/serve/handle.py | 38 ++++++- .../experimental/serve/kv_store_service.py | 5 + python/ray/experimental/serve/queues.py | 104 +++++++++++++----- .../ray/experimental/serve/request_params.py | 22 +++- python/ray/experimental/serve/server.py | 8 +- python/ray/experimental/serve/task_runner.py | 98 ++++++++++++++--- .../ray/experimental/serve/tests/test_api.py | 2 + .../experimental/serve/tests/test_queue.py | 36 ++++-- .../serve/tests/test_task_runner.py | 17 ++- 15 files changed, 380 insertions(+), 80 deletions(-) create mode 100644 python/ray/experimental/serve/examples/echo_conditional_pipeline.py diff --git a/python/ray/experimental/serve/__init__.py b/python/ray/experimental/serve/__init__.py index 4d86ca222ec7d..0664898ac4cc0 100644 --- a/python/ray/experimental/serve/__init__.py +++ b/python/ray/experimental/serve/__init__.py @@ -1,6 +1,7 @@ import sys from ray.experimental.serve.backend_config import BackendConfig from ray.experimental.serve.policy import RoutePolicy +from ray.experimental.serve.constants import RESULT_KEY, PREDICATE_KEY if sys.version_info < (3, 0): raise ImportError("serve is Python 3 only.") @@ -10,5 +11,5 @@ __all__ = [ "init", "create_backend", "create_endpoint", "link", "split", "get_handle", "stat", "set_backend_config", "get_backend_config", "BackendConfig", - "RoutePolicy", "accept_batch" + "RoutePolicy", "accept_batch", "RESULT_KEY", "PREDICATE_KEY" ] diff --git a/python/ray/experimental/serve/api.py b/python/ray/experimental/serve/api.py index e1ed0e2bd3325..a3814d4578fd4 100644 --- a/python/ray/experimental/serve/api.py +++ b/python/ray/experimental/serve/api.py @@ -204,7 +204,8 @@ def get_backend_config(backend_tag): def create_backend(func_or_class, backend_tag, *actor_init_args, - backend_config=BackendConfig()): + backend_config=BackendConfig(), + predicate_function=None): """Create a backend using func_or_class and assign backend_tag. Args: @@ -216,6 +217,8 @@ def create_backend(func_or_class, for starting a backend. *actor_init_args (optional): the argument to pass to the class initialization method. + predicate_function(callable): a function which returns boolean values + for conditional enqueuing. """ assert isinstance(backend_config, BackendConfig), ("backend_config must be" @@ -234,16 +237,26 @@ def create_backend(func_or_class, if should_accept_batch and not hasattr(func_or_class, "serve_accept_batch"): raise batch_annotation_not_found - + if backend_config.enable_predicate and predicate_function is None: + raise RayServeException( + "For enabling predicate, Specify predicate_function.") # arg list for a fn is function itself arg_list = [func_or_class] + # add predicate function to args + if backend_config.enable_predicate: + arg_list.append(predicate_function) + # ignore lint on lambda expression creator = lambda kwrgs: TaskRunnerActor._remote(**kwrgs) # noqa: E731 elif inspect.isclass(func_or_class): if should_accept_batch and not hasattr(func_or_class.__call__, "serve_accept_batch"): raise batch_annotation_not_found - + if backend_config.enable_predicate and not hasattr( + func_or_class, "__predicate__"): + raise RayServeException( + "For enabling predicate, implement __predicate__ function " + "in backend class.") # Python inheritance order is right-to-left. We put RayServeMixin # on the left to make sure its methods are not overriden. @ray.remote @@ -297,7 +310,10 @@ def _start_replica(backend_tag): # Setup the worker ray.get( runner_handle._ray_serve_setup.remote( - backend_tag, global_state.init_or_get_router(), runner_handle)) + backend_tag, + global_state.init_or_get_router(), + runner_handle, + predicate_required=backend_config.enable_predicate)) runner_handle._ray_serve_fetch.remote() # Register the worker in config tables as well as metric monitor @@ -392,10 +408,16 @@ def split(endpoint_name, traffic_policy_dictionary): assert isinstance(traffic_policy_dictionary, dict), "Traffic policy must be dictionary" prob = 0 + backend_predicates = [] for backend, weight in traffic_policy_dictionary.items(): prob += weight assert (backend in global_state.backend_table.list_backends() ), "backend {} is not registered".format(backend) + backend_predicates.append( + global_state.backend_table.get_backend_predicate(backend)) + assert len(set(backend_predicates)) == 1, ("Provided backends are not" + "consistent wrt to predicate" + "feature") assert np.isclose( prob, 1, atol=0.02), "weights must sum to 1, currently it sums to {}".format( diff --git a/python/ray/experimental/serve/backend_config.py b/python/ray/experimental/serve/backend_config.py index d4cde75f54e78..ea60aabeb5de2 100644 --- a/python/ray/experimental/serve/backend_config.py +++ b/python/ray/experimental/serve/backend_config.py @@ -4,11 +4,13 @@ class BackendConfig: # configs not needed for actor creation when # instantiating a replica - _serve_configs = ["_num_replicas", "max_batch_size"] + _serve_configs = ["_num_replicas", "max_batch_size", "enable_predicate"] # configs which when changed leads to restarting # the existing replicas. - restart_on_change_fields = ["resources", "num_cpus", "num_gpus"] + restart_on_change_fields = [ + "resources", "num_cpus", "num_gpus", "enable_predicate" + ] def __init__(self, num_replicas=1, @@ -17,7 +19,8 @@ def __init__(self, num_cpus=None, num_gpus=None, memory=None, - object_store_memory=None): + object_store_memory=None, + enable_predicate=False): """ Class for defining backend configuration. """ @@ -32,6 +35,7 @@ def __init__(self, self.num_gpus = num_gpus self.memory = memory self.object_store_memory = object_store_memory + self.enable_predicate = enable_predicate @property def num_replicas(self): diff --git a/python/ray/experimental/serve/constants.py b/python/ray/experimental/serve/constants.py index 1b15c1e96e257..29c15527dcc36 100644 --- a/python/ray/experimental/serve/constants.py +++ b/python/ray/experimental/serve/constants.py @@ -15,3 +15,10 @@ #: HTTP Port DEFAULT_HTTP_PORT = 8000 + +#: Return ObjectIDs keys for a dictionary +RESULT_KEY = "result" +PREDICATE_KEY = "predicate" + +# default value to pass when Enqeueue Predicate is False +PREDICATE_DEFAULT_VALUE = "predicate-false" diff --git a/python/ray/experimental/serve/examples/echo_conditional_pipeline.py b/python/ray/experimental/serve/examples/echo_conditional_pipeline.py new file mode 100644 index 0000000000000..a5b02d4734e58 --- /dev/null +++ b/python/ray/experimental/serve/examples/echo_conditional_pipeline.py @@ -0,0 +1,74 @@ +""" +Ray serve conditional pipeline example +""" +import ray +import ray.experimental.serve as serve +from ray.experimental.serve import BackendConfig + +# initialize ray serve system. +# blocking=True will wait for HTTP server to be ready to serve request. +serve.init(blocking=True) + + +# This is an example of conditional backend implementation +def echo_v1(_, num): + return num + + +def echo_v1_predicate(num): + return num < 0.5 + + +def echo_v2(_, relay=""): + return f"echo_v2({relay})" + + +def echo_v3(_, relay=""): + return f"echo_v3({relay})" + + +# an endpoint is associated with an http URL. +serve.create_endpoint("my_endpoint1", "/echo1") +serve.create_endpoint("my_endpoint2", "/echo2") +serve.create_endpoint("my_endpoint3", "/echo3") + +# create backends +serve.create_backend( + echo_v1, + "echo:v1", + backend_config=BackendConfig(enable_predicate=True), + predicate_function=echo_v1_predicate) +serve.create_backend(echo_v2, "echo:v2") +serve.create_backend(echo_v3, "echo:v3") + +# link service to backends +serve.link("my_endpoint1", "echo:v1") +serve.link("my_endpoint2", "echo:v2") +serve.link("my_endpoint3", "echo:v3") + +# get the handle of the endpoints +handle1 = serve.get_handle("my_endpoint1") +handle2 = serve.get_handle("my_endpoint2") +handle3 = serve.get_handle("my_endpoint3") + +for number in [0.2, 0.8]: + first_object_id = ray.ObjectID.from_random() + predicate_object_id = ray.ObjectID.from_random() + handle1.remote( + num=number, + return_object_ids={ + serve.RESULT_KEY: first_object_id, + serve.PREDICATE_KEY: predicate_object_id + }) + second_object_id = ray.ObjectID.from_random() + + return_val = handle2.remote( + relay=first_object_id, + predicate_condition=predicate_object_id, + default_value=("kwargs", "relay"), + return_object_ids={serve.RESULT_KEY: second_object_id}) + + assert return_val is None + result = ray.get(handle3.remote(relay=second_object_id)) + print("For number : {} the whole pipeline output is : {}".format( + number, result)) diff --git a/python/ray/experimental/serve/examples/echo_pipeline.py b/python/ray/experimental/serve/examples/echo_pipeline.py index f25b2b188f0d5..4df33fbdc8176 100644 --- a/python/ray/experimental/serve/examples/echo_pipeline.py +++ b/python/ray/experimental/serve/examples/echo_pipeline.py @@ -82,7 +82,7 @@ def echo_v4(_, relay1="", relay2=""): # asynchronous! All the remote calls below are completely asynchronous temp1 = handle2.remote( relay=first_object_id, - return_object_ids=[second_object_id], + return_object_ids={serve.RESULT_KEY: second_object_id}, slo_ms=wall_clock_slo, is_wall_clock_time=True) @@ -90,14 +90,14 @@ def echo_v4(_, relay1="", relay2=""): assert temp1 is None handle3.remote( relay=first_object_id, - return_object_ids=[third_object_id], + return_object_ids={serve.RESULT_KEY: third_object_id}, slo_ms=wall_clock_slo, is_wall_clock_time=True) fourth_object_id = ray.ObjectID.from_random() temp2 = handle4.remote( relay1=second_object_id, relay2=third_object_id, - return_object_ids=[fourth_object_id], + return_object_ids={serve.RESULT_KEY: fourth_object_id}, slo_ms=wall_clock_slo, is_wall_clock_time=True) assert temp2 is None diff --git a/python/ray/experimental/serve/handle.py b/python/ray/experimental/serve/handle.py index 36cad11eb9c68..74892a519e280 100644 --- a/python/ray/experimental/serve/handle.py +++ b/python/ray/experimental/serve/handle.py @@ -2,7 +2,8 @@ from ray.experimental import serve from ray.experimental.serve.context import TaskContext from ray.experimental.serve.exceptions import RayServeException -from ray.experimental.serve.constants import DEFAULT_HTTP_ADDRESS +from ray.experimental.serve.constants import (DEFAULT_HTTP_ADDRESS, + PREDICATE_DEFAULT_VALUE) from ray.experimental.serve.request_params import RequestParams, RequestInfo @@ -40,6 +41,26 @@ def _fix_kwarg_name(self, name): return "request_slo_ms" return name + def _check_default_value(self, default_value, len_args, kwargs_keys): + if default_value != PREDICATE_DEFAULT_VALUE: + if not isinstance(default_value, tuple): + raise ValueError("The default value must be a tuple.") + if len(default_value) != 2: + raise ValueError( + "Specify default_value in format: ('args',arg_index)" + " or ('kwargs', kwargs_key)") + val = default_value[0] + if val not in ["args", "kwargs"]: + raise ValueError( + "First value of default_value must be: 'args' or 'kwargs'." + ) + if val == "args": + if default_value[1] >= len_args: + raise ValueError("Specify the args index currently!") + else: + if default_value[1] not in kwargs_keys: + raise ValueError("Specify the kwargs key correctly!") + def __init__(self, router_handle, endpoint_name): self.router_handle = router_handle self.endpoint_name = endpoint_name @@ -65,11 +86,22 @@ def remote(self, *args, **kwargs): except ValueError as e: raise RayServeException(str(e)) - # create request parameters required for enqueuing the request + # check and pop predicate_condition and default value + # specified while enqueuing + predicate_condition = kwargs.pop("predicate_condition", True) + default_value = kwargs.pop("default_value", PREDICATE_DEFAULT_VALUE) + try: + self._check_default_value(default_value, len(args), + list(kwargs.keys())) + except ValueError as e: + raise RayServeException(str(e)) + + # create request parameters required for enqueuing the request request_params = RequestParams(self.endpoint_name, TaskContext.Python, **request_param_kwargs) req_info_object_id = self.router_handle.enqueue_request.remote( - request_params, *args, **kwargs) + request_params, predicate_condition, default_value, *args, + **kwargs) # check if it is necessary to wait for enqueue to be completed # NOTE: This will make remote call completely non-blocking for diff --git a/python/ray/experimental/serve/kv_store_service.py b/python/ray/experimental/serve/kv_store_service.py index 83aacc5f9774f..87035114a79d9 100644 --- a/python/ray/experimental/serve/kv_store_service.py +++ b/python/ray/experimental/serve/kv_store_service.py @@ -238,6 +238,11 @@ def get_backend_creator(self, backend_tag): def list_backends(self): return list(self.backend_table.as_dict().keys()) + def get_backend_predicate(self, backend_tag): + backend_info = json.loads(self.backend_info.get(backend_tag, "{}")) + if "enable_predicate" in backend_info: + return backend_info["enable_predicate"] + def list_replicas(self, backend_tag: str): return json.loads(self.replica_table.get(backend_tag, "[]")) diff --git a/python/ray/experimental/serve/queues.py b/python/ray/experimental/serve/queues.py index 005f1c13e6fd1..e3bad7fb4576d 100644 --- a/python/ray/experimental/serve/queues.py +++ b/python/ray/experimental/serve/queues.py @@ -8,6 +8,10 @@ from blist import sortedlist import time from ray.experimental.serve.request_params import RequestInfo +from ray.experimental.serve.constants import (RESULT_KEY, PREDICATE_KEY, + PREDICATE_DEFAULT_VALUE) + +from copy import deepcopy class Query: @@ -16,15 +20,14 @@ def __init__(self, request_kwargs, request_context, request_slo_ms, - result_object_id=None): + result_object_id={}): self.request_args = request_args self.request_kwargs = request_kwargs self.request_context = request_context - if result_object_id is None: - self.result_object_id = [ray.ObjectID.from_random()] - else: - self.result_object_id = result_object_id + if RESULT_KEY not in result_object_id: + result_object_id[RESULT_KEY] = ray.ObjectID.from_random() + self.result_object_id = result_object_id # Service level objective in milliseconds. This is expected to be the # absolute time since unix epoch. @@ -114,7 +117,12 @@ def _serve_metric(self): # request_slo_ms is time specified in milliseconds till which the # answer of the query should be calculated - def enqueue_request(self, request_params, *request_args, **request_kwargs): + def enqueue_request(self, + request_params, + predicate_condition=True, + default_value=PREDICATE_DEFAULT_VALUE, + *request_args, + **request_kwargs): """ Enqueues a request in the service queue. @@ -122,6 +130,11 @@ def enqueue_request(self, request_params, *request_args, **request_kwargs): request_params(RequestParams): Argument specified for enqueuing request and getting information correspondingly after the enqueue is called. + predicate_condition(bool): Specifies whether to enqueue the + request or not. Intendted when this function is called passing + a ray.ObjectID for this variable. + default_value: Only used when predicate_condition is `False`. The + RequestInfo returned will have this value. *request_args(optional): The arguments that need to be passed to backend class `__call__` method. **request_kwargs(optional): The keyword arguments that need to be @@ -129,30 +142,61 @@ def enqueue_request(self, request_params, *request_args, **request_kwargs): :rtype: RequestInfo """ - request_slo_ms = request_params.request_slo_ms - if request_slo_ms is None: - # if request_slo_ms is not specified then set it to a high level - request_slo_ms = 1e9 - - # add wall clock time to specify the deadline for completion of query - # this also assures FIFO behaviour if request_slo_ms is not specified - # if request_slo_ms is not wall clock time - if not request_params.is_wall_clock_time: - request_slo_ms += (time.time() * 1000) - query = Query( - request_args, - request_kwargs, - request_params.request_context, - request_slo_ms, - result_object_id=request_params.return_object_ids) - - self.queues[request_params.service].append(query) - self.flush() - # create request information to be returned - req_info = RequestInfo( - query.result_object_id, request_params.return_object_ids is None, - request_slo_ms, request_params.return_wall_clock_time) - return req_info + if predicate_condition: + request_slo_ms = request_params.request_slo_ms + return_object_id = ( + RESULT_KEY not in request_params.return_object_ids) + if request_slo_ms is None: + # if request_slo_ms is not specified then set + # it to a high level + request_slo_ms = 1e9 + + # add wall clock time to specify the deadline for completion of + # query this also assures FIFO behaviour if request_slo_ms is not + # specified if request_slo_ms is not wall clock time. + if not request_params.is_wall_clock_time: + request_slo_ms += (time.time() * 1000) + query = Query( + request_args, + request_kwargs, + request_params.request_context, + request_slo_ms, + result_object_id=deepcopy(request_params.return_object_ids)) + + self.queues[request_params.service].append(query) + self.flush() + # create request information to be returned + req_info = RequestInfo(query.result_object_id, return_object_id, + request_slo_ms, + request_params.return_wall_clock_time) + return req_info + else: + # set the default value + if default_value != PREDICATE_DEFAULT_VALUE: + args_kwargs_identifier, param = default_value + if args_kwargs_identifier == "args": + default_value = request_args[param] + else: + default_value = request_kwargs[param] + object_id_dict = request_params.return_object_ids + request_slo_ms = request_params.request_slo_ms + if request_slo_ms is None: + request_slo_ms = 0 + if not request_params.is_wall_clock_time: + request_slo_ms += (time.time() * 1000) + return_object_id = False + if RESULT_KEY not in object_id_dict: + return_object_id = True + object_id_dict[RESULT_KEY] = ray.ObjectID.from_random() + ray.worker.global_worker.put_object(default_value, + object_id_dict[RESULT_KEY]) + if PREDICATE_KEY in object_id_dict: + ray.worker.global_worker.put_object( + predicate_condition, object_id_dict[PREDICATE_KEY]) + req_info = RequestInfo(object_id_dict, return_object_id, + request_slo_ms, + request_params.return_wall_clock_time) + return req_info def dequeue_request(self, backend, replica_handle): intention = WorkIntent(replica_handle) diff --git a/python/ray/experimental/serve/request_params.py b/python/ray/experimental/serve/request_params.py index 6f098fdde34d3..586db0b690ba5 100644 --- a/python/ray/experimental/serve/request_params.py +++ b/python/ray/experimental/serve/request_params.py @@ -1,4 +1,5 @@ import inspect +from ray.experimental.serve.constants import RESULT_KEY, PREDICATE_KEY class RequestParams: @@ -11,8 +12,9 @@ class RequestParams: request_context(TaskContext): Context of a request. request_slo_ms(float): Expected time for the query to get completed. - return_object_ids(list[ray.ObjectID]): List of ObjectIds where - result of the request will be put. + return_object_ids(dict[str,ray.ObjectID]): Dictionary of ObjectIds + where result or predicate of the request will be put. Supported + keys are: ['result' , 'predicate'] is_wall_clock_time(bool): if True, router won't add wall clock time to `request_slo_ms`. return_wall_clock_time(bool): if True, wall clock time for query @@ -24,13 +26,22 @@ def __init__(self, service, request_context, request_slo_ms=None, - return_object_ids=None, + return_object_ids={}, is_wall_clock_time=False, return_wall_clock_time=False): self.service = service self.request_context = request_context self.request_slo_ms = request_slo_ms + if return_object_ids is not None: + # check for dictionary + assert isinstance(return_object_ids, dict), ("return_object_ids" + " must be a " + "dictionary.") + # keys must be a subset of return_keys_supproted + assert (set(return_object_ids.keys()) <= set( + [RESULT_KEY, PREDICATE_KEY])), ("return_object_ids " + "specified wrongly") self.return_object_ids = return_object_ids self.is_wall_clock_time = is_wall_clock_time self.return_wall_clock_time = return_wall_clock_time @@ -73,14 +84,13 @@ def __init__(self, result_object_id, return_object_id, request_slo_ms, def __iter__(self): if self.return_object_id: - for object_id in self.result_object_id: - yield object_id + yield self.result_object_id[RESULT_KEY] if self.return_wall_clock_time: yield self.request_slo_ms @staticmethod def wait_for_requestInfo(request_params): - if (request_params.return_object_ids is None + if (RESULT_KEY not in request_params.return_object_ids or request_params.return_wall_clock_time): return True return False diff --git a/python/ray/experimental/serve/server.py b/python/ray/experimental/serve/server.py index be6d440e0a7a4..94220da18c348 100644 --- a/python/ray/experimental/serve/server.py +++ b/python/ray/experimental/serve/server.py @@ -5,7 +5,8 @@ import ray from ray.experimental.async_api import _async_init -from ray.experimental.serve.constants import HTTP_ROUTER_CHECKER_INTERVAL_S +from ray.experimental.serve.constants import (HTTP_ROUTER_CHECKER_INTERVAL_S, + PREDICATE_DEFAULT_VALUE) from ray.experimental.serve.context import TaskContext from ray.experimental.serve.utils import BytesEncoder from urllib.parse import parse_qs @@ -162,8 +163,9 @@ async def __call__(self, scope, receive, send): # await for request info to get back req_info = await (self.serve_global_state.init_or_get_router() - .enqueue_request.remote(request_params, *args, - **kwargs)) + .enqueue_request.remote(request_params, True, + PREDICATE_DEFAULT_VALUE, + *args, **kwargs)) # await for result result = await next(iter(req_info)) diff --git a/python/ray/experimental/serve/task_runner.py b/python/ray/experimental/serve/task_runner.py index d3288883e0733..35b1e81afee5e 100644 --- a/python/ray/experimental/serve/task_runner.py +++ b/python/ray/experimental/serve/task_runner.py @@ -7,6 +7,7 @@ from collections import defaultdict from ray.experimental.serve.utils import parse_request_item from ray.experimental.serve.exceptions import RayServeException +from ray.experimental.serve.constants import RESULT_KEY, PREDICATE_KEY class TaskRunner: @@ -16,12 +17,16 @@ class TaskRunner: That is, a ray serve actor should implement the TaskRunner interface. """ - def __init__(self, func_to_run): + def __init__(self, func_to_run, predicate_to_run=None): self.func = func_to_run + self.predicate_to_run = predicate_to_run def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) + def __predicate__(self, result): + return self.predicate_to_run(result) + def wrap_to_ray_error(exception): """Utility method that catch and seal exceptions in execution""" @@ -53,6 +58,7 @@ class RayServeActor(RayServeMixin, MyClass): _ray_serve_router_handle = None _ray_serve_setup_completed = False _ray_serve_dequeue_requester_name = None + _predicate_required = False # Work token can be unfullfilled from last iteration. # This cache will be used to determine whether or not we should @@ -81,11 +87,16 @@ def _serve_metric(self): }, } - def _ray_serve_setup(self, my_name, router_handle, my_handle): + def _ray_serve_setup(self, + my_name, + router_handle, + my_handle, + predicate_required=False): self._ray_serve_dequeue_requester_name = my_name self._ray_serve_router_handle = router_handle self._ray_serve_self_handle = my_handle self._ray_serve_setup_completed = True + self._predicate_required = predicate_required def _ray_serve_fetch(self): assert self._ray_serve_setup_completed @@ -94,21 +105,43 @@ def _ray_serve_fetch(self): self._ray_serve_dequeue_requester_name, self._ray_serve_self_handle) + def _check_predicate_demand(self, predicate_required): + if predicate_required != self._predicate_required: + if predicate_required: + raise RayServeException( + "Backend doesn't support predicate feature!") + else: + raise RayServeException( + "Backend requires predicate objectid to be specified" + " while issuing the request!") + def invoke_single(self, request_item): args, kwargs, is_web_context, result_object_id = parse_request_item( request_item) serve_context.web = is_web_context + predicate_required = PREDICATE_KEY in result_object_id start_timestamp = time.time() try: + # check predicate demand + self._check_predicate_demand(predicate_required) result = self.__call__(*args, **kwargs) - for rid in result_object_id: - ray.worker.global_worker.put_object(result, rid) + if self._predicate_required: + predicate_result = self.__predicate__(result) + + ray.worker.global_worker.put_object(result, + result_object_id[RESULT_KEY]) + if self._predicate_required: + ray.worker.global_worker.put_object( + predicate_result, result_object_id[PREDICATE_KEY]) + except Exception as e: wrapped_exception = wrap_to_ray_error(e) self._serve_metric_error_counter += 1 - for return_id in result_object_id: - ray.worker.global_worker.put_object(wrapped_exception, - return_id) + ray.worker.global_worker.put_object(wrapped_exception, + result_object_id[RESULT_KEY]) + if PREDICATE_KEY in result_object_id: + ray.worker.global_worker.put_object( + wrapped_exception, result_object_id[PREDICATE_KEY]) self._serve_metric_latency_list.append(time.time() - start_timestamp) def invoke_batch(self, request_item_list): @@ -129,11 +162,13 @@ def invoke_batch(self, request_item_list): kwargs_list = defaultdict(list) result_object_ids, context_flag_list, arg_list = [], [], [] curr_batch_size = len(request_item_list) - + predicate_list_flag = list() for item in request_item_list: args, kwargs, is_web_context, result_object_id = ( parse_request_item(item)) context_flag_list.append(is_web_context) + predicate_required = PREDICATE_KEY in result_object_id + predicate_list_flag.append(predicate_required) # Python context only have kwargs # Web context only have one positional argument @@ -150,7 +185,19 @@ def invoke_batch(self, request_item_list): if len(set(context_flag_list)) != 1: raise RayServeException( "Batched queries contain mixed context.") + + # check mixing of predicate queries + # unified query in terms of predicate return needed + if len(set(predicate_list_flag)) != 1: + raise RayServeException( + "Batched queries contain mixed predicate demand.") + serve_context.web = all(context_flag_list) + predicate_required = all(predicate_list_flag) + + # check predicate demand + self._check_predicate_demand(predicate_required) + if serve_context.web: args = (arg_list, ) else: @@ -163,8 +210,11 @@ def invoke_batch(self, request_item_list): args = (fake_flask_request_lst, ) # set the current batch size (n) for serve_context serve_context.batch_size = len(result_object_ids) + start_timestamp = time.time() result_list = self.__call__(*args, **kwargs_list) + if self._predicate_required: + predicate_list = self.__predicate__(result_list) if (not isinstance(result_list, list)) or (len(result_list) != len(result_object_ids)): raise RayServeException("__call__ function " @@ -173,20 +223,40 @@ def invoke_batch(self, request_item_list): "with length equals to the batch " "size.") - for result, result_object_id in zip(result_list, - result_object_ids): - for return_id in result_object_id: - ray.worker.global_worker.put_object(result, return_id) + if self._predicate_required: + if ((not isinstance(predicate_list, list)) + or (len(predicate_list) != len(result_object_ids))): + raise RayServeException("__predicate__ function " + "doesn't preserve batch-size. " + "Please return a list of result " + "with length equals to the batch " + "size.") + + for result, predicate_result, result_object_id in zip( + result_list, predicate_list, result_object_ids): + ray.worker.global_worker.put_object( + result, result_object_id[RESULT_KEY]) + ray.worker.global_worker.put_object( + predicate_result, result_object_id[PREDICATE_KEY]) + else: + for result, result_object_id in zip(result_list, + result_object_ids): + ray.worker.global_worker.put_object( + result, result_object_id[RESULT_KEY]) self._serve_metric_latency_list.append(time.time() - start_timestamp) + except Exception as e: wrapped_exception = wrap_to_ray_error(e) self._serve_metric_error_counter += len(result_object_ids) for result_object_id in result_object_ids: - for return_id in result_object_id: + # for return_id in result_object_id: + ray.worker.global_worker.put_object( + wrapped_exception, result_object_id[RESULT_KEY]) + if PREDICATE_KEY in result_object_id: ray.worker.global_worker.put_object( - wrapped_exception, return_id) + wrapped_exception, result_object_id[PREDICATE_KEY]) def _ray_serve_call(self, request): work_item = request diff --git a/python/ray/experimental/serve/tests/test_api.py b/python/ray/experimental/serve/tests/test_api.py index 7cf360b318ca2..b82ef2d7a72ba 100644 --- a/python/ray/experimental/serve/tests/test_api.py +++ b/python/ray/experimental/serve/tests/test_api.py @@ -84,6 +84,8 @@ def __init__(self): @serve.accept_batch def __call__(self, flask_request, temp=None): + # simulating some intensive work + time.sleep(1) self.count += 1 batch_size = serve.context.batch_size return [self.count] * batch_size diff --git a/python/ray/experimental/serve/tests/test_queue.py b/python/ray/experimental/serve/tests/test_queue.py index 84c57d8fa1ccb..6f097a2d94781 100644 --- a/python/ray/experimental/serve/tests/test_queue.py +++ b/python/ray/experimental/serve/tests/test_queue.py @@ -4,6 +4,8 @@ from ray.experimental.serve.queues import (RoundRobinPolicyQueue, FixedPackingPolicyQueue) from ray.experimental.serve.request_params import RequestParams +from ray.experimental.serve.constants import (RESULT_KEY, + PREDICATE_DEFAULT_VALUE) @pytest.fixture(scope="session") @@ -28,13 +30,16 @@ def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor): q.link("svc", "backend") result_object_id = next( - iter(q.enqueue_request(RequestParams("svc", None), 1))) + iter( + q.enqueue_request( + RequestParams("svc", None), True, PREDICATE_DEFAULT_VALUE, 1))) q.dequeue_request("backend", task_runner_mock_actor) got_work = ray.get(task_runner_mock_actor.get_recent_call.remote()) assert got_work.request_args[0] == 1 assert got_work.request_kwargs == {} - ray.worker.global_worker.put_object(2, got_work.result_object_id[0]) + ray.worker.global_worker.put_object(2, + got_work.result_object_id[RESULT_KEY]) assert ray.get(result_object_id) == 2 @@ -44,7 +49,9 @@ def test_slo(serve_instance, task_runner_mock_actor): for i in range(10): slo_ms = 1000 - 100 * i - q.enqueue_request(RequestParams("svc", None, request_slo_ms=slo_ms), i) + q.enqueue_request( + RequestParams("svc", None, request_slo_ms=slo_ms), True, + PREDICATE_DEFAULT_VALUE, i) for i in range(10): q.dequeue_request("backend", task_runner_mock_actor) got_work = ray.get(task_runner_mock_actor.get_recent_call.remote()) @@ -56,20 +63,26 @@ def test_alter_backend(serve_instance, task_runner_mock_actor): q.set_traffic("svc", {"backend-1": 1}) result_object_id = next( - iter(q.enqueue_request(RequestParams("svc", None), 1))) + iter( + q.enqueue_request( + RequestParams("svc", None), True, PREDICATE_DEFAULT_VALUE, 1))) q.dequeue_request("backend-1", task_runner_mock_actor) got_work = ray.get(task_runner_mock_actor.get_recent_call.remote()) assert got_work.request_args[0] == 1 - ray.worker.global_worker.put_object(2, got_work.result_object_id[0]) + ray.worker.global_worker.put_object(2, + got_work.result_object_id[RESULT_KEY]) assert ray.get(result_object_id) == 2 q.set_traffic("svc", {"backend-2": 1}) result_object_id = next( - iter(q.enqueue_request(RequestParams("svc", None), 1))) + iter( + q.enqueue_request( + RequestParams("svc", None), True, PREDICATE_DEFAULT_VALUE, 1))) q.dequeue_request("backend-2", task_runner_mock_actor) got_work = ray.get(task_runner_mock_actor.get_recent_call.remote()) assert got_work.request_args[0] == 1 - ray.worker.global_worker.put_object(2, got_work.result_object_id[0]) + ray.worker.global_worker.put_object(2, + got_work.result_object_id[RESULT_KEY]) assert ray.get(result_object_id) == 2 @@ -80,7 +93,8 @@ def test_split_traffic(serve_instance, task_runner_mock_actor): # assume 50% split, the probability of all 20 requests goes to a # single queue is 0.5^20 ~ 1-6 for _ in range(20): - q.enqueue_request(RequestParams("svc", None), 1) + q.enqueue_request( + RequestParams("svc", None), True, PREDICATE_DEFAULT_VALUE, 1) q.dequeue_request("backend-1", task_runner_mock_actor) result_one = ray.get(task_runner_mock_actor.get_recent_call.remote()) q.dequeue_request("backend-2", task_runner_mock_actor) @@ -96,7 +110,8 @@ def test_split_traffic_round_robin(serve_instance, task_runner_mock_actor): # since round robin policy is stateful firing two queries consecutively # would transfer the queries to two different backends for _ in range(2): - q.enqueue_request(RequestParams("svc", None), 1) + q.enqueue_request( + RequestParams("svc", None), True, PREDICATE_DEFAULT_VALUE, 1) q.dequeue_request("backend-1", task_runner_mock_actor) result_one = ray.get(task_runner_mock_actor.get_recent_call.remote()) q.dequeue_request("backend-2", task_runner_mock_actor) @@ -113,7 +128,8 @@ def test_split_traffic_fixed_packing(serve_instance, task_runner_mock_actor): # fire twice the number of queries as the packing number for i in range(2 * packing_num): - q.enqueue_request(RequestParams("svc", None), i) + q.enqueue_request( + RequestParams("svc", None), True, PREDICATE_DEFAULT_VALUE, i) # both the backends will get equal number of queries # as it is packed round robin diff --git a/python/ray/experimental/serve/tests/test_task_runner.py b/python/ray/experimental/serve/tests/test_task_runner.py index 941b47672ed51..af3b61ec30778 100644 --- a/python/ray/experimental/serve/tests/test_task_runner.py +++ b/python/ray/experimental/serve/tests/test_task_runner.py @@ -7,6 +7,8 @@ RayServeMixin, TaskRunner, TaskRunnerActor, wrap_to_ray_error) from ray.experimental.serve.request_params import RequestParams +from ray.experimental.serve.constants import PREDICATE_DEFAULT_VALUE + def test_runner_basic(): def echo(i): @@ -40,7 +42,10 @@ def echo(flask_request, i=None): for query in [333, 444, 555]: query_param = RequestParams(PRODUCER_NAME, context.TaskContext.Python) result_token = next( - iter(ray.get(q.enqueue_request.remote(query_param, i=query)))) + iter( + ray.get( + q.enqueue_request.remote( + query_param, True, PREDICATE_DEFAULT_VALUE, i=query)))) assert ray.get(result_token) == query @@ -72,7 +77,10 @@ class CustomActor(MyAdder, RayServeMixin): for query in [333, 444, 555]: query_param = RequestParams(PRODUCER_NAME, context.TaskContext.Python) result_token = next( - iter(ray.get(q.enqueue_request.remote(query_param, i=query)))) + iter( + ray.get( + q.enqueue_request.remote( + query_param, True, PREDICATE_DEFAULT_VALUE, i=query)))) assert ray.get(result_token) == query + 3 @@ -94,7 +102,10 @@ def echo(flask_request, i=None): q.link.remote(PRODUCER_NAME, CONSUMER_NAME) query_param = RequestParams(PRODUCER_NAME, context.TaskContext.Python) result_token = next( - iter(ray.get(q.enqueue_request.remote(query_param, i=42)))) + iter( + ray.get( + q.enqueue_request.remote( + query_param, True, PREDICATE_DEFAULT_VALUE, i=42)))) with pytest.raises(ray.exceptions.RayTaskError): ray.get(result_token)