diff --git a/integration/test_batch_v4.py b/integration/test_batch_v4.py index 4ec6a6de4..1e20d25ee 100644 --- a/integration/test_batch_v4.py +++ b/integration/test_batch_v4.py @@ -553,9 +553,9 @@ def test_add_1000_tenant_objects_with_async_indexing_and_wait_for_only_one( lambda client: client.batch.rate_limit(1000), ], ids=[ - "test_add_one_hundred_objects_and_references_between_all_dynamic", - "test_add_one_hundred_objects_and_references_between_all_fixed_size", - "test_add_one_hundred_objects_and_references_between_all_rate_limit", + "test_add_one_object_and_a_self_reference_dynamic", + "test_add_one_object_and_a_self_reference_fixed_size", + "test_add_one_object_and_a_self_reference_rate_limit", ], ) def test_add_one_object_and_a_self_reference( @@ -596,8 +596,8 @@ def batch_insert(batch: BatchClient) -> None: with concurrent.futures.ThreadPoolExecutor() as executor: with client.batch.dynamic() as batch: futures = [executor.submit(batch_insert, batch) for _ in range(nr_threads)] - for future in concurrent.futures.as_completed(futures): - future.result() + for future in concurrent.futures.as_completed(futures): + future.result() objs = client.collections.get(name).query.fetch_objects(limit=nr_objects * nr_threads).objects assert len(objs) == nr_objects * nr_threads diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index 207c8926a..60d50501f 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -1,3 +1,4 @@ +import concurrent.futures import math import threading import time @@ -67,18 +68,16 @@ def __len__(self) -> int: def add(self, item: TBatchInput) -> None: """Add an item to the BatchRequest.""" - self._lock.acquire() - self._items.append(item) - self._lock.release() + with self._lock: + self._items.append(item) def prepend(self, item: List[TBatchInput]) -> None: """Add items to the front of the BatchRequest. This is intended to be used when objects should be retries, eg. after a temporary error. """ - self._lock.acquire() - self._items = item + self._items - self._lock.release() + with self._lock: + self._items = item + self._items class ReferencesBatchRequest(BatchRequest[_BatchReference, BatchReferenceReturn]): @@ -92,15 +91,14 @@ def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[_BatchRefere """ ret: List[_BatchReference] = [] i = 0 - self._lock.acquire() - while len(ret) < pop_amount and len(self._items) > 0 and i < len(self._items): - if self._items[i].from_uuid not in uuid_lookup and ( - self._items[i].to_uuid is None or self._items[i].to_uuid not in uuid_lookup - ): - ret.append(self._items.pop(i)) - else: - i += 1 - self._lock.release() + with self._lock: + while len(ret) < pop_amount and len(self._items) > 0 and i < len(self._items): + if self._items[i].from_uuid not in uuid_lookup and ( + self._items[i].to_uuid is None or self._items[i].to_uuid not in uuid_lookup + ): + ret.append(self._items.pop(i)) + else: + i += 1 return ret @@ -113,15 +111,14 @@ def pop_items(self, pop_amount: int) -> List[_BatchObject]: Returns `List[_BatchObject]` items from the BatchRequest. """ - self._lock.acquire() - if pop_amount >= len(self._items): - ret = copy(self._items) - self._items.clear() - else: - ret = copy(self._items[:pop_amount]) - self._items = self._items[pop_amount:] + with self._lock: + if pop_amount >= len(self._items): + ret = copy(self._items) + self._items.clear() + else: + ret = copy(self._items[:pop_amount]) + self._items = self._items[pop_amount:] - self._lock.release() return ret @@ -149,8 +146,23 @@ class _RateLimitedBatching: requests_per_minute: int +T = TypeVar("T") + + +@dataclass +class _SendReturn(Generic[T]): + n: int + response: T + start: float + + _BatchMode: TypeAlias = Union[_DynamicBatching, _FixedSizeBatching, _RateLimitedBatching] +_BatchFuture = Union[ + concurrent.futures.Future[_SendReturn[BatchObjectReturn]], + concurrent.futures.Future[_SendReturn[BatchReferenceReturn]], +] + class _BatchBase: def __init__( @@ -188,6 +200,7 @@ def __init__( self.__batching_mode: _BatchMode = batch_mode self.__max_batch_size: int = 1000 + self.__flushing = False self.__loop = event_loop self.__objs_count = 0 self.__objs_logs_count = 0 @@ -221,7 +234,6 @@ def __init__( self.__recommended_num_refs: int = 50 self.__active_requests = 0 - self.__active_requests_lock = threading.Lock() # dynamic batching self.__time_last_scale_up: float = 0 @@ -253,64 +265,125 @@ def _shutdown(self) -> None: time.sleep(0.01) # copy the results to the public results - self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results - self.__results_for_wrapper_backup.failed_objects = self.__results_for_wrapper.failed_objects - self.__results_for_wrapper_backup.failed_references = ( - self.__results_for_wrapper.failed_references - ) - self.__results_for_wrapper_backup.imported_shards = ( - self.__results_for_wrapper.imported_shards - ) + with self.__results_lock: + self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results + self.__results_for_wrapper_backup.failed_objects = ( + self.__results_for_wrapper.failed_objects + ) + self.__results_for_wrapper_backup.failed_references = ( + self.__results_for_wrapper.failed_references + ) + self.__results_for_wrapper_backup.imported_shards = ( + self.__results_for_wrapper.imported_shards + ) def __batch_send(self) -> None: refresh_time: float = 0.01 + futures: list[_BatchFuture] = [] + max_wait = 5 + now = time.time() while ( self.__shut_background_thread_down is not None and not self.__shut_background_thread_down.is_set() ): - if isinstance(self.__batching_mode, _RateLimitedBatching): - if ( - time.time() - self.__time_stamp_last_request - < self.__fix_rate_batching_base_time // self.__concurrent_requests - ): - time.sleep(1) - continue - refresh_time = 0 - elif isinstance(self.__batching_mode, _DynamicBatching) and self.__vectorizer_batching: - if self.__dynamic_batching_sleep_time > 0: + while ( + len(self.__batch_objects) < self.__recommended_num_objects + and len(self.__batch_references) < self.__recommended_num_refs + and time.time() - now < max_wait + and not self.__flushing + ): + time.sleep(refresh_time) + + if len(futures) < self.__concurrent_requests and len(self.__batch_objects) > 0: + if isinstance(self.__batching_mode, _RateLimitedBatching): if ( time.time() - self.__time_stamp_last_request - < self.__dynamic_batching_sleep_time + < self.__fix_rate_batching_base_time // self.__concurrent_requests ): + logger.error( + { + "message": "Rate limit reached. Sleeping for 1 second.", + } + ) time.sleep(1) continue - - if ( - self.__active_requests < self.__concurrent_requests - and len(self.__batch_objects) + len(self.__batch_references) > 0 - ): + refresh_time = 0 + elif ( + isinstance(self.__batching_mode, _DynamicBatching) + and self.__vectorizer_batching + ): + if self.__dynamic_batching_sleep_time > 0: + if ( + time.time() - self.__time_stamp_last_request + < self.__dynamic_batching_sleep_time + ): + time.sleep(1) + continue self.__time_stamp_last_request = time.time() - self._batch_send = True - self.__active_requests_lock.acquire() - self.__active_requests += 1 - self.__active_requests_lock.release() - objs = self.__batch_objects.pop_items(self.__recommended_num_objects) - self.__uuid_lookup_lock.acquire() - refs = self.__batch_references.pop_items( - self.__recommended_num_refs, uuid_lookup=self.__uuid_lookup - ) - self.__uuid_lookup_lock.release() - # do not block the thread - the results are written to a central (locked) list and we want to have multiple concurrent batch-requests - self.__loop.schedule( - self.__send_batch, - objs, - refs, - readd_rate_limit=isinstance(self.__batching_mode, _RateLimitedBatching), - ) - - time.sleep(refresh_time) + if len(objs) > 0: + futures.append( + self.__loop.schedule( + self.__send_objs, + objs, + ) + ) + self.__active_requests += 1 + elif len(futures) < self.__concurrent_requests and len(self.__batch_references) > 0: + with self.__uuid_lookup_lock: + refs = self.__batch_references.pop_items( + self.__recommended_num_refs, uuid_lookup=self.__uuid_lookup + ) + if len(refs) > 0: + futures.append(self.__loop.schedule(self.__send_refs, refs)) + self.__active_requests += 1 + elif len(futures) > 0: + # wait for at least one of the futures to be done because len(futures) == self.__concurrent_requests + while True: + if self.__handle_futures(futures, break_early=True): + break + elif self.__flushing: + # wait for all futures to be done + while True: + self.__handle_futures(futures, break_early=False) + if len(futures) == 0: + self.__flushing = False + break + + # check if any of the futures are done before looping again + if len(futures) > 0: + self.__handle_futures(futures, break_early=False) + else: + time.sleep(refresh_time) + + # if thread is shutting down, wait for all futures to return + start = time.time() + if len(futures) > 0: + while True: + self.__handle_futures(futures, break_early=False) + if len(futures) == 0: + break + if time.time() - start > 10: # wait for 10 seconds max + break + + def __handle_futures(self, futures: List[_BatchFuture], break_early: bool) -> bool: + for i, future in enumerate(futures): + if future.done(): + self.__active_requests -= 1 + ret = futures.pop(i).result() + if isinstance(ret.response, BatchObjectReturn): + self.__handle_objs( + cast(_SendReturn[BatchObjectReturn], ret), + readd_rate_limit=isinstance(self.__batching_mode, _RateLimitedBatching), + ) + elif isinstance(ret.response, BatchReferenceReturn): + self.__handle_refs(cast(_SendReturn[BatchReferenceReturn], ret)) + else: + raise Exception("Unknown response type") + if break_early: + return True + return False def __dynamic_batch_rate_loop(self) -> None: refresh_time = 1 @@ -336,6 +409,12 @@ def dynamic_batch_rate_wrapper() -> None: try: self.__dynamic_batch_rate_loop() except Exception as e: + logger.error( + { + "message": "There was an error in the dynamic batch rate loop!", + "exception": e, + } + ) self.__bg_thread_exception = e demonDynamic = threading.Thread( @@ -349,6 +428,12 @@ def batch_send_wrapper() -> None: try: self.__batch_send() except Exception as e: + logger.error( + { + "message": "There was an error in the batching loop!", + "exception": e, + } + ) self.__bg_thread_exception = e demonBatchSend = threading.Thread( @@ -450,167 +535,188 @@ def __dynamic_batching(self) -> None: self.__recommended_num_objects = 0 self.__concurrent_requests = 2 - async def __send_batch( - self, objs: List[_BatchObject], refs: List[_BatchReference], readd_rate_limit: bool - ) -> None: - if (n_objs := len(objs)) > 0: - start = time.time() - try: - response_obj = await self.__batch_grpc.objects( - objects=objs, timeout=DEFAULT_REQUEST_TIMEOUT - ) - except Exception as e: - errors_obj = { - idx: ErrorObject(message=repr(e), object_=obj) for idx, obj in enumerate(objs) - } - response_obj = BatchObjectReturn( + async def __send_objs(self, objs: List[_BatchObject]) -> _SendReturn[BatchObjectReturn]: + n = len(objs) + start = time.time() + try: + res = await self.__batch_grpc.objects(objects=objs, timeout=DEFAULT_REQUEST_TIMEOUT) + return _SendReturn( + n=n, + response=res, + start=start, + ) + except Exception as e: + errors_obj = { + idx: ErrorObject(message=repr(e), object_=obj) for idx, obj in enumerate(objs) + } + return _SendReturn( + n=n, + response=BatchObjectReturn( _all_responses=list(errors_obj.values()), elapsed_seconds=time.time() - start, errors=errors_obj, has_errors=True, - ) + ), + start=start, + ) - readded_uuids = set() - readded_objects = [] - highest_retry_count = 0 - for i, err in response_obj.errors.items(): - if ( - ( - "support@cohere.com" in err.message - and ( - "rate limit" in err.message - or "500 error: internal server error" in err.message - ) + def __handle_objs( + self, return_: _SendReturn[BatchObjectReturn], readd_rate_limit: bool + ) -> None: + readded_uuids = set[str]() + readded_objects = [] + highest_retry_count = 0 + response_obj = return_.response + n_objs = return_.n + for i, err in response_obj.errors.items(): + if ( + ( + "support@cohere.com" in err.message + and ( + "rate limit" in err.message + or "500 error: internal server error" in err.message ) - or ( - "OpenAI" in err.message - and ( - "Rate limit reached" in err.message - or "on tokens per min (TPM)" in err.message - or "503 error: Service Unavailable." in err.message - or "500 error: The server had an error while processing your request." - in err.message - ) + ) + or ( + "OpenAI" in err.message + and ( + "Rate limit reached" in err.message + or "on tokens per min (TPM)" in err.message + or "503 error: Service Unavailable." in err.message + or "500 error: The server had an error while processing your request." + in err.message ) - or ("failed with status: 503 error" in err.message) # huggingface - ): - if err.object_.retry_count > highest_retry_count: - highest_retry_count = err.object_.retry_count - - if err.object_.retry_count > 5: - continue # too many retries, give up - err.object_.retry_count += 1 - readded_objects.append(i) - - if len(readded_objects) > 0: - _Warnings.batch_rate_limit_reached( - response_obj.errors[readded_objects[0]].message, - self.__fix_rate_batching_base_time * (highest_retry_count + 1), ) + or ("failed with status: 503 error" in err.message) # huggingface + ): + if err.object_.retry_count > highest_retry_count: + highest_retry_count = err.object_.retry_count + + if err.object_.retry_count > 5: + continue # too many retries, give up + err.object_.retry_count += 1 + readded_objects.append(i) + + if len(readded_objects) > 0: + _Warnings.batch_rate_limit_reached( + response_obj.errors[readded_objects[0]].message, + self.__fix_rate_batching_base_time * (highest_retry_count + 1), + ) - readd_objects = [ - err.object_ for i, err in response_obj.errors.items() if i in readded_objects - ] - readded_uuids = {obj.uuid for obj in readd_objects} - - self.__batch_objects.prepend(readd_objects) - - new_errors = { - i: err for i, err in response_obj.errors.items() if i not in readded_objects - } - response_obj = BatchObjectReturn( - uuids={ - i: uid for i, uid in response_obj.uuids.items() if i not in readded_objects - }, - errors=new_errors, - has_errors=len(new_errors) > 0, - _all_responses=[ - err - for i, err in enumerate(response_obj.all_responses) - if i not in readded_objects - ], - elapsed_seconds=response_obj.elapsed_seconds, + readd_objects = [ + err.object_ for i, err in response_obj.errors.items() if i in readded_objects + ] + readded_uuids = {obj.uuid for obj in readd_objects} + + self.__batch_objects.prepend(readd_objects) + + new_errors = { + i: err for i, err in response_obj.errors.items() if i not in readded_objects + } + response_obj = BatchObjectReturn( + uuids={i: uid for i, uid in response_obj.uuids.items() if i not in readded_objects}, + errors=new_errors, + has_errors=len(new_errors) > 0, + _all_responses=[ + err + for i, err in enumerate(response_obj.all_responses) + if i not in readded_objects + ], + elapsed_seconds=response_obj.elapsed_seconds, + ) + if readd_rate_limit: + # for rate limited batching the timing is handled by the outer loop => no sleep here + self.__time_stamp_last_request = ( + time.time() + self.__fix_rate_batching_base_time * (highest_retry_count + 1) + ) # skip a full minute to recover from the rate limit + self.__fix_rate_batching_base_time += ( + 1 # increase the base time as the current one is too low ) - if readd_rate_limit: - # for rate limited batching the timing is handled by the outer loop => no sleep here - self.__time_stamp_last_request = ( - time.time() + self.__fix_rate_batching_base_time * (highest_retry_count + 1) - ) # skip a full minute to recover from the rate limit - self.__fix_rate_batching_base_time += ( - 1 # increase the base time as the current one is too low - ) - else: - # sleep a bit to recover from the rate limit in other cases - time.sleep(2**highest_retry_count) - self.__uuid_lookup_lock.acquire() + else: + # sleep a bit to recover from the rate limit in other cases + time.sleep(2**highest_retry_count) + + with self.__uuid_lookup_lock: + self.__uuid_lookup.difference_update( + u for uuid in response_obj.uuids.values() if not (u := str(uuid)) in readded_uuids + ) self.__uuid_lookup.difference_update( - obj.uuid for obj in objs if obj.uuid not in readded_uuids + u + for err in response_obj.errors.values() + if err is not None and (u := str(err.original_uuid)) not in readded_uuids ) - self.__uuid_lookup_lock.release() - if (n_obj_errs := len(response_obj.errors)) > 0 and self.__objs_logs_count < 30: - logger.error( - { - "message": f"Failed to send {n_obj_errs} objects in a batch of {n_objs}. Please inspect client.batch.failed_objects or collection.batch.failed_objects for the failed objects.", - } - ) - self.__objs_logs_count += 1 - if self.__objs_logs_count > 30: - logger.error( - { - "message": "There have been more than 30 failed object batches. Further errors will not be logged.", - } - ) - self.__results_lock.acquire() + if (n_obj_errs := len(response_obj.errors)) > 0 and self.__objs_logs_count < 30: + logger.error( + { + "message": f"Failed to send {n_obj_errs} objects in a batch of {n_objs}. Please inspect client.batch.failed_objects or collection.batch.failed_objects for the failed objects.", + } + ) + self.__objs_logs_count += 1 + if self.__objs_logs_count > 30: + logger.error( + { + "message": "There have been more than 30 failed object batches. Further errors will not be logged.", + } + ) + with self.__results_lock: self.__results_for_wrapper.results.objs += response_obj self.__results_for_wrapper.failed_objects.extend(response_obj.errors.values()) - self.__results_lock.release() - self.__took_queue.append(time.time() - start) + self.__took_queue.append(time.time() - return_.start) - if (n_refs := len(refs)) > 0: - start = time.time() - try: - response_ref = await self.__batch_rest.references(references=refs) - except Exception as e: - errors_ref = { - idx: ErrorReference(message=repr(e), reference=ref) - for idx, ref in enumerate(refs) - } - response_ref = BatchReferenceReturn( + async def __send_refs(self, refs: List[_BatchReference]) -> _SendReturn[BatchReferenceReturn]: + n = len(refs) + start = time.time() + try: + return _SendReturn( + n=n, + response=await self.__batch_rest.references(references=refs), + start=start, + ) + except Exception as e: + errors_ref = { + idx: ErrorReference(message=repr(e), reference=ref) for idx, ref in enumerate(refs) + } + return _SendReturn( + n=n, + response=BatchReferenceReturn( elapsed_seconds=time.time() - start, errors=errors_ref, has_errors=True, - ) - if (n_ref_errs := len(response_ref.errors)) > 0 and self.__refs_logs_count < 30: - logger.error( - { - "message": f"Failed to send {n_ref_errs} references in a batch of {n_refs}. Please inspect client.batch.failed_references or collection.batch.failed_references for the failed references.", - "errors": response_ref.errors, - } - ) - self.__refs_logs_count += 1 - if self.__refs_logs_count > 30: - logger.error( - { - "message": "There have been more than 30 failed reference batches. Further errors will not be logged.", - } - ) - self.__results_lock.acquire() + ), + start=start, + ) + + def __handle_refs(self, return_: _SendReturn[BatchReferenceReturn]) -> None: + response_ref = return_.response + n_refs = return_.n + if (n_ref_errs := len(response_ref.errors)) > 0 and self.__refs_logs_count < 30: + logger.error( + { + "message": f"Failed to send {n_ref_errs} references in a batch of {n_refs}. Please inspect client.batch.failed_references or collection.batch.failed_references for the failed references.", + "errors": response_ref.errors, + } + ) + self.__refs_logs_count += 1 + if self.__refs_logs_count > 30: + logger.error( + { + "message": "There have been more than 30 failed reference batches. Further errors will not be logged.", + } + ) + with self.__results_lock: self.__results_for_wrapper.results.refs += response_ref self.__results_for_wrapper.failed_references.extend(response_ref.errors.values()) - self.__results_lock.release() - - self.__active_requests_lock.acquire() - self.__active_requests -= 1 - self.__active_requests_lock.release() def flush(self) -> None: """Flush the batch queue and wait for all requests to be finished.""" # bg thread is sending objs+refs automatically, so simply wait for everything to be done + self.__flushing = True while ( self.__active_requests > 0 or len(self.__batch_objects) > 0 or len(self.__batch_references) > 0 + or self.__flushing ): time.sleep(0.01) self.__check_bg_thread_alive() @@ -641,9 +747,8 @@ def _add_object( ) except ValidationError as e: raise WeaviateBatchValidationError(repr(e)) - self.__uuid_lookup_lock.acquire() - self.__uuid_lookup.add(str(batch_object.uuid)) - self.__uuid_lookup_lock.release() + with self.__uuid_lookup_lock: + self.__uuid_lookup.add(str(batch_object.uuid)) self.__batch_objects.add(batch_object._to_internal()) # block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do