diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index c821c5c1d0936..9e00c8e77a98f 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -655,9 +655,22 @@ def _validate_sagemaker(self): return def _add_dynamic_batching(self, _dynamic_batching: Optional[Dict]): + import collections + + def deep_update(source, overrides): + for key, value in overrides.items(): + if isinstance(value, collections.Mapping) and value: + returned = deep_update(source.get(key, {}), value) + source[key] = returned + else: + source[key] = overrides[key] + return source + if _dynamic_batching: self.dynamic_batching = getattr(self, 'dynamic_batching', {}) - self.dynamic_batching.update(_dynamic_batching) + self.dynamic_batching = deep_update( + self.dynamic_batching, _dynamic_batching + ) def _add_metas(self, _metas: Optional[Dict]): from jina.serve.executors.metas import get_default_metas diff --git a/jina/serve/executors/decorators.py b/jina/serve/executors/decorators.py index 7c7a6e4031bcf..b9072929cbed7 100644 --- a/jina/serve/executors/decorators.py +++ b/jina/serve/executors/decorators.py @@ -416,7 +416,9 @@ def dynamic_batching( *, preferred_batch_size: Optional[int] = None, timeout: Optional[float] = 10_000, - flush_all: bool = False + flush_all: bool = False, + custom_metric: Optional[Callable[['DocumentArray'], Union[float, int]]] = None, + use_custom_metric: bool = False, ): """ `@dynamic_batching` defines the dynamic batching behavior of an Executor. @@ -434,6 +436,8 @@ def dynamic_batching( Default is 10_000ms (10 seconds). :param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not. If this is true, `preferred_batch_size` is used as a trigger mechanism. + :param custom_metric: Potential lambda function to measure the "weight" of each request. + :param use_custom_metric: Determines if we need to use the `custom_metric` to determine preferred_batch_size. :return: decorated function """ @@ -480,6 +484,8 @@ def _inject_owner_attrs(self, owner, name): ] = preferred_batch_size owner.dynamic_batching[fn_name]['timeout'] = timeout owner.dynamic_batching[fn_name]['flush_all'] = flush_all + owner.dynamic_batching[fn_name]['use_custom_metric'] = use_custom_metric + owner.dynamic_batching[fn_name]['custom_metric'] = custom_metric setattr(owner, name, self.fn) def __set_name__(self, owner, name): diff --git a/jina/serve/runtimes/worker/batch_queue.py b/jina/serve/runtimes/worker/batch_queue.py index 0419e35414a46..31bac588d5efd 100644 --- a/jina/serve/runtimes/worker/batch_queue.py +++ b/jina/serve/runtimes/worker/batch_queue.py @@ -1,9 +1,10 @@ import asyncio import copy from asyncio import Event, Task -from typing import Callable, Dict, List, Optional, TYPE_CHECKING +from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union from jina._docarray import docarray_v2 import contextlib + if not docarray_v2: from docarray import DocumentArray else: @@ -18,16 +19,18 @@ class BatchQueue: """A batch queue that holds the data request and the callable to batch requests to.""" def __init__( - self, - func: Callable, - request_docarray_cls, - response_docarray_cls, - output_array_type: Optional[str] = None, - params: Optional[Dict] = None, - allow_concurrent: bool = False, - flush_all: bool = False, - preferred_batch_size: int = 4, - timeout: int = 10_000, + self, + func: Callable, + request_docarray_cls, + response_docarray_cls, + output_array_type: Optional[str] = None, + params: Optional[Dict] = None, + allow_concurrent: bool = False, + flush_all: bool = False, + preferred_batch_size: int = 4, + timeout: int = 10_000, + custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None, + use_custom_metric: bool = False, ) -> None: # To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent if allow_concurrent and flush_all: @@ -44,6 +47,8 @@ def __init__( self._response_docarray_cls = response_docarray_cls self._flush_all = flush_all self._preferred_batch_size: int = preferred_batch_size + self._custom_metric = None if not use_custom_metric else custom_metric + self._metric_value = 0 self._timeout: int = timeout self._reset() self._flush_trigger: Event = Event() @@ -62,20 +67,22 @@ def _reset(self) -> None: # a list of every request ID self._request_idxs: List[int] = [] self._request_lens: List[int] = [] + self._docs_metrics: List[int] = [] self._requests_completed: List[asyncio.Queue] = [] if not docarray_v2: self._big_doc: DocumentArray = DocumentArray.empty() else: self._big_doc = self._request_docarray_cls() + self._metric_value = 0 self._flush_task: Optional[Task] = None self._flush_trigger: Event = Event() def _cancel_timer_if_pending(self): if ( - self._timer_task - and not self._timer_task.done() - and not self._timer_task.cancelled() + self._timer_task + and not self._timer_task.done() + and not self._timer_task.cancelled() ): self._timer_finished = False self._timer_task.cancel() @@ -91,7 +98,7 @@ async def _sleep_then_set(self): self._flush_trigger.set() self._timer_finished = True - async def push(self, request: DataRequest, http = False) -> asyncio.Queue: + async def push(self, request: DataRequest, http=False) -> asyncio.Queue: """Append request to the the list of requests to be processed. This method creates an asyncio Queue for that request and keeps track of it. It returns @@ -116,12 +123,18 @@ async def push(self, request: DataRequest, http = False) -> asyncio.Queue: self._big_doc.extend(docs) next_req_idx = len(self._requests) num_docs = len(docs) + metric_value = num_docs + if self._custom_metric is not None: + metrics = [self._custom_metric(doc) for doc in docs] + metric_value += sum(metrics) + self._docs_metrics.extend(metrics) + self._metric_value += metric_value self._request_idxs.extend([next_req_idx] * num_docs) - self._request_lens.append(len(docs)) + self._request_lens.append(num_docs) self._requests.append(request) queue = asyncio.Queue() self._requests_completed.append(queue) - if len(self._big_doc) >= self._preferred_batch_size: + if self._metric_value >= self._preferred_batch_size: self._flush_trigger.set() return queue @@ -132,10 +145,10 @@ async def _await_then_flush(self, http=False) -> None: """ def _get_docs_groups_completed_request_indexes( - non_assigned_docs, - non_assigned_docs_reqs_idx, - sum_from_previous_mini_batch_in_first_req_idx, - requests_lens_in_batch, + non_assigned_docs, + non_assigned_docs_reqs_idx, + sum_from_previous_mini_batch_in_first_req_idx, + requests_lens_in_batch, ): """ This method groups all the `non_assigned_docs` into groups of docs according to the `req_idx` they belong to. @@ -160,9 +173,9 @@ def _get_docs_groups_completed_request_indexes( ) if req_idx > min_involved_req_idx: request_bucket = non_assigned_docs[ - num_distributed_docs : num_distributed_docs - + num_docs_in_req_idx - ] + num_distributed_docs: num_distributed_docs + + num_docs_in_req_idx + ] num_distributed_docs += num_docs_in_req_idx completed_req_idx.append(min_involved_req_idx) min_involved_req_idx = req_idx @@ -171,25 +184,25 @@ def _get_docs_groups_completed_request_indexes( num_docs_in_req_idx += 1 if ( - req_idx not in completed_req_idx - and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx - == requests_lens_in_batch[req_idx] + req_idx not in completed_req_idx + and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx + == requests_lens_in_batch[req_idx] ): completed_req_idx.append(req_idx) request_bucket = non_assigned_docs[ - num_distributed_docs : num_distributed_docs + num_docs_in_req_idx - ] + num_distributed_docs: num_distributed_docs + num_docs_in_req_idx + ] distributed_requests.append(request_bucket) return distributed_requests, completed_req_idx async def _assign_results( - non_assigned_docs, - non_assigned_docs_reqs_idx, - sum_from_previous_mini_batch_in_first_req_idx, - requests_lens_in_batch, - requests_in_batch, - requests_completed_in_batch, + non_assigned_docs, + non_assigned_docs_reqs_idx, + sum_from_previous_mini_batch_in_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, ): """ This method aims to assign to the corresponding request objects the resulting documents from the mini batches. @@ -220,7 +233,7 @@ async def _assign_results( request = requests_in_batch[request_idx] request_completed = requests_completed_in_batch[request_idx] if http is False or self._output_array_type is not None: - request.direct_docs = None # batch queue will work in place, therefore result will need to read from data. + request.direct_docs = None # batch queue will work in place, therefore result will need to read from data. request.data.set_docs_convert_arrays( docs_group, ndarray_type=self._output_array_type ) @@ -230,15 +243,31 @@ async def _assign_results( return num_assigned_docs - def batch(iterable_1, iterable_2, n:Optional[int] = 1): + def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Optional = None): if n is None: yield iterable_1, iterable_2 return - items = len(iterable_1) - for ndx in range(0, items, n): - yield iterable_1[ndx : min(ndx + n, items)], iterable_2[ - ndx : min(ndx + n, items) - ] + elif iterable_metrics is None: + items = len(iterable_1) + for ndx in range(0, items, n): + yield iterable_1[ndx: min(ndx + n, items)], iterable_2[ + ndx: min(ndx + n, items) + ] + else: + batch_idx = 0 + batch_weight = 0 + + for i, (item, weight) in enumerate(zip(iterable_1, iterable_metrics)): + batch_weight += weight + + if batch_weight >= n: + yield iterable_1[batch_idx: i + 1], iterable_2[batch_idx: i + 1] + batch_idx = i + 1 + batch_weight = 0 + + # Yield any remaining items + if batch_weight > 0: + yield iterable_1[batch_idx: len(iterable_1)], iterable_2[batch_idx: len(iterable_1)] await self._flush_trigger.wait() # writes to shared data between tasks need to be mutually exclusive @@ -246,6 +275,7 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1): big_doc_in_batch = copy.copy(self._big_doc) requests_idxs_in_batch = copy.copy(self._request_idxs) requests_lens_in_batch = copy.copy(self._request_lens) + docs_metrics_in_batch = copy.copy(self._docs_metrics) requests_in_batch = copy.copy(self._requests) requests_completed_in_batch = copy.copy(self._requests_completed) @@ -263,7 +293,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1): non_assigned_to_response_request_idxs = [] sum_from_previous_first_req_idx = 0 for docs_inner_batch, req_idxs in batch( - big_doc_in_batch, requests_idxs_in_batch, self._preferred_batch_size if not self._flush_all else None + big_doc_in_batch, requests_idxs_in_batch, + self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None ): involved_requests_min_indx = req_idxs[0] involved_requests_max_indx = req_idxs[-1] @@ -278,8 +309,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1): ) # Output validation if (docarray_v2 and isinstance(batch_res_docs, DocList)) or ( - not docarray_v2 - and isinstance(batch_res_docs, DocumentArray) + not docarray_v2 + and isinstance(batch_res_docs, DocumentArray) ): if not len(batch_res_docs) == input_len_before_call: raise ValueError( @@ -301,8 +332,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1): except Exception as exc: # All the requests containing docs in this Exception should be raising it for request_full in requests_completed_in_batch[ - involved_requests_min_indx : involved_requests_max_indx + 1 - ]: + involved_requests_min_indx: involved_requests_max_indx + 1 + ]: await request_full.put(exc) else: # We need to attribute the docs to their requests @@ -320,11 +351,11 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1): ) sum_from_previous_first_req_idx = ( - len(non_assigned_to_response_docs) - num_assigned_docs + len(non_assigned_to_response_docs) - num_assigned_docs ) non_assigned_to_response_docs = non_assigned_to_response_docs[ - num_assigned_docs: - ] + num_assigned_docs: + ] non_assigned_to_response_request_idxs = ( non_assigned_to_response_request_idxs[num_assigned_docs:] ) diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index 87e98455317bb..b55e8415c0aae 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -736,3 +736,67 @@ def foo(self, docs, **kwargs): assert smaller_than_5 == (1 if allow_concurrent else 0) assert larger_than_5 > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize('use_custom_metric', [True, False]) +@pytest.mark.parametrize('flush_all', [False, True]) +async def test_dynamic_batching_custom_metric(use_custom_metric, flush_all): + class DynCustomBatchProcessor(Executor): + + @dynamic_batching(preferred_batch_size=10, custom_metric=lambda x: len(x.text)) + @requests(on='/foo') + def foo(self, docs, **kwargs): + time.sleep(0.5) + total_len = sum([len(doc.text) for doc in docs]) + for doc in docs: + doc.text = f"{total_len}" + + depl = Deployment(uses=DynCustomBatchProcessor, uses_dynamic_batching={'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_custom_metric": use_custom_metric, "flush_all": flush_all}}) + da = DocumentArray([Document(text='aaaaa') for i in range(50)]) + with depl: + cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True) + res = [] + async for r in cl.post( + on='/foo', + inputs=da, + request_size=1, + continue_on_error=True, + results_in_order=True, + ): + res.extend(r) + assert len(res) == 50 # 1 request per input + + # If custom_metric and flush all + if use_custom_metric and not flush_all: + for doc in res: + assert doc.text == "10" + + elif not use_custom_metric and not flush_all: + for doc in res: + assert doc.text == "50" + + elif use_custom_metric and flush_all: + # There will be 2 "10" and the rest will be "240" + num_10 = 0 + num_240 = 0 + for doc in res: + if doc.text == "10": + num_10 += 1 + elif doc.text == "240": + num_240 += 1 + + assert num_10 == 2 + assert num_240 == 48 + elif not use_custom_metric and flush_all: + # There will be 10 "50" and the rest will be "200" + num_50 = 0 + num_200 = 0 + for doc in res: + if doc.text == "50": + num_50 += 1 + elif doc.text == "200": + num_200 += 1 + + assert num_50 == 10 + assert num_200 == 40 diff --git a/tests/unit/serve/executors/test_executor.py b/tests/unit/serve/executors/test_executor.py index a6d902421ae83..344ebcaab7254 100644 --- a/tests/unit/serve/executors/test_executor.py +++ b/tests/unit/serve/executors/test_executor.py @@ -614,15 +614,15 @@ class C(B): [ ( dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000, flush_all=False), + dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None), ), ( dict(preferred_batch_size=4, timeout=5_000, flush_all=True), - dict(preferred_batch_size=4, timeout=5_000, flush_all=True), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None), ), ( dict(preferred_batch_size=4), - dict(preferred_batch_size=4, timeout=10_000, flush_all=False), + dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None), ), ], ) @@ -641,15 +641,15 @@ def foo(self, docs, **kwargs): [ ( dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000, flush_all=False), + dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None), ), ( dict(preferred_batch_size=4, timeout=5_000, flush_all=True), - dict(preferred_batch_size=4, timeout=5_000, flush_all=True), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None), ), ( dict(preferred_batch_size=4), - dict(preferred_batch_size=4, timeout=10_000, flush_all=False), + dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None), ), ], )