diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py index 64ef3c7c7711..6cb39659a83a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py @@ -5,13 +5,14 @@ import time import queue import logging -from threading import RLock, Condition, Semaphore +from threading import RLock from concurrent.futures import ThreadPoolExecutor from typing import Optional, Callable, TYPE_CHECKING from .._producer import EventHubProducer from .._common import EventDataBatch from ..exceptions import OperationTimeoutError + if TYPE_CHECKING: from .._producer_client import SendEventTypes @@ -21,28 +22,23 @@ class BufferedProducer: # pylint: disable=too-many-instance-attributes def __init__( - self, - producer: EventHubProducer, - partition_id: str, - on_success: Callable[["SendEventTypes", Optional[str]], None], - on_error: Callable[["SendEventTypes", Optional[str], Exception], None], - max_message_size_on_link: int, - executor: ThreadPoolExecutor, - *, - max_wait_time: float = 1, - max_concurrent_sends: int = 1, - max_buffer_length: int = 10 + self, + producer: EventHubProducer, + partition_id: str, + on_success: Callable[["SendEventTypes", Optional[str]], None], + on_error: Callable[["SendEventTypes", Optional[str], Exception], None], + max_message_size_on_link: int, + executor: ThreadPoolExecutor, + *, + max_wait_time: float = 1, + max_buffer_length: int ): self._buffered_queue: queue.Queue = queue.Queue() + self._max_buffer_len = max_buffer_length self._cur_buffered_len = 0 self._executor: ThreadPoolExecutor = executor self._producer: EventHubProducer = producer self._lock = RLock() - self._not_empty = Condition(self._lock) - self._not_full = Condition(self._lock) - self._max_buffer_len = max_buffer_length - self._max_concurrent_sends = max_concurrent_sends - self._max_concurrent_sends_semaphore = Semaphore(self._max_concurrent_sends) self._max_wait_time = max_wait_time self._on_success = self.failsafe_callback(on_success) self._on_error = self.failsafe_callback(on_error) @@ -62,72 +58,61 @@ def start(self): self._check_max_wait_time_future = self._executor.submit(self.check_max_wait_time_worker) def stop(self, flush=True, timeout_time=None, raise_error=False): + self._running = False if flush: - self.flush(timeout_time=timeout_time, raise_error=raise_error) + with self._lock: + self.flush(timeout_time=timeout_time, raise_error=raise_error) else: if self._cur_buffered_len: _LOGGER.warning( "Shutting down Partition %r. There are still %r events in the buffer which will be lost", self.partition_id, - self._cur_buffered_len + self._cur_buffered_len, ) if self._check_max_wait_time_future: remain_timeout = timeout_time - time.time() if timeout_time else None try: - with self._not_empty: - # in the stop procedure, calling notify to give check_max_wait_time_future a chance to stop - # as it is waiting for Condition self._not_empty - self._not_empty.notify() self._check_max_wait_time_future.result(remain_timeout) except Exception as exc: # pylint: disable=broad-except - _LOGGER.warning( - "Partition %r stopped with error %r", - self.partition_id, - exc - ) + _LOGGER.warning("Partition %r stopped with error %r", self.partition_id, exc) self._producer.close() def put_events(self, events, timeout_time=None): # Put single event or EventDataBatch into the queue. # This method would raise OperationTimeout if the queue does not have enough space for the input and # flush cannot finish in timeout. - with self._not_full: - try: - new_events_len = len(events) - except TypeError: - new_events_len = 1 - - if self._max_buffer_len - self._cur_buffered_len < new_events_len: - _LOGGER.info( - "The buffer for partition %r is full. Attempting to flush before adding %r events.", - self.partition_id, - new_events_len - ) - # flush the buffer + try: + new_events_len = len(events) + except TypeError: + new_events_len = 1 + if self._max_buffer_len - self._cur_buffered_len < new_events_len: + _LOGGER.info( + "The buffer for partition %r is full. Attempting to flush before adding %r events.", + self.partition_id, + new_events_len, + ) + # flush the buffer + with self._lock: self.flush(timeout_time=timeout_time) - - if timeout_time and time.time() > timeout_time: - raise OperationTimeoutError("Failed to enqueue events into buffer due to timeout.") - - try: - # add single event into current batch - self._cur_batch.add(events) - except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer - # if there are events in cur_batch, enqueue cur_batch to the buffer - if self._cur_batch: - self._buffered_queue.put(self._cur_batch) - self._buffered_queue.put(events) - # create a new batch for incoming events - self._cur_batch = EventDataBatch(self._max_message_size_on_link) - except ValueError: - # add single event exceeds the cur batch size, create new batch + if timeout_time and time.time() > timeout_time: + raise OperationTimeoutError("Failed to enqueue events into buffer due to timeout.") + try: + # add single event into current batch + self._cur_batch.add(events) + except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer + # if there are events in cur_batch, enqueue cur_batch to the buffer + if self._cur_batch: self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link) - self._cur_batch.add(events) - self._cur_buffered_len += new_events_len - # notify the max_wait_time worker - self._not_empty.notify() + self._buffered_queue.put(events) + # create a new batch for incoming events + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + except ValueError: + # add single event exceeds the cur batch size, create new batch + self._buffered_queue.put(self._cur_batch) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch.add(events) + self._cur_buffered_len += new_events_len def failsafe_callback(self, callback): def wrapper_callback(*args, **kwargs): @@ -135,10 +120,7 @@ def wrapper_callback(*args, **kwargs): callback(*args, **kwargs) except Exception as exc: # pylint: disable=broad-except _LOGGER.warning( - "On partition %r, callback %r encountered exception %r", - callback.__name__, - exc, - self.partition_id + "On partition %r, callback %r encountered exception %r", callback.__name__, exc, self.partition_id ) return wrapper_callback @@ -147,50 +129,35 @@ def flush(self, timeout_time=None, raise_error=True): # pylint: disable=protected-access # try flushing all the buffered batch within given time _LOGGER.info("Partition: %r started flushing.", self.partition_id) - with self._not_empty: - if self._cur_batch: # if there is batch, enqueue it to the buffer first - self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link) - while self._cur_buffered_len: - remaining_time = timeout_time - time.time() if timeout_time else None - # If flush could get the semaphore, perform sending - if ((remaining_time and remaining_time > 0) or remaining_time is None) and \ - self._max_concurrent_sends_semaphore.acquire(timeout=remaining_time): - batch = self._buffered_queue.get() - self._buffered_queue.task_done() - try: - _LOGGER.info("Partition %r is sending.", self.partition_id) - self._producer.send( - batch, - timeout=timeout_time - time.time() if timeout_time else None - ) - _LOGGER.info( - "Partition %r sending %r events succeeded.", - self.partition_id, - len(batch) - ) - self._on_success(batch._internal_events, self.partition_id) - except Exception as exc: # pylint: disable=broad-except - _LOGGER.info( - "Partition %r sending %r events failed due to exception: %r ", - self.partition_id, - len(batch), - exc - ) - self._on_error(batch._internal_events, self.partition_id, exc) - finally: - self._cur_buffered_len -= len(batch) - self._max_concurrent_sends_semaphore.release() - self._not_full.notify() - # If flush could not get the semaphore, we log and raise error if wanted - else: + if self._cur_batch: # if there is batch, enqueue it to the buffer first + self._buffered_queue.put(self._cur_batch) + while self._cur_buffered_len: + remaining_time = timeout_time - time.time() if timeout_time else None + if (remaining_time and remaining_time > 0) or remaining_time is None: + batch = self._buffered_queue.get() + self._buffered_queue.task_done() + try: + _LOGGER.info("Partition %r is sending.", self.partition_id) + self._producer.send(batch, timeout=timeout_time - time.time() if timeout_time else None) + _LOGGER.info("Partition %r sending %r events succeeded.", self.partition_id, len(batch)) + self._on_success(batch._internal_events, self.partition_id) + except Exception as exc: # pylint: disable=broad-except _LOGGER.info( - "Partition %r fails to flush due to timeout.", - self.partition_id + "Partition %r sending %r events failed due to exception: %r ", + self.partition_id, + len(batch), + exc, + ) + self._on_error(batch._internal_events, self.partition_id, exc) + finally: + self._cur_buffered_len -= len(batch) + else: + _LOGGER.info("Partition %r fails to flush due to timeout.", self.partition_id) + if raise_error: + raise OperationTimeoutError( + "Failed to flush {!r} within {}".format(self.partition_id, timeout_time) ) - if raise_error: - raise OperationTimeoutError("Failed to flush {!r}".format(self.partition_id)) - break + break # after finishing flushing, reset cur batch and put it into the buffer self._last_send_time = time.time() self._cur_batch = EventDataBatch(self._max_message_size_on_link) @@ -198,15 +165,17 @@ def flush(self, timeout_time=None, raise_error=True): def check_max_wait_time_worker(self): while self._running: - with self._not_empty: - if not self._cur_buffered_len: - _LOGGER.info("Partition %r worker is awaiting data.", self.partition_id) - self._not_empty.wait() + if self._cur_buffered_len > 0: now_time = time.time() _LOGGER.info("Partition %r worker is checking max_wait_time.", self.partition_id) - if now_time - self._last_send_time > self._max_wait_time and self._running: + # flush the partition if the producer is running beyond the waiting time + # or the buffer is at max capacity + if (now_time - self._last_send_time > self._max_wait_time) or ( + self._cur_buffered_len >= self._max_buffer_len + ): # in the worker, not raising error for flush, users can not handle this - self.flush(raise_error=False) + with self._lock: + self.flush(raise_error=False) time.sleep(min(self._max_wait_time, 5)) @property diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py index 7c8777ce16fb..3f58dc0f70f3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py @@ -5,7 +5,7 @@ import logging from threading import Lock from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Optional, List, Callable, TYPE_CHECKING +from typing import Dict, Optional, List, Callable, Union, TYPE_CHECKING from ._partition_resolver import PartitionResolver from ._buffered_producer import BufferedProducer @@ -31,9 +31,7 @@ def __init__( *, max_buffer_length: int = 1500, max_wait_time: float = 1, - max_concurrent_sends: int = 1, - executor: Optional[ThreadPoolExecutor] = None, - max_worker: Optional[int] = None + executor: Optional[Union[ThreadPoolExecutor, int]] = None ): self._buffered_producers: Dict[str, BufferedProducer] = {} self._partition_ids: List[str] = partitions @@ -46,9 +44,15 @@ def __init__( self._partition_resolver = PartitionResolver(self._partition_ids) self._max_wait_time = max_wait_time self._max_buffer_length = max_buffer_length - self._max_concurrent_sends = max_concurrent_sends - self._existing_executor = bool(executor) - self._executor = executor or ThreadPoolExecutor(max_worker) + self._existing_executor = False + + if not executor: + self._executor = ThreadPoolExecutor() + elif isinstance(executor, ThreadPoolExecutor): + self._existing_executor = True + self._executor = executor + elif isinstance(executor, int): + self._executor = ThreadPoolExecutor(executor) def _get_partition_id(self, partition_id, partition_key): if partition_id: @@ -77,7 +81,6 @@ def enqueue_events(self, events, *, partition_id=None, partition_key=None, timeo self._max_message_size_on_link, executor=self._executor, max_wait_time=self._max_wait_time, - max_concurrent_sends=self._max_concurrent_sends, max_buffer_length=self._max_buffer_length ) buffered_producer.start() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 3ccdcbd76a43..aa9ada4764af 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from concurrent.futures import ThreadPoolExecutor import logging import threading import time @@ -37,7 +38,7 @@ _LOGGER = logging.getLogger(__name__) -class EventHubProducerClient(ClientBase): # pylint: disable=client-accepts-api-version-keyword +class EventHubProducerClient(ClientBase): # pylint: disable=client-accepts-api-version-keyword # pylint: disable=too-many-instance-attributes """The EventHubProducerClient class defines a high level interface for sending events to the Azure Event Hubs service. @@ -53,6 +54,10 @@ class EventHubProducerClient(ClientBase): # pylint: disable=client-accepts-api or ~azure.core.credentials.AzureNamedKeyCredential :keyword bool buffered_mode: If True, the producer client will collect events in a buffer, efficiently batch, then publish. Default is False. + :keyword Union[ThreadPoolExecutor, int] buffer_concurrency: The ThreadPoolExecutor to be used for publishing events + or the number of workers for the ThreadPoolExecutor. + Default is none and a ThreadPoolExecutor with the default number of workers will be created + per https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor :keyword on_success: The callback to be called once a batch has been successfully published. The callback takes two parameters: - `events`: The list of events that have been successfully published @@ -147,6 +152,7 @@ def __init__( credential: "CredentialTypes", *, buffered_mode: Literal[True], + buffer_concurrency: Union[ThreadPoolExecutor, int] = None, on_error: Callable[[SendEventTypes, Optional[str], Exception], None], on_success: Callable[[SendEventTypes, Optional[str]], None], max_buffer_length: int = 1500, @@ -175,9 +181,7 @@ def __init__( network_tracing=kwargs.get("logging_enable"), **kwargs ) - self._producers = { - ALL_PARTITIONS: self._create_producer() - } # type: Dict[str, Optional[EventHubProducer]] + self._producers = {ALL_PARTITIONS: self._create_producer()} # type: Dict[str, Optional[EventHubProducer]] self._max_message_size_on_link = 0 self._partition_ids = None # Optional[List[str]] self._lock = threading.Lock() @@ -187,10 +191,7 @@ def __init__( self._buffered_producer_dispatcher = None self._max_wait_time = max_wait_time self._max_buffer_length = max_buffer_length - # the following two parameters are not part of the public api yet - # which could be exposed in the future if needed - self._executor = kwargs.get("executor") - self._max_worker = kwargs.get("max_worker") + self._executor = kwargs.get("buffer_concurrency") if self._buffered_mode: setattr(self, "send_batch", self._buffered_send_batch) @@ -211,6 +212,8 @@ def __init__( self._max_buffer_length = 1500 if self._max_buffer_length <= 0: raise ValueError("'max_buffer_length' must be an integer greater than 0 in buffered mode") + if isinstance(self._executor, int) and self._executor <= 0: + raise ValueError("'buffer_concurrency' must be an integer greater than 0 in buffered mode") def __enter__(self): return self @@ -233,9 +236,8 @@ def _buffered_send(self, events, **kwargs): self._max_message_size_on_link, max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, - executor=self._executor, - max_worker=self._max_worker - ) + executor=self._executor + ) self._buffered_producer_dispatcher.enqueue_events(events, **kwargs) def _batch_preparer(self, event_data_batch, **kwargs): @@ -250,12 +252,8 @@ def _batch_preparer(self, event_data_batch, **kwargs): ) to_send_batch = event_data_batch else: - to_send_batch = self.create_batch( - partition_id=partition_id, partition_key=partition_key - ) - to_send_batch._load_events( # pylint:disable=protected-access - event_data_batch - ) + to_send_batch = self.create_batch(partition_id=partition_id, partition_key=partition_key) + to_send_batch._load_events(event_data_batch) # pylint:disable=protected-access return to_send_batch, to_send_batch._partition_id, partition_key # pylint:disable=protected-access @@ -268,12 +266,7 @@ def _buffered_send_batch(self, event_data_batch, **kwargs): timeout = kwargs.get("timeout") timeout_time = time.time() + timeout if timeout else None - self._buffered_send( - event_data_batch, - partition_id=pid, - partition_key=pkey, - timeout_time=timeout_time - ) + self._buffered_send(event_data_batch, partition_id=pid, partition_key=pkey, timeout_time=timeout_time) def _buffered_send_event(self, event, **kwargs): partition_key = kwargs.get("partition_key") @@ -281,10 +274,7 @@ def _buffered_send_event(self, event, **kwargs): timeout = kwargs.get("timeout") timeout_time = time.time() + timeout if timeout else None self._buffered_send( - event, - partition_id=kwargs.get("partition_id"), - partition_key=partition_key, - timeout_time=timeout_time + event, partition_id=kwargs.get("partition_id"), partition_key=partition_key, timeout_time=timeout_time ) def _get_partitions(self): @@ -299,13 +289,9 @@ def _get_max_message_size(self): # pylint: disable=protected-access,line-too-long with self._lock: if not self._max_message_size_on_link: - cast( - EventHubProducer, self._producers[ALL_PARTITIONS] - )._open_with_retry() + cast(EventHubProducer, self._producers[ALL_PARTITIONS])._open_with_retry() self._max_message_size_on_link = ( - self._producers[ # type: ignore - ALL_PARTITIONS - ]._handler.message_handler._link.peer_max_message_size + self._producers[ALL_PARTITIONS]._handler.message_handler._link.peer_max_message_size # type: ignore or constants.MAX_MESSAGE_LENGTH_BYTES ) @@ -313,33 +299,21 @@ def _start_producer(self, partition_id, send_timeout): # type: (str, Optional[Union[int, float]]) -> None with self._lock: self._get_partitions() - if ( - partition_id not in cast(List[str], self._partition_ids) - and partition_id != ALL_PARTITIONS - ): + if partition_id not in cast(List[str], self._partition_ids) and partition_id != ALL_PARTITIONS: raise ConnectError( - "Invalid partition {} for the event hub {}".format( - partition_id, self.eventhub_name - ) + "Invalid partition {} for the event hub {}".format(partition_id, self.eventhub_name) ) - if ( - not self._producers[partition_id] - or cast(EventHubProducer, self._producers[partition_id]).closed - ): + if not self._producers[partition_id] or cast(EventHubProducer, self._producers[partition_id]).closed: self._producers[partition_id] = self._create_producer( - partition_id=( - None if partition_id == ALL_PARTITIONS else partition_id - ), + partition_id=(None if partition_id == ALL_PARTITIONS else partition_id), send_timeout=send_timeout, ) def _create_producer(self, partition_id=None, send_timeout=None): # type: (Optional[str], Optional[Union[int, float]]) -> EventHubProducer target = "amqps://{}{}".format(self._address.hostname, self._address.path) - send_timeout = ( - self._config.send_timeout if send_timeout is None else send_timeout - ) + send_timeout = self._config.send_timeout if send_timeout is None else send_timeout handler = EventHubProducer( self, @@ -746,9 +720,7 @@ def get_partition_properties(self, partition_id): :rtype: Dict[str, Any] :raises: :class:`EventHubError` """ - return super(EventHubProducerClient, self)._get_partition_properties( - partition_id - ) + return super(EventHubProducerClient, self)._get_partition_properties(partition_id) def flush(self, **kwargs: Any) -> None: """ @@ -765,12 +737,7 @@ def flush(self, **kwargs: Any) -> None: timeout_time = time.time() + timeout if timeout else None self._buffered_producer_dispatcher.flush(timeout_time=timeout_time) - def close( - self, - *, - flush: bool = True, - **kwargs: Any - ) -> None: + def close(self, *, flush: bool = True, **kwargs: Any) -> None: """Close the Producer client underlying AMQP connection and links. :keyword bool flush: Buffered mode only. If set to True, events in the buffer will be sent @@ -816,10 +783,9 @@ def get_buffered_event_count(self, partition_id: str) -> Optional[int]: return None try: - return cast( - BufferedProducerDispatcher, - self._buffered_producer_dispatcher - ).get_buffered_event_count(partition_id) + return cast(BufferedProducerDispatcher, self._buffered_producer_dispatcher).get_buffered_event_count( + partition_id + ) except AttributeError: return 0 @@ -835,9 +801,6 @@ def total_buffered_event_count(self) -> Optional[int]: return None try: - return cast( - BufferedProducerDispatcher, - self._buffered_producer_dispatcher - ).total_buffered_event_count + return cast(BufferedProducerDispatcher, self._buffered_producer_dispatcher).total_buffered_event_count except AttributeError: return 0 diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py index b170fbdad140..93efa215bcfb 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py @@ -6,10 +6,9 @@ import logging import queue import time -from asyncio import Lock, Condition, Semaphore +from asyncio import Lock from typing import Optional, Callable, Awaitable, TYPE_CHECKING -from .._async_utils import semaphore_acquire_with_timeout from .._producer_async import EventHubProducer from ..._common import EventDataBatch from ...exceptions import OperationTimeoutError @@ -23,26 +22,21 @@ class BufferedProducer: # pylint: disable=too-many-instance-attributes def __init__( - self, - producer: EventHubProducer, - partition_id: str, - on_success: Callable[["SendEventTypes", Optional[str]], Awaitable[None]], - on_error: Callable[["SendEventTypes", Optional[str], Exception], Awaitable[None]], - max_message_size_on_link: int, - *, - max_wait_time: float = 1, - max_concurrent_sends: int = 1, - max_buffer_length: int = 10 + self, + producer: EventHubProducer, + partition_id: str, + on_success: Callable[["SendEventTypes", Optional[str]], Awaitable[None]], + on_error: Callable[["SendEventTypes", Optional[str], Exception], Awaitable[None]], + max_message_size_on_link: int, + *, + max_wait_time: float = 1, + max_buffer_length: int ): self._buffered_queue: queue.Queue = queue.Queue() + self._max_buffer_len = max_buffer_length self._cur_buffered_len = 0 self._producer: EventHubProducer = producer self._lock = Lock() - self._not_empty = Condition(self._lock) - self._not_full = Condition(self._lock) - self._max_buffer_len = max_buffer_length - self._max_concurrent_sends = max_concurrent_sends - self._max_concurrent_sends_semaphore = Semaphore(self._max_concurrent_sends) self._max_wait_time = max_wait_time self._on_success = self.failsafe_callback(on_success) self._on_error = self.failsafe_callback(on_error) @@ -64,28 +58,20 @@ async def start(self): async def stop(self, flush=True, timeout_time=None, raise_error=False): self._running = False if flush: - await self.flush(timeout_time=timeout_time, raise_error=raise_error) + async with self._lock: + await self.flush(timeout_time=timeout_time, raise_error=raise_error) else: if self._cur_buffered_len: _LOGGER.warning( - "Shutting down Partition %r." - " There are still %r events in the buffer which will be lost", + "Shutting down Partition %r." " There are still %r events in the buffer which will be lost", self.partition_id, - self._cur_buffered_len + self._cur_buffered_len, ) if self._check_max_wait_time_future: try: - async with self._not_empty: - # in the stop procedure, calling notify to give check_max_wait_time_future a chance to stop - # as it is waiting for Condition self._not_empty - self._not_empty.notify() await self._check_max_wait_time_future except Exception as exc: # pylint: disable=broad-except - _LOGGER.warning( - "Partition %r stopped with error %r", - self.partition_id, - exc - ) + _LOGGER.warning("Partition %r stopped with error %r", self.partition_id, exc) await self._producer.close() async def put_events(self, events, timeout_time=None): @@ -101,33 +87,30 @@ async def put_events(self, events, timeout_time=None): _LOGGER.info( "The buffer for partition %r is full. Attempting to flush before adding %r events.", self.partition_id, - new_events_len + new_events_len, ) # flush the buffer - await self.flush(timeout_time=timeout_time) - - async with self._not_full: - if timeout_time and time.time() > timeout_time: - raise OperationTimeoutError("Failed to enqueue events into buffer due to timeout.") + async with self._lock: + await self.flush(timeout_time=timeout_time) - try: - # add single event into current batch - self._cur_batch.add(events) - except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer - # if there are events in cur_batch, enqueue cur_batch to the buffer - if self._cur_batch: - self._buffered_queue.put(self._cur_batch) - self._buffered_queue.put(events) - # create a new batch for incoming events - self._cur_batch = EventDataBatch(self._max_message_size_on_link) - except ValueError: - # add single event exceeds the cur batch size, create new batch + if timeout_time and time.time() > timeout_time: + raise OperationTimeoutError("Failed to enqueue events into buffer due to timeout.") + try: + # add single event into current batch + self._cur_batch.add(events) + except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer + # if there are events in cur_batch, enqueue cur_batch to the buffer + if self._cur_batch: self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link) - self._cur_batch.add(events) - self._cur_buffered_len += new_events_len - # notify the max_wait_time worker - self._not_empty.notify() + self._buffered_queue.put(events) + # create a new batch for incoming events + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + except ValueError: + # add single event exceeds the cur batch size, create new batch + self._buffered_queue.put(self._cur_batch) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch.add(events) + self._cur_buffered_len += new_events_len def failsafe_callback(self, callback): async def wrapper_callback(*args, **kwargs): @@ -135,10 +118,7 @@ async def wrapper_callback(*args, **kwargs): await callback(*args, **kwargs) except Exception as exc: # pylint: disable=broad-except _LOGGER.warning( - "On partition %r, callback %r encountered exception %r", - callback.__name__, - exc, - self.partition_id + "On partition %r, callback %r encountered exception %r", callback.__name__, exc, self.partition_id ) return wrapper_callback @@ -147,53 +127,34 @@ async def flush(self, timeout_time=None, raise_error=True): # pylint: disable=protected-access # try flushing all the buffered batch within given time _LOGGER.info("Partition: %r started flushing.", self.partition_id) - async with self._not_empty: - if self._cur_batch: # if there is batch, enqueue it to the buffer first - self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link) - while self._cur_buffered_len: - remaining_time = timeout_time - time.time() if timeout_time else None - # If flush could get the semaphore, perform sending - if ((remaining_time and remaining_time > 0) or remaining_time is None) and \ - await semaphore_acquire_with_timeout( - self._max_concurrent_sends_semaphore, - timeout=remaining_time - ): - batch = self._buffered_queue.get() - self._buffered_queue.task_done() - try: - _LOGGER.info("Partition %r is sending.", self.partition_id) - await self._producer.send( - batch, - timeout=timeout_time - time.time() if timeout_time else None - ) - _LOGGER.info( - "Partition %r sending %r events succeeded.", - self.partition_id, - len(batch) - ) - await self._on_success(batch._internal_events, self.partition_id) - except Exception as exc: # pylint: disable=broad-except - _LOGGER.info( - "Partition %r sending %r events failed due to exception: %r", - self.partition_id, - len(batch), - exc - ) - await self._on_error(batch._internal_events, self.partition_id, exc) - finally: - self._cur_buffered_len -= len(batch) - self._max_concurrent_sends_semaphore.release() - self._not_full.notify() - # If flush could not get the semaphore, we log and raise error if wanted - else: + if self._cur_batch: # if there is batch, enqueue it to the buffer first + self._buffered_queue.put(self._cur_batch) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + while self._cur_buffered_len: + remaining_time = timeout_time - time.time() if timeout_time else None + if (remaining_time and remaining_time > 0) or remaining_time is None: + batch = self._buffered_queue.get() + self._buffered_queue.task_done() + try: + _LOGGER.info("Partition %r is sending.", self.partition_id) + await self._producer.send(batch, timeout=timeout_time - time.time() if timeout_time else None) + _LOGGER.info("Partition %r sending %r events succeeded.", self.partition_id, len(batch)) + await self._on_success(batch._internal_events, self.partition_id) + except Exception as exc: # pylint: disable=broad-except _LOGGER.info( - "Partition %r fails to flush due to timeout.", - self.partition_id + "Partition %r sending %r events failed due to exception: %r", self.partition_id, len(batch), exc + ) + await self._on_error(batch._internal_events, self.partition_id, exc) + finally: + self._cur_buffered_len -= len(batch) + # If flush could not get the semaphore, we log and raise error if wanted + else: + _LOGGER.info("Partition %r fails to flush due to timeout.", self.partition_id) + if raise_error: + raise OperationTimeoutError( + "Failed to flush {!r} within {}".format(self.partition_id, timeout_time) ) - if raise_error: - raise OperationTimeoutError("Failed to flush {!r}".format(self.partition_id)) - break + break # after finishing flushing, reset cur batch and put it into the buffer self._last_send_time = time.time() self._cur_batch = EventDataBatch(self._max_message_size_on_link) @@ -201,15 +162,16 @@ async def flush(self, timeout_time=None, raise_error=True): async def check_max_wait_time_worker(self): while self._running: - async with self._not_empty: - if not self._cur_buffered_len: - _LOGGER.info("Partition %r worker is awaiting data.", self.partition_id) - await self._not_empty.wait() - now_time = time.time() - _LOGGER.info("Partition %r worker is checking max_wait_time.", self.partition_id) - if now_time - self._last_send_time > self._max_wait_time and self._running: - # in the worker, not raising error for flush, users can not handle this - await self.flush(raise_error=False) + if self._max_buffer_len > 0: + now_time = time.time() + _LOGGER.info("Partition %r worker is checking max_wait_time.", self.partition_id) + # flush the partition if its beyond the waiting time or the buffer is at max capacity + if (now_time - self._last_send_time > self._max_wait_time and self._running) or ( + self._cur_buffered_len >= self._max_buffer_len and self._running + ): + # in the worker, not raising error for flush, users can not handle this + async with self._lock: + await self.flush(raise_error=False) await asyncio.sleep(min(self._max_wait_time, 5)) @property diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py index 619c28f5f2a9..6d67d2fd8ab0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py @@ -31,7 +31,6 @@ def __init__( *, max_buffer_length: int = 1500, max_wait_time: float = 1, - max_concurrent_sends: int = 1 ): self._buffered_producers: Dict[str, BufferedProducer] = {} self._partition_ids: List[str] = partitions @@ -44,7 +43,6 @@ def __init__( self._partition_resolver = PartitionResolver(self._partition_ids) self._max_wait_time = max_wait_time self._max_buffer_length = max_buffer_length - self._max_concurrent_sends = max_concurrent_sends async def _get_partition_id(self, partition_id, partition_key): if partition_id: @@ -72,7 +70,6 @@ async def enqueue_events(self, events, *, partition_id=None, partition_key=None, self._on_error, self._max_message_size_on_link, max_wait_time=self._max_wait_time, - max_concurrent_sends=self._max_concurrent_sends, max_buffer_length=self._max_buffer_length ) await buffered_producer.start()