From 9a34e81c12f6e07c0fa614372df918b50d9883b4 Mon Sep 17 00:00:00 2001 From: swathipil <76007337+swathipil@users.noreply.github.com> Date: Thu, 2 Jun 2022 08:58:54 -0700 Subject: [PATCH] [EventHubs] merge Buffered Producer into main (#24653) * update buffered producer changelog and version (#24210) * [EventHubs] Buffered Producer (#24362) * clean up, remove conditions, semaphores * minor fix * remove semaphores, conditions * minor fixes * minor changs on queue length * expose buffer_concurrency * remove max_concurrent_sends * make buffer size reqd * remove comment * add locks around flush * use the right counter to track q size * use the correct count for the q * locks and right q size var for async * clean imports * lock for bg worker * formatting fixes for pylin * final review * fix pylint issues * lint + version * remove semaphore tests * skip tests that flush then close * fix for lock issue * unskip tests * more async updates Co-authored-by: Kashif Khan <361477+kashifkhan@users.noreply.github.com> Co-authored-by: Kashif Khan --- sdk/eventhub/azure-eventhub/CHANGELOG.md | 22 +- .../eventhub/_buffered_producer/__init__.py | 13 + .../_buffered_producer/_buffered_producer.py | 183 ++++++ .../_buffered_producer_dispatcher.py | 157 +++++ .../_buffered_producer/_partition_resolver.py | 197 +++++++ .../azure-eventhub/azure/eventhub/_common.py | 7 +- .../azure/eventhub/_producer.py | 13 +- .../azure/eventhub/_producer_client.py | 536 +++++++++++++++--- .../azure-eventhub/azure/eventhub/_utils.py | 24 +- .../azure-eventhub/azure/eventhub/_version.py | 2 +- .../azure/eventhub/aio/_async_utils.py | 11 +- .../aio/_buffered_producer/__init__.py | 13 + .../_buffered_producer_async.py | 183 ++++++ .../_buffered_producer_dispatcher_async.py | 140 +++++ .../_partition_resolver_async.py | 31 + .../azure/eventhub/aio/_producer_async.py | 7 +- .../eventhub/aio/_producer_client_async.py | 500 ++++++++++++++-- sdk/eventhub/azure-eventhub/samples/README.md | 5 + .../async_samples/send_buffered_mode_async.py | 61 ++ .../sync_samples/send_buffered_mode.py | 56 ++ .../test_buffered_producer_async.py | 492 ++++++++++++++++ .../asynctests/test_negative_async.py | 20 + .../livetest/asynctests/test_send_async.py | 163 +++++- .../synctests/test_buffered_producer.py | 498 ++++++++++++++++ .../tests/livetest/synctests/test_negative.py | 20 + .../tests/livetest/synctests/test_send.py | 182 +++++- .../test_partition_resolver_async.py | 44 ++ .../tests/unittest/test_partition_resolver.py | 61 ++ 28 files changed, 3442 insertions(+), 199 deletions(-) create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/__init__.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_partition_resolver.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/__init__.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_partition_resolver_async.py create mode 100644 sdk/eventhub/azure-eventhub/samples/async_samples/send_buffered_mode_async.py create mode 100644 sdk/eventhub/azure-eventhub/samples/sync_samples/send_buffered_mode.py create mode 100644 sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_buffered_producer_async.py create mode 100644 sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py create mode 100644 sdk/eventhub/azure-eventhub/tests/unittest/asynctests/test_partition_resolver_async.py create mode 100644 sdk/eventhub/azure-eventhub/tests/unittest/test_partition_resolver.py diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index ddc71cc7acfe..a0f1716105c3 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,14 +1,24 @@ # Release History -## 5.9.1 (Unreleased) +## 5.10.0 (2022-06-07) ### Features Added -### Breaking Changes - -### Bugs Fixed - -### Other Changes +- Includes the following features related to buffered sending of events: + - A new method `send_event` to `EventHubProducerClient` which allows sending single `EventData` or `AmqpAnnotatedMessage`. + - Buffered mode sending to `EventHubProducerClient` which is intended to allow for efficient publishing of events + without having to explicitly manage batches in the application. + - The constructor of `EventHubProducerClient` and `from_connection_string` method takes the following new keyword arguments + for configuration: + - `buffered_mode`: The flag to enable/disable buffered mode sending. + - `on_success`: The callback to be called once events have been successfully published. + - `on_error`: The callback to be called once events have failed to be published. + - `max_buffer_length`: The total number of events per partition that can be buffered before a flush will be triggered. + - `max_wait_time`: The amount of time to wait for a batch to be built with events in the buffer before publishing. + - A new method `EventHubProducerClient.flush` which flushes events in the buffer to be sent immediately. + - A new method `EventHubProducerClient.get_buffered_event_count` which returns the number of events that are buffered and waiting to be published for a given partition. + - A new property `EventHubProducerClient.total_buffered_event_count` which returns the total number of events that are currently buffered and waiting to be published, across all partitions. + - A new boolean keyword argument `flush` to `EventHubProducerClient.close` which indicates whether to flush the buffer or not while closing. ## 5.9.0 (2022-05-10) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/__init__.py new file mode 100644 index 000000000000..bfee862537a1 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/__init__.py @@ -0,0 +1,13 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from ._buffered_producer import BufferedProducer +from ._partition_resolver import PartitionResolver +from ._buffered_producer_dispatcher import BufferedProducerDispatcher + +__all__ = [ + "BufferedProducer", + "PartitionResolver", + "BufferedProducerDispatcher", +] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py new file mode 100644 index 000000000000..679867e4109b --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py @@ -0,0 +1,183 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import time +import queue +import logging +from threading import RLock +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Callable, TYPE_CHECKING + +from .._producer import EventHubProducer +from .._common import EventDataBatch +from ..exceptions import OperationTimeoutError + +if TYPE_CHECKING: + from .._producer_client import SendEventTypes + +_LOGGER = logging.getLogger(__name__) + + +class BufferedProducer: + # pylint: disable=too-many-instance-attributes + def __init__( + self, + producer: EventHubProducer, + partition_id: str, + on_success: Callable[["SendEventTypes", Optional[str]], None], + on_error: Callable[["SendEventTypes", Optional[str], Exception], None], + max_message_size_on_link: int, + executor: ThreadPoolExecutor, + *, + max_wait_time: float = 1, + max_buffer_length: int + ): + self._buffered_queue: queue.Queue = queue.Queue() + self._max_buffer_len = max_buffer_length + self._cur_buffered_len = 0 + self._executor: ThreadPoolExecutor = executor + self._producer: EventHubProducer = producer + self._lock = RLock() + self._max_wait_time = max_wait_time + self._on_success = self.failsafe_callback(on_success) + self._on_error = self.failsafe_callback(on_error) + self._last_send_time = None + self._running = False + self._cur_batch: Optional[EventDataBatch] = None + self._max_message_size_on_link = max_message_size_on_link + self._check_max_wait_time_future = None + self.partition_id = partition_id + + def start(self): + with self._lock: + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._running = True + if self._max_wait_time: + self._last_send_time = time.time() + self._check_max_wait_time_future = self._executor.submit(self.check_max_wait_time_worker) + + def stop(self, flush=True, timeout_time=None, raise_error=False): + self._running = False + if flush: + with self._lock: + self.flush(timeout_time=timeout_time, raise_error=raise_error) + else: + if self._cur_buffered_len: + _LOGGER.warning( + "Shutting down Partition %r. There are still %r events in the buffer which will be lost", + self.partition_id, + self._cur_buffered_len, + ) + if self._check_max_wait_time_future: + remain_timeout = timeout_time - time.time() if timeout_time else None + try: + self._check_max_wait_time_future.result(remain_timeout) + except Exception as exc: # pylint: disable=broad-except + _LOGGER.warning("Partition %r stopped with error %r", self.partition_id, exc) + self._producer.close() + + def put_events(self, events, timeout_time=None): + # Put single event or EventDataBatch into the queue. + # This method would raise OperationTimeout if the queue does not have enough space for the input and + # flush cannot finish in timeout. + try: + new_events_len = len(events) + except TypeError: + new_events_len = 1 + if self._max_buffer_len - self._cur_buffered_len < new_events_len: + _LOGGER.info( + "The buffer for partition %r is full. Attempting to flush before adding %r events.", + self.partition_id, + new_events_len, + ) + # flush the buffer + with self._lock: + self.flush(timeout_time=timeout_time) + if timeout_time and time.time() > timeout_time: + raise OperationTimeoutError("Failed to enqueue events into buffer due to timeout.") + try: + # add single event into current batch + self._cur_batch.add(events) + except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer + # if there are events in cur_batch, enqueue cur_batch to the buffer + if self._cur_batch: + self._buffered_queue.put(self._cur_batch) + self._buffered_queue.put(events) + # create a new batch for incoming events + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + except ValueError: + # add single event exceeds the cur batch size, create new batch + self._buffered_queue.put(self._cur_batch) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch.add(events) + self._cur_buffered_len += new_events_len + + def failsafe_callback(self, callback): + def wrapper_callback(*args, **kwargs): + try: + callback(*args, **kwargs) + except Exception as exc: # pylint: disable=broad-except + _LOGGER.warning( + "On partition %r, callback %r encountered exception %r", callback.__name__, exc, self.partition_id + ) + + return wrapper_callback + + def flush(self, timeout_time=None, raise_error=True): + # pylint: disable=protected-access + # try flushing all the buffered batch within given time + with self._lock: + _LOGGER.info("Partition: %r started flushing.", self.partition_id) + if self._cur_batch: # if there is batch, enqueue it to the buffer first + self._buffered_queue.put(self._cur_batch) + while self._cur_buffered_len: + remaining_time = timeout_time - time.time() if timeout_time else None + if (remaining_time and remaining_time > 0) or remaining_time is None: + batch = self._buffered_queue.get() + self._buffered_queue.task_done() + try: + _LOGGER.info("Partition %r is sending.", self.partition_id) + self._producer.send(batch, timeout=timeout_time - time.time() if timeout_time else None) + _LOGGER.info("Partition %r sending %r events succeeded.", self.partition_id, len(batch)) + self._on_success(batch._internal_events, self.partition_id) + except Exception as exc: # pylint: disable=broad-except + _LOGGER.info( + "Partition %r sending %r events failed due to exception: %r ", + self.partition_id, + len(batch), + exc, + ) + self._on_error(batch._internal_events, self.partition_id, exc) + finally: + self._cur_buffered_len -= len(batch) + else: + _LOGGER.info("Partition %r fails to flush due to timeout.", self.partition_id) + if raise_error: + raise OperationTimeoutError( + "Failed to flush {!r} within {}".format(self.partition_id, timeout_time) + ) + break + # after finishing flushing, reset cur batch and put it into the buffer + self._last_send_time = time.time() + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + _LOGGER.info("Partition %r finished flushing.", self.partition_id) + + def check_max_wait_time_worker(self): + while self._running: + if self._cur_buffered_len > 0: + now_time = time.time() + _LOGGER.info("Partition %r worker is checking max_wait_time.", self.partition_id) + # flush the partition if the producer is running beyond the waiting time + # or the buffer is at max capacity + if (now_time - self._last_send_time > self._max_wait_time) or ( + self._cur_buffered_len >= self._max_buffer_len + ): + # in the worker, not raising error for flush, users can not handle this + with self._lock: + self.flush(raise_error=False) + time.sleep(min(self._max_wait_time, 5)) + + @property + def buffered_event_count(self): + return self._cur_buffered_len diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py new file mode 100644 index 000000000000..39ab2a326c42 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py @@ -0,0 +1,157 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import logging +from threading import Lock +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, Optional, List, Callable, Union, TYPE_CHECKING + +from ._partition_resolver import PartitionResolver +from ._buffered_producer import BufferedProducer +from .._producer import EventHubProducer +from ..exceptions import EventDataSendError, ConnectError, EventHubError + +if TYPE_CHECKING: + from .._producer_client import SendEventTypes + +_LOGGER = logging.getLogger(__name__) + + +class BufferedProducerDispatcher: + # pylint: disable=too-many-instance-attributes + def __init__( + self, + partitions: List[str], + on_success: Callable[["SendEventTypes", Optional[str]], None], + on_error: Callable[["SendEventTypes", Optional[str], Exception], None], + create_producer: Callable[..., EventHubProducer], + eventhub_name: str, + max_message_size_on_link: int, + *, + max_buffer_length: int = 1500, + max_wait_time: float = 1, + executor: Optional[Union[ThreadPoolExecutor, int]] = None + ): + self._buffered_producers: Dict[str, BufferedProducer] = {} + self._partition_ids: List[str] = partitions + self._lock = Lock() + self._on_success = on_success + self._on_error = on_error + self._create_producer = create_producer + self._eventhub_name = eventhub_name + self._max_message_size_on_link = max_message_size_on_link + self._partition_resolver = PartitionResolver(self._partition_ids) + self._max_wait_time = max_wait_time + self._max_buffer_length = max_buffer_length + self._existing_executor = False + + if not executor: + self._executor = ThreadPoolExecutor() + elif isinstance(executor, ThreadPoolExecutor): + self._existing_executor = True + self._executor = executor + elif isinstance(executor, int): + self._executor = ThreadPoolExecutor(executor) + + def _get_partition_id(self, partition_id, partition_key): + if partition_id: + if partition_id not in self._partition_ids: + raise ConnectError( + "Invalid partition {} for the event hub {}".format( + partition_id, self._eventhub_name + ) + ) + return partition_id + if isinstance(partition_key, str): + return self._partition_resolver.get_partition_id_by_partition_key(partition_key) + return self._partition_resolver.get_next_partition_id() + + def enqueue_events(self, events, *, partition_id=None, partition_key=None, timeout_time=None): + pid = self._get_partition_id(partition_id, partition_key) + with self._lock: + try: + self._buffered_producers[pid].put_events(events, timeout_time) + except KeyError: + buffered_producer = BufferedProducer( + self._create_producer(pid), + pid, + self._on_success, + self._on_error, + self._max_message_size_on_link, + executor=self._executor, + max_wait_time=self._max_wait_time, + max_buffer_length=self._max_buffer_length + ) + buffered_producer.start() + self._buffered_producers[pid] = buffered_producer + buffered_producer.put_events(events, timeout_time) + + def flush(self, timeout_time=None): + # flush all the buffered producer, the method will block until finishes or times out + with self._lock: + futures = [] + for pid, producer in self._buffered_producers.items(): + # call each producer's flush method + futures.append((pid, self._executor.submit(producer.flush, timeout_time=timeout_time))) + + # gather results + exc_results = {} + for pid, future in futures: + try: + future.result() + except Exception as exc: # pylint: disable=broad-except + exc_results[pid] = exc + + if not exc_results: + _LOGGER.info("Flushing all partitions succeeded") + return + + _LOGGER.warning('Flushing all partitions partially failed with result %r.', exc_results) + raise EventDataSendError( + message="Flushing all partitions partially failed, failed partitions are {!r}" + " Exception details are {!r}".format(exc_results.keys(), exc_results) + ) + + def close(self, *, flush=True, timeout_time=None, raise_error=False): + + with self._lock: + + futures = [] + # stop all buffered producers + for pid, producer in self._buffered_producers.items(): + futures.append((pid, self._executor.submit( + producer.stop, + flush=flush, + timeout_time=timeout_time, + raise_error=raise_error + ))) + + exc_results = {} + # gather results + for pid, future in futures: + try: + future.result() + except Exception as exc: # pylint: disable=broad-except + exc_results[pid] = exc + + if exc_results: + _LOGGER.warning('Stopping all partitions partially failed with result %r.', exc_results) + if raise_error: + raise EventHubError( + message="Stopping all partitions partially failed, failed partitions are {!r}" + " Exception details are {!r}".format(exc_results.keys(), exc_results) + ) + + if not self._existing_executor: + self._executor.shutdown() + + def get_buffered_event_count(self, pid): + try: + return self._buffered_producers[pid].buffered_event_count + except KeyError: + return 0 + + @property + def total_buffered_event_count(self): + return sum([self.get_buffered_event_count(pid) for pid in self._buffered_producers]) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_partition_resolver.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_partition_resolver.py new file mode 100644 index 000000000000..148e454404b4 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_partition_resolver.py @@ -0,0 +1,197 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +""" +jenkins-hash lookup3 algorithm implementation +""" + +from threading import Lock +import struct + +c_signed_short = struct.Struct(">h") + + +def rot(x, k): + return (x << k) | (x >> (32 - k)) + + +def mix(a, b, c): + a &= 0xffffffff + b &= 0xffffffff + c &= 0xffffffff + a -= c + a &= 0xffffffff + a ^= rot(c, 4) + a &= 0xffffffff + c += b + c &= 0xffffffff + b -= a + b &= 0xffffffff + b ^= rot(a, 6) + b &= 0xffffffff + a += c + a &= 0xffffffff + c -= b + c &= 0xffffffff + c ^= rot(b, 8) + c &= 0xffffffff + b += a + b &= 0xffffffff + a -= c + a &= 0xffffffff + a ^= rot(c, 16) + a &= 0xffffffff + c += b + c &= 0xffffffff + b -= a + b &= 0xffffffff + b ^= rot(a, 19) + b &= 0xffffffff + a += c + a &= 0xffffffff + c -= b + c &= 0xffffffff + c ^= rot(b, 4) + c &= 0xffffffff + b += a + b &= 0xffffffff + return a, b, c + + +def final(a, b, c): + a &= 0xffffffff + b &= 0xffffffff + c &= 0xffffffff + c ^= b + c &= 0xffffffff + c -= rot(b, 14) + c &= 0xffffffff + a ^= c + a &= 0xffffffff + a -= rot(c, 11) + a &= 0xffffffff + b ^= a + b &= 0xffffffff + b -= rot(a, 25) + b &= 0xffffffff + c ^= b + c &= 0xffffffff + c -= rot(b, 16) + c &= 0xffffffff + a ^= c + a &= 0xffffffff + a -= rot(c, 4) + a &= 0xffffffff + b ^= a + b &= 0xffffffff + b -= rot(a, 14) + b &= 0xffffffff + c ^= b + c &= 0xffffffff + c -= rot(b, 24) + c &= 0xffffffff + return a, b, c + + +def compute_hash(data, init_val=0, init_val2=0): + # pylint: disable=too-many-statements + """ + implementation by: + https://stackoverflow.com/questions/3279615/python-implementation-of-jenkins-hash + """ + length = lenpos = len(data) + + a = b = c = (0xdeadbeef + length + init_val) + + c += init_val2 + c &= 0xffffffff + + p = 0 # string offset + while lenpos > 12: + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8) + (ord(data[p + 2]) << 16) + (ord(data[p + 3]) << 24)) + a &= 0xffffffff + b += (ord(data[p + 4]) + (ord(data[p + 5]) << 8) + (ord(data[p + 6]) << 16) + (ord(data[p + 7]) << 24)) + b &= 0xffffffff + c += (ord(data[p + 8]) + (ord(data[p + 9]) << 8) + (ord(data[p + 10]) << 16) + (ord(data[p + 11]) << 24)) + c &= 0xffffffff + a, b, c = mix(a, b, c) + p += 12 + lenpos -= 12 + + if lenpos == 12: + c += (ord(data[p + 8]) + (ord(data[p + 9]) << 8) + (ord(data[p + 10]) << 16) + (ord(data[p + 11]) << 24)) + b += (ord(data[p + 4]) + (ord(data[p + 5]) << 8) + (ord(data[p + 6]) << 16) + (ord(data[p + 7]) << 24)) + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8) + (ord(data[p + 2]) << 16) + (ord(data[p + 3]) << 24)) + if lenpos == 11: + c += (ord(data[p + 8]) + (ord(data[p + 9]) << 8) + (ord(data[p + 10]) << 16)) + b += (ord(data[p + 4]) + (ord(data[p + 5]) << 8) + (ord(data[p + 6]) << 16) + (ord(data[p + 7]) << 24)) + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8) + (ord(data[p + 2]) << 16) + (ord(data[p + 3]) << 24)) + if lenpos == 10: + c += (ord(data[p + 8]) + (ord(data[p + 9]) << 8)) + b += (ord(data[p + 4]) + (ord(data[p + 5]) << 8) + (ord(data[p + 6]) << 16) + (ord(data[p + 7]) << 24)) + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8) + (ord(data[p + 2]) << 16) + (ord(data[p + 3]) << 24)) + if lenpos == 9: + c += (ord(data[p + 8])) + b += (ord(data[p + 4]) + (ord(data[p + 5]) << 8) + (ord(data[p + 6]) << 16) + (ord(data[p + 7]) << 24)) + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8) + (ord(data[p + 2]) << 16) + (ord(data[p + 3]) << 24)) + if lenpos == 8: + b += (ord(data[p + 4]) + (ord(data[p + 5]) << 8) + (ord(data[p + 6]) << 16) + (ord(data[p + 7]) << 24)) + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8) + (ord(data[p + 2]) << 16) + (ord(data[p + 3]) << 24)) + if lenpos == 7: + b += (ord(data[p + 4]) + (ord(data[p + 5]) << 8) + (ord(data[p + 6]) << 16)) + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8) + (ord(data[p + 2]) << 16) + (ord(data[p + 3]) << 24)) + if lenpos == 6: + b += ((ord(data[p + 5]) << 8) + ord(data[p + 4])) + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8) + (ord(data[p + 2]) << 16) + (ord(data[p + 3]) << 24)) + if lenpos == 5: + b += (ord(data[p + 4])) + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8) + (ord(data[p + 2]) << 16) + (ord(data[p + 3]) << 24)) + if lenpos == 4: + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8) + (ord(data[p + 2]) << 16) + (ord(data[p + 3]) << 24)) + if lenpos == 3: + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8) + (ord(data[p + 2]) << 16)) + if lenpos == 2: + a += (ord(data[p + 0]) + (ord(data[p + 1]) << 8)) + if lenpos == 1: + a += ord(data[p + 0]) + + a &= 0xffffffff + b &= 0xffffffff + c &= 0xffffffff + if lenpos == 0: + return c, b + + a, b, c = final(a, b, c) + + return c, b + + +def generate_hash_code(partition_key): + if not partition_key: + return 0 + + hash_tuple = compute_hash(partition_key, 0, 0) + hash_value = (hash_tuple[0] ^ hash_tuple[1]) & 0xffff + return c_signed_short.unpack(struct.pack('>H', hash_value))[0] + + +class PartitionResolver: + def __init__(self, partitions): + self._idx = -1 + self._partitions = partitions + self._partitions_cnt = len(self._partitions) + self._lock = Lock() + + def get_next_partition_id(self): + """ + round-robin partition assignment + """ + with self._lock: + self._idx += 1 + self._idx %= self._partitions_cnt + return self._partitions[self._idx] + + def get_partition_id_by_partition_key(self, partition_key): + hash_code = generate_hash_code(partition_key) + return self._partitions[abs(hash_code % self._partitions_cnt)] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 5e01e0cfc69e..9ab6db008543 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -549,6 +549,7 @@ def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None set_message_partition_key(self.message, self._partition_key) self._size = self.message.gather()[0].get_message_encoded_size() self._count = 0 + self._internal_events: List[Union[EventData, AmqpAnnotatedMessage]] = [] def __repr__(self): # type: () -> str @@ -567,9 +568,8 @@ def _from_batch(cls, batch_data, partition_key=None): transform_outbound_single_message(m, EventData) for m in batch_data ] batch_data_instance = cls(partition_key=partition_key) - batch_data_instance.message._body_gen = ( # pylint:disable=protected-access - outgoing_batch_data - ) + for data in outgoing_batch_data: + batch_data_instance.add(data) return batch_data_instance def _load_events(self, events): @@ -639,6 +639,7 @@ def add(self, event_data): ) ) + self._internal_events.append(event_data) self.message._body_gen.append( # pylint: disable=protected-access outgoing_event_data ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 75498fc0bf37..be1ed347ad16 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -34,6 +34,7 @@ transform_outbound_single_message, ) from ._constants import TIMEOUT_SYMBOL +from .amqp import AmqpAnnotatedMessage _LOGGER = logging.getLogger(__name__) @@ -184,12 +185,12 @@ def _on_outcome(self, outcome, condition): def _wrap_eventdata( self, - event_data, # type: Union[EventData, EventDataBatch, Iterable[EventData]] + event_data, # type: Union[EventData, AmqpAnnotatedMessage, EventDataBatch, Iterable[EventData]] span, # type: Optional[AbstractSpan] partition_key, # type: Optional[AnyStr] ): # type: (...) -> Union[EventData, EventDataBatch] - if isinstance(event_data, EventData): + if isinstance(event_data, (EventData, AmqpAnnotatedMessage)): outgoing_event_data = transform_outbound_single_message(event_data, EventData) if partition_key: set_message_partition_key(outgoing_event_data.message, partition_key) @@ -199,6 +200,8 @@ def _wrap_eventdata( if isinstance( event_data, EventDataBatch ): # The partition_key in the param will be omitted. + if not event_data: + return event_data if ( partition_key and partition_key != event_data._partition_key # pylint: disable=protected-access ): @@ -218,7 +221,7 @@ def _wrap_eventdata( def send( self, - event_data, # type: Union[EventData, EventDataBatch, Iterable[EventData]] + event_data, # type: Union[EventData, AmqpAnnotatedMessage, EventDataBatch, Iterable[EventData]] partition_key=None, # type: Optional[AnyStr] timeout=None, # type: Optional[float] ): @@ -251,6 +254,10 @@ def send( with send_context_manager() as child: self._check_closed() wrapper_event_data = self._wrap_eventdata(event_data, child, partition_key) + + if not wrapper_event_data: + return + self._unsent_events = [wrapper_event_data.message] if child: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index b8bf46cd3733..7db5f28eda7d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -2,19 +2,33 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from concurrent.futures import ThreadPoolExecutor import logging import threading - -from typing import Any, Union, TYPE_CHECKING, Dict, List, Optional, cast +import time +from typing import ( + Any, + Union, + TYPE_CHECKING, + Dict, + List, + Optional, + Callable, + cast, + overload, +) +from typing_extensions import Literal from uamqp import constants -from .exceptions import ConnectError, EventHubError -from .amqp import AmqpAnnotatedMessage from ._client_base import ClientBase -from ._producer import EventHubProducer -from ._constants import ALL_PARTITIONS from ._common import EventDataBatch, EventData +from ._constants import ALL_PARTITIONS +from ._producer import EventHubProducer +from ._buffered_producer import BufferedProducerDispatcher +from ._utils import set_event_partition_key +from .amqp import AmqpAnnotatedMessage +from .exceptions import ConnectError, EventHubError if TYPE_CHECKING: from ._client_base import CredentialTypes @@ -24,7 +38,8 @@ _LOGGER = logging.getLogger(__name__) -class EventHubProducerClient(ClientBase): # pylint: disable=client-accepts-api-version-keyword +class EventHubProducerClient(ClientBase): # pylint: disable=client-accepts-api-version-keyword + # pylint: disable=too-many-instance-attributes """The EventHubProducerClient class defines a high level interface for sending events to the Azure Event Hubs service. @@ -37,6 +52,39 @@ class EventHubProducerClient(ClientBase): # pylint: disable=client-accepts-api by the azure-identity library and objects that implement the `get_token(self, *scopes)` method. :type credential: ~azure.core.credentials.TokenCredential or ~azure.core.credentials.AzureSasCredential or ~azure.core.credentials.AzureNamedKeyCredential + :keyword bool buffered_mode: If True, the producer client will collect events in a buffer, efficiently batch, + then publish. Default is False. + :keyword Union[ThreadPoolExecutor, int] buffer_concurrency: The ThreadPoolExecutor to be used for publishing events + or the number of workers for the ThreadPoolExecutor. + Default is none and a ThreadPoolExecutor with the default number of workers will be created + per https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor + :keyword on_success: The callback to be called once a batch has been successfully published. + The callback takes two parameters: + - `events`: The list of events that have been successfully published + - `partition_id`: The partition id that the events in the list have been published to. + The callback function should be defined like: `on_success(events, partition_id)`. + It is required when `buffered_mode` is True while optional if `buffered_mode` is False. + :paramtype on_success: Optional[Callable[[SendEventTypes, Optional[str]], None]] + :keyword on_error: The callback to be called once a batch has failed to be published. + The callback function should be defined like: `on_error(events, partition_id, error)`, where: + - `events`: The list of events that failed to be published, + - `partition_id`: The partition id that the events in the list have been tried to be published to and + - `error`: The exception related to the sending failure. + If `buffered_mode` is False, `on_error` callback is optional and errors will be handled as follows: + - If an `on_error` callback is passed during the producer client instantiation, + then error information will be passed to the `on_error` callback, which will then be called. + - If an `on_error` callback is not passed in during client instantiation, + then the error will be raised by default. + If `buffered_mode` is True, `on_error` callback is required and errors will be handled as follows: + - If events fail to enqueue within the given timeout, then an error will be directly raised. + - If events fail to send after enqueuing successfully, the `on_error` callback will be called. + :paramtype on_error: Optional[Callable[[SendEventTypes, Optional[str], Exception], None]] + :keyword int max_buffer_length: Buffered mode only. + The total number of events per partition that can be buffered before a flush will be triggered. + The default value is 1500 in buffered mode. + :keyword Optional[float] max_wait_time: Buffered mode only. + The amount of time to wait for a batch to be built with events in the buffer before publishing. + The default value is 1 in buffered mode. :keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`. :keyword float auth_timeout: The time in seconds to wait for a token to be authorized by the service. The default value is 60 seconds. If set to 0, no timeout will be enforced from the client. @@ -84,14 +132,48 @@ class EventHubProducerClient(ClientBase): # pylint: disable=client-accepts-api """ + @overload + def __init__( + self, + fully_qualified_namespace: str, + eventhub_name: str, + credential: "CredentialTypes", + *, + buffered_mode: Literal[False] = False, + **kwargs: Any + ) -> None: + ... + + @overload def __init__( self, - fully_qualified_namespace, # type: str - eventhub_name, # type: str - credential, # type: CredentialTypes - **kwargs # type: Any - ): - # type:(...) -> None + fully_qualified_namespace: str, + eventhub_name: str, + credential: "CredentialTypes", + *, + buffered_mode: Literal[True], + buffer_concurrency: Union[ThreadPoolExecutor, int] = None, + on_error: Callable[[SendEventTypes, Optional[str], Exception], None], + on_success: Callable[[SendEventTypes, Optional[str]], None], + max_buffer_length: int = 1500, + max_wait_time: float = 1, + **kwargs: Any + ) -> None: + ... + + def __init__( + self, + fully_qualified_namespace: str, + eventhub_name: str, + credential: "CredentialTypes", + *, + buffered_mode: bool = False, + on_error: Optional[Callable[[SendEventTypes, Optional[str], Exception], None]] = None, + on_success: Optional[Callable[[SendEventTypes, Optional[str]], None]] = None, + max_buffer_length: Optional[int] = None, + max_wait_time: Optional[float] = None, + **kwargs: Any + ) -> None: super(EventHubProducerClient, self).__init__( fully_qualified_namespace=fully_qualified_namespace, eventhub_name=eventhub_name, @@ -99,12 +181,39 @@ def __init__( network_tracing=kwargs.get("logging_enable"), **kwargs ) - self._producers = { - ALL_PARTITIONS: self._create_producer() - } # type: Dict[str, Optional[EventHubProducer]] + self._producers = {ALL_PARTITIONS: self._create_producer()} # type: Dict[str, Optional[EventHubProducer]] self._max_message_size_on_link = 0 self._partition_ids = None # Optional[List[str]] self._lock = threading.Lock() + self._buffered_mode = buffered_mode + self._on_success = on_success + self._on_error = on_error + self._buffered_producer_dispatcher = None + self._max_wait_time = max_wait_time + self._max_buffer_length = max_buffer_length + self._executor = kwargs.get("buffer_concurrency") + + if self._buffered_mode: + setattr(self, "send_batch", self._buffered_send_batch) + setattr(self, "send_event", self._buffered_send_event) + if not self._on_error: + raise TypeError( + "EventHubProducerClient in buffered mode missing 1 required keyword argument: 'on_error'" + ) + if not self._on_success: + raise TypeError( + "EventHubProducerClient in buffered mode missing 1 required keyword argument: 'on_success'" + ) + if self._max_wait_time is None: + self._max_wait_time = 1 + if self._max_wait_time <= 0: + raise ValueError("'max_wait_time' must be a float greater than 0 in buffered mode") + if self._max_buffer_length is None: + self._max_buffer_length = 1500 + if self._max_buffer_length <= 0: + raise ValueError("'max_buffer_length' must be an integer greater than 0 in buffered mode") + if isinstance(self._executor, int) and self._executor <= 0: + raise ValueError("'buffer_concurrency' must be an integer greater than 0 in buffered mode") def __enter__(self): return self @@ -112,6 +221,62 @@ def __enter__(self): def __exit__(self, *args): self.close() + def _buffered_send(self, events, **kwargs): + try: + self._buffered_producer_dispatcher.enqueue_events(events, **kwargs) + except AttributeError: + self._get_partitions() + self._get_max_message_size() + self._buffered_producer_dispatcher = BufferedProducerDispatcher( + self._partition_ids, + self._on_success, + self._on_error, + self._create_producer, + self.eventhub_name, + self._max_message_size_on_link, + max_wait_time=self._max_wait_time, + max_buffer_length=self._max_buffer_length, + executor=self._executor + ) + self._buffered_producer_dispatcher.enqueue_events(events, **kwargs) + + def _batch_preparer(self, event_data_batch, **kwargs): + partition_id = kwargs.pop("partition_id", None) + partition_key = kwargs.pop("partition_key", None) + + if isinstance(event_data_batch, EventDataBatch): + if partition_id or partition_key: + raise TypeError( + "partition_id and partition_key should be None when sending an EventDataBatch " + "because type EventDataBatch itself may have partition_id or partition_key" + ) + to_send_batch = event_data_batch + else: + to_send_batch = self.create_batch(partition_id=partition_id, partition_key=partition_key) + to_send_batch._load_events(event_data_batch) # pylint:disable=protected-access + + return to_send_batch, to_send_batch._partition_id, partition_key # pylint:disable=protected-access + + def _buffered_send_batch(self, event_data_batch, **kwargs): + batch, pid, pkey = self._batch_preparer(event_data_batch, **kwargs) + + if len(batch) == 0: + return + + timeout = kwargs.get("timeout") + timeout_time = time.time() + timeout if timeout else None + + self._buffered_send(event_data_batch, partition_id=pid, partition_key=pkey, timeout_time=timeout_time) + + def _buffered_send_event(self, event, **kwargs): + partition_key = kwargs.get("partition_key") + set_event_partition_key(event, partition_key) + timeout = kwargs.get("timeout") + timeout_time = time.time() + timeout if timeout else None + self._buffered_send( + event, partition_id=kwargs.get("partition_id"), partition_key=partition_key, timeout_time=timeout_time + ) + def _get_partitions(self): # type: () -> None if not self._partition_ids: @@ -119,18 +284,14 @@ def _get_partitions(self): for p_id in cast(List[str], self._partition_ids): self._producers[p_id] = None - def _get_max_mesage_size(self): + def _get_max_message_size(self): # type: () -> None # pylint: disable=protected-access,line-too-long with self._lock: if not self._max_message_size_on_link: - cast( - EventHubProducer, self._producers[ALL_PARTITIONS] - )._open_with_retry() + cast(EventHubProducer, self._producers[ALL_PARTITIONS])._open_with_retry() self._max_message_size_on_link = ( - self._producers[ # type: ignore - ALL_PARTITIONS - ]._handler.message_handler._link.peer_max_message_size + self._producers[ALL_PARTITIONS]._handler.message_handler._link.peer_max_message_size # type: ignore or constants.MAX_MESSAGE_LENGTH_BYTES ) @@ -138,33 +299,21 @@ def _start_producer(self, partition_id, send_timeout): # type: (str, Optional[Union[int, float]]) -> None with self._lock: self._get_partitions() - if ( - partition_id not in cast(List[str], self._partition_ids) - and partition_id != ALL_PARTITIONS - ): + if partition_id not in cast(List[str], self._partition_ids) and partition_id != ALL_PARTITIONS: raise ConnectError( - "Invalid partition {} for the event hub {}".format( - partition_id, self.eventhub_name - ) + "Invalid partition {} for the event hub {}".format(partition_id, self.eventhub_name) ) - if ( - not self._producers[partition_id] - or cast(EventHubProducer, self._producers[partition_id]).closed - ): + if not self._producers[partition_id] or cast(EventHubProducer, self._producers[partition_id]).closed: self._producers[partition_id] = self._create_producer( - partition_id=( - None if partition_id == ALL_PARTITIONS else partition_id - ), + partition_id=(None if partition_id == ALL_PARTITIONS else partition_id), send_timeout=send_timeout, ) def _create_producer(self, partition_id=None, send_timeout=None): # type: (Optional[str], Optional[Union[int, float]]) -> EventHubProducer target = "amqps://{}{}".format(self._address.hostname, self._address.path) - send_timeout = ( - self._config.send_timeout if send_timeout is None else send_timeout - ) + send_timeout = self._config.send_timeout if send_timeout is None else send_timeout handler = EventHubProducer( self, @@ -176,14 +325,82 @@ def _create_producer(self, partition_id=None, send_timeout=None): return handler @classmethod - def from_connection_string(cls, conn_str, **kwargs): - # type: (str, Any) -> EventHubProducerClient + @overload + def from_connection_string( + cls, + conn_str: str, + *, + eventhub_name: Optional[str] = None, + buffered_mode: Literal[False] = False, + **kwargs: Any + ) -> "EventHubProducerClient": + ... + + @classmethod + @overload + def from_connection_string( + cls, + conn_str: str, + *, + eventhub_name: Optional[str] = None, + buffered_mode: Literal[True], + on_error: Callable[[SendEventTypes, Optional[str], Exception], None], + on_success: Callable[[SendEventTypes, Optional[str]], None], + max_buffer_length: int = 1500, + max_wait_time: float = 1, + **kwargs: Any + ) -> "EventHubProducerClient": + ... + + @classmethod + def from_connection_string( + cls, + conn_str: str, + *, + eventhub_name: Optional[str] = None, + buffered_mode: bool = False, + on_error: Optional[Callable[[SendEventTypes, Optional[str], Exception], None]] = None, + on_success: Optional[Callable[[SendEventTypes, Optional[str]], None]] = None, + max_buffer_length: Optional[int] = None, + max_wait_time: Optional[float] = None, + **kwargs: Any + ) -> "EventHubProducerClient": """Create an EventHubProducerClient from a connection string. :param str conn_str: The connection string of an Event Hub. :keyword str eventhub_name: The path of the specific Event Hub to connect the client to. + :keyword bool buffered_mode: If True, the producer client will collect events in a buffer, efficiently batch, + then publish. Default is False. + :keyword on_success: The callback to be called once a batch has been successfully published. + The callback takes two parameters: + - `events`: The list of events that have been successfully published + - `partition_id`: The partition id that the events in the list have been published to. + The callback function should be defined like: `on_success(events, partition_id)`. + Required when `buffered_mode` is True while optional if `buffered_mode` is False. + :paramtype on_success: Optional[Callable[[SendEventTypes, Optional[str]], None]] + :keyword on_error: The callback to be called once a batch has failed to be published. + Required when in `buffered_mode` is True while optional if `buffered_mode` is False. + The callback function should be defined like: `on_error(events, partition_id, error)`, where: + - `events`: The list of events that failed to be published, + - `partition_id`: The partition id that the events in the list have been tried to be published to and + - `error`: The exception related to the sending failure. + If `buffered_mode` is False, `on_error` callback is optional and errors will be handled as follows: + - If an `on_error` callback is passed during the producer client instantiation, + then error information will be passed to the `on_error` callback, which will then be called. + - If an `on_error` callback is not passed in during client instantiation, + then the error will be raised by default. + If `buffered_mode` is True, `on_error` callback is required and errors will be handled as follows: + - If events fail to enqueue within the given timeout, then an error will be directly raised. + - If events fail to send after enqueuing successfully, the `on_error` callback will be called. + :paramtype on_error: Optional[Callable[[SendEventTypes, Optional[str], Exception], None]] + :keyword int max_buffer_length: Buffered mode only. + The total number of events per partition that can be buffered before a flush will be triggered. + The default value is 1500 in buffered mode. + :keyword Optional[float] max_wait_time: Buffered mode only. + The amount of time to wait for a batch to be built with events in the buffer before publishing. + The default value is 1 in buffered mode. :keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`. - :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + :keyword dict http_proxy: HTTP proxy settings. This must be a dictionary with the following keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). Additionally the following keys may also be present: `'username', 'password'`. :keyword float auth_timeout: The time in seconds to wait for a token to be authorized by the service. @@ -228,12 +445,110 @@ def from_connection_string(cls, conn_str, **kwargs): :dedent: 4 :caption: Create a new instance of the EventHubProducerClient from connection string. """ - constructor_args = cls._from_connection_string(conn_str, **kwargs) + constructor_args = cls._from_connection_string( + conn_str, + eventhub_name=eventhub_name, + buffered_mode=buffered_mode, + on_success=on_success, + on_error=on_error, + max_buffer_length=max_buffer_length, + max_wait_time=max_wait_time, + **kwargs + ) return cls(**constructor_args) + def send_event(self, event_data, **kwargs): + # type: (Union[EventData, AmqpAnnotatedMessage], Any) -> None + """ + Sends an event data. + By default, the method will block until acknowledgement is received or operation times out. + If the `EventHubProducerClient` is configured to run in buffered mode, the method will try enqueuing + the events into buffer within the given time if specified and return. + The producer will do automatic sending in the background in buffered mode. + + If `buffered_mode` is False, `on_error` callback is optional and errors will be handled as follows: + - If an `on_error` callback is passed during the producer client instantiation, + then error information will be passed to the `on_error` callback, which will then be called. + - If an `on_error` callback is not passed in during client instantiation, + then the error will be raised by default. + + If `buffered_mode` is True, `on_error` callback is required and errors will be handled as follows: + - If events fail to enqueue within the given timeout, then an error will be directly raised. + - If events fail to send after enqueuing successfully, the `on_error` callback will be called. + + :param event_data: The `EventData` object to be sent. + :type event_data: Union[~azure.eventhub.EventData, ~azure.eventhub.amqp.AmqpAnnotatedMessage] + :keyword float timeout: The maximum wait time to send the event data in non-buffered mode or the + maximum wait time to enqueue the event data into the buffer in buffered mode. + In non-buffered mode, the default wait time specified when the producer + was created will be used. In buffered mode, the default wait time is None. + :keyword str partition_id: The specific partition ID to send to. Default is None, in which case the service + will assign to all partitions using round-robin. + A `TypeError` will be raised if partition_id is specified and event_data_batch is an `EventDataBatch` because + `EventDataBatch` itself has partition_id. + :keyword str partition_key: With the given partition_key, event data will be sent to + a particular partition of the Event Hub decided by the service. + A `TypeError` will be raised if partition_key is specified and event_data_batch is an `EventDataBatch` because + `EventDataBatch` itself has partition_key. + If both partition_id and partition_key are provided, the partition_id will take precedence. + **WARNING: Setting partition_key of non-string value on the events to be sent is discouraged + as the partition_key will be ignored by the Event Hub service and events will be assigned + to all partitions using round-robin. Furthermore, there are SDKs for consuming events which expect + partition_key to only be string type, they might fail to parse the non-string value.** + :rtype: None + :raises: :class:`AuthenticationError` + :class:`ConnectError` + :class:`ConnectionLostError` + :class:`EventDataError` + :class:`EventDataSendError` + :class:`EventHubError` + :raises OperationTimeoutError: If the value specified by the timeout parameter elapses before the event can be + sent in non-buffered mode or the events can be enqueued into the buffered in buffered mode. + """ + input_pid = kwargs.get("partition_id") + pid = input_pid or ALL_PARTITIONS + partition_key = kwargs.get("partition_key") + send_timeout = kwargs.get("timeout") + try: + try: + cast(EventHubProducer, self._producers[pid]).send( + event_data, partition_key=partition_key, timeout=send_timeout + ) + except (KeyError, AttributeError, EventHubError): + self._start_producer(pid, send_timeout) + cast(EventHubProducer, self._producers[pid]).send( + event_data, partition_key=partition_key, timeout=send_timeout + ) + if self._on_success: + self._on_success([event_data], input_pid) + except Exception as exc: # pylint: disable=broad-except + if self._on_error: + self._on_error([event_data], input_pid, exc) + else: + raise + def send_batch(self, event_data_batch, **kwargs): # type: (Union[EventDataBatch, SendEventTypes], Any) -> None - """Sends event data and blocks until acknowledgement is received or operation times out. + # pylint: disable=protected-access + """ + Sends a batch of event data. + By default, the method will block until acknowledgement is received or operation times out. + If the `EventHubProducerClient` is configured to run in buffered mode, the method will try enqueuing + the events into buffer within the given time if specified and return. + The producer will do automatic sending in the background in buffered mode. + + If `buffered_mode` is False, `on_error` callback is optional and errors will be handled as follows: + - If an `on_error` callback is passed during the producer client instantiation, + then error information will be passed to the `on_error` callback, which will then be called. + - If an `on_error` callback is not passed in during client instantiation, + then the error will be raised by default. + + If `buffered_mode` is True, `on_error` callback is required and errors will be handled as follows: + - If events fail to enqueue within the given timeout, then an error will be directly raised. + - If events fail to send after enqueuing successfully, the `on_error` callback will be called. + + In buffered mode, sending a batch will remain intact and sent as a single unit. + The batch will not be rearranged. This may result in inefficiency of sending events. If you're sending a finite list of `EventData` or `AmqpAnnotatedMessage` and you know it's within the event hub frame size limit, you can send them with a `send_batch` call. Otherwise, use :meth:`create_batch` @@ -244,8 +559,10 @@ def send_batch(self, event_data_batch, **kwargs): All `EventData` or `AmqpAnnotatedMessage` in the list or `EventDataBatch` will land on the same partition. :type event_data_batch: Union[~azure.eventhub.EventDataBatch, List[Union[~azure.eventhub.EventData, ~azure.eventhub.amqp.AmqpAnnotatedMessage]] - :keyword float timeout: The maximum wait time to send the event data. - If not specified, the default wait time specified when the producer was created will be used. + :keyword float timeout: The maximum wait time to send the event data in non-buffered mode or the + maximum wait time to enqueue the event data into the buffer in buffered mode. + In non-buffered mode, the default wait time specified when the producer + was created will be used. In buffered mode, the default wait time is None. :keyword str partition_id: The specific partition ID to send to. Default is None, in which case the service will assign to all partitions using round-robin. A `TypeError` will be raised if partition_id is specified and event_data_batch is an `EventDataBatch` because @@ -268,6 +585,8 @@ def send_batch(self, event_data_batch, **kwargs): :class:`EventHubError` :class:`ValueError` :class:`TypeError` + :raises OperationTimeoutError: If the value specified by the timeout parameter elapses before the event can be + sent in non-buffered mode or the events can not be enqueued into the buffered in buffered mode. .. admonition:: Example: @@ -279,41 +598,33 @@ def send_batch(self, event_data_batch, **kwargs): :caption: Sends event data """ - partition_id = kwargs.get("partition_id") - partition_key = kwargs.get("partition_key") - - if isinstance(event_data_batch, EventDataBatch): - if partition_id or partition_key: - raise TypeError( - "partition_id and partition_key should be None when sending an EventDataBatch " - "because type EventDataBatch itself may have partition_id or partition_key" - ) - to_send_batch = event_data_batch - else: - to_send_batch = self.create_batch( - partition_id=partition_id, partition_key=partition_key - ) - to_send_batch._load_events( # pylint:disable=protected-access - event_data_batch - ) - partition_id = ( - to_send_batch._partition_id # pylint:disable=protected-access - or ALL_PARTITIONS - ) + batch, pid, pkey = self._batch_preparer(event_data_batch, **kwargs) - if len(to_send_batch) == 0: + if len(batch) == 0: return + partition_id = pid or ALL_PARTITIONS send_timeout = kwargs.pop("timeout", None) + try: - cast(EventHubProducer, self._producers[partition_id]).send( - to_send_batch, timeout=send_timeout - ) - except (KeyError, AttributeError, EventHubError): - self._start_producer(partition_id, send_timeout) - cast(EventHubProducer, self._producers[partition_id]).send( - to_send_batch, timeout=send_timeout - ) + try: + cast(EventHubProducer, self._producers[partition_id]).send( + batch, partition_key=pkey, timeout=send_timeout + ) + if self._on_success: + self._on_success(batch._internal_events, pid) + except (KeyError, AttributeError, EventHubError): + self._start_producer(partition_id, send_timeout) + cast(EventHubProducer, self._producers[partition_id]).send( + batch, partition_key=pkey, timeout=send_timeout + ) + if self._on_success: + self._on_success(batch._internal_events, pid) + except Exception as exc: # pylint: disable=broad-except + if self._on_error: + self._on_error(batch._internal_events, pid, exc) + else: + raise def create_batch(self, **kwargs): # type:(Any) -> EventDataBatch @@ -345,7 +656,7 @@ def create_batch(self, **kwargs): """ if not self._max_message_size_on_link: - self._get_max_mesage_size() + self._get_max_message_size() max_size_in_bytes = kwargs.get("max_size_in_bytes", None) partition_id = kwargs.get("partition_id", None) @@ -409,15 +720,34 @@ def get_partition_properties(self, partition_id): :rtype: Dict[str, Any] :raises: :class:`EventHubError` """ - return super(EventHubProducerClient, self)._get_partition_properties( - partition_id - ) + return super(EventHubProducerClient, self)._get_partition_properties(partition_id) - def close(self): - # type: () -> None + def flush(self, **kwargs: Any) -> None: + """ + Buffered mode only. + Flush events in the buffer to be sent immediately if the client is working in buffered mode. + + :keyword Optional[float] timeout: Timeout to flush the buffered events, default is None which means no timeout. + :rtype: None + :raises EventDataSendError: If the producer fails to flush the buffer within the given timeout + in buffered mode. + """ + with self._lock: + if self._buffered_mode and self._buffered_producer_dispatcher: + timeout = kwargs.get("timeout") + timeout_time = time.time() + timeout if timeout else None + self._buffered_producer_dispatcher.flush(timeout_time=timeout_time) + + def close(self, *, flush: bool = True, **kwargs: Any) -> None: """Close the Producer client underlying AMQP connection and links. + :keyword bool flush: Buffered mode only. If set to True, events in the buffer will be sent + immediately. Default is True. + :keyword Optional[float] timeout: Buffered mode only. Timeout to close the producer. + Default is None which means no timeout. :rtype: None + :raises EventHubError: If an error occurred when flushing the buffer if `flush` is set to True or closing the + underlying AMQP connections in buffered mode. .. admonition:: Example: @@ -430,8 +760,48 @@ def close(self): """ with self._lock: + if self._buffered_mode and self._buffered_producer_dispatcher: + timeout = kwargs.get("timeout") + timeout_time = time.time() + timeout if timeout else None + self._buffered_producer_dispatcher.close(flush=flush, timeout_time=timeout_time, raise_error=True) + self._buffered_producer_dispatcher = None + for pid in self._producers: if self._producers[pid]: self._producers[pid].close() # type: ignore self._producers[pid] = None super(EventHubProducerClient, self)._close() + + def get_buffered_event_count(self, partition_id: str) -> Optional[int]: + """ + The number of events that are buffered and waiting to be published for a given partition. + Returns None in non-buffered mode. + + :param str partition_id: The target partition ID. + :rtype: int or None + """ + if not self._buffered_mode: + return None + + try: + return cast(BufferedProducerDispatcher, self._buffered_producer_dispatcher).get_buffered_event_count( + partition_id + ) + except AttributeError: + return 0 + + @property + def total_buffered_event_count(self) -> Optional[int]: + """ + The total number of events that are currently buffered and waiting to be published, across all partitions. + Returns None in non-buffered mode. + + :rtype: int or None + """ + if not self._buffered_mode: + return None + + try: + return cast(BufferedProducerDispatcher, self._buffered_producer_dispatcher).total_buffered_event_count + except AttributeError: + return 0 diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index cf913f4e3335..b23e9af7c5a4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -20,7 +20,7 @@ from azure.core.settings import settings from azure.core.tracing import SpanKind, Link -from .amqp import AmqpAnnotatedMessage +from .amqp import AmqpAnnotatedMessage, AmqpMessageHeader from ._version import VERSION from ._constants import ( PROP_PARTITION_KEY_AMQP_SYMBOL, @@ -112,6 +112,28 @@ def create_properties(user_agent=None): return properties +def set_event_partition_key(event, partition_key): + # type: (Union[AmqpAnnotatedMessage, EventData], Optional[Union[bytes, str]]) -> None + if not partition_key: + return + + try: + raw_message = event.raw_amqp_message # type: ignore + except AttributeError: + raw_message = event + + annotations = raw_message.annotations + if annotations is None: + annotations = dict() + annotations[ + PROP_PARTITION_KEY_AMQP_SYMBOL + ] = partition_key # pylint:disable=protected-access + if not raw_message.header: + raw_message.header = AmqpMessageHeader(header=True) + else: + raw_message.header.durable = True + + def set_message_partition_key(message, partition_key): # type: (Message, Optional[Union[bytes, str]]) -> None """Set the partition key as an annotation on a uamqp message. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py index 0c74c27abc63..f517c2385bd5 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py @@ -3,4 +3,4 @@ # Licensed under the MIT License. # ------------------------------------ -VERSION = "5.9.1" +VERSION = "5.10.0" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_async_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_async_utils.py index 9e604a982b53..a3f0c995e3ae 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_async_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_async_utils.py @@ -3,8 +3,10 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. #-------------------------------------------------------------------------- - +import asyncio import sys +from asyncio import Semaphore + def get_dict_with_loop_if_needed(loop): if sys.version_info >= (3, 10): @@ -13,3 +15,10 @@ def get_dict_with_loop_if_needed(loop): elif loop: return {'loop': loop} return {} + + +async def semaphore_acquire_with_timeout(semaphore: Semaphore, timeout=None): + try: + return await asyncio.wait_for(semaphore.acquire(), timeout=timeout) + except asyncio.TimeoutError: + return False diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/__init__.py new file mode 100644 index 000000000000..4d9cd95fb90a --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/__init__.py @@ -0,0 +1,13 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from ._buffered_producer_dispatcher_async import BufferedProducerDispatcher +from ._partition_resolver_async import PartitionResolver +from ._buffered_producer_async import BufferedProducer + +__all__ = [ + "BufferedProducerDispatcher", + "PartitionResolver", + "BufferedProducer" +] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py new file mode 100644 index 000000000000..945c5b8f2699 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py @@ -0,0 +1,183 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import asyncio +import logging +import queue +import time +from asyncio import Lock +from typing import Optional, Callable, Awaitable, TYPE_CHECKING + +from .._producer_async import EventHubProducer +from ..._common import EventDataBatch +from ...exceptions import OperationTimeoutError + +if TYPE_CHECKING: + from ..._producer_client import SendEventTypes + +_LOGGER = logging.getLogger(__name__) + + +class BufferedProducer: + # pylint: disable=too-many-instance-attributes + def __init__( + self, + producer: EventHubProducer, + partition_id: str, + on_success: Callable[["SendEventTypes", Optional[str]], Awaitable[None]], + on_error: Callable[["SendEventTypes", Optional[str], Exception], Awaitable[None]], + max_message_size_on_link: int, + *, + max_wait_time: float = 1, + max_buffer_length: int + ): + self._buffered_queue: queue.Queue = queue.Queue() + self._max_buffer_len = max_buffer_length + self._cur_buffered_len = 0 + self._producer: EventHubProducer = producer + self._lock = Lock() + self._max_wait_time = max_wait_time + self._on_success = self.failsafe_callback(on_success) + self._on_error = self.failsafe_callback(on_error) + self._last_send_time = None + self._running = False + self._cur_batch: Optional[EventDataBatch] = None + self._max_message_size_on_link = max_message_size_on_link + self._check_max_wait_time_future = None + self.partition_id = partition_id + + async def start(self): + async with self._lock: + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._running = True + if self._max_wait_time: + self._last_send_time = time.time() + self._check_max_wait_time_future = asyncio.ensure_future(self.check_max_wait_time_worker()) + + async def stop(self, flush=True, timeout_time=None, raise_error=False): + self._running = False + if flush: + async with self._lock: + await self._flush(timeout_time=timeout_time, raise_error=raise_error) + else: + if self._cur_buffered_len: + _LOGGER.warning( + "Shutting down Partition %r." " There are still %r events in the buffer which will be lost", + self.partition_id, + self._cur_buffered_len, + ) + if self._check_max_wait_time_future: + try: + await self._check_max_wait_time_future + except Exception as exc: # pylint: disable=broad-except + _LOGGER.warning("Partition %r stopped with error %r", self.partition_id, exc) + await self._producer.close() + + async def put_events(self, events, timeout_time=None): + # Put single event or EventDataBatch into the queue. + # This method would raise OperationTimeout if the queue does not have enough space for the input and + # flush cannot finish in timeout. + try: + new_events_len = len(events) + except TypeError: + new_events_len = 1 + + if self._max_buffer_len - self._cur_buffered_len < new_events_len: + _LOGGER.info( + "The buffer for partition %r is full. Attempting to flush before adding %r events.", + self.partition_id, + new_events_len, + ) + # flush the buffer + async with self._lock: + await self._flush(timeout_time=timeout_time) + + if timeout_time and time.time() > timeout_time: + raise OperationTimeoutError("Failed to enqueue events into buffer due to timeout.") + try: + # add single event into current batch + self._cur_batch.add(events) + except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer + # if there are events in cur_batch, enqueue cur_batch to the buffer + if self._cur_batch: + self._buffered_queue.put(self._cur_batch) + self._buffered_queue.put(events) + # create a new batch for incoming events + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + except ValueError: + # add single event exceeds the cur batch size, create new batch + self._buffered_queue.put(self._cur_batch) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch.add(events) + self._cur_buffered_len += new_events_len + + def failsafe_callback(self, callback): + async def wrapper_callback(*args, **kwargs): + try: + await callback(*args, **kwargs) + except Exception as exc: # pylint: disable=broad-except + _LOGGER.warning( + "On partition %r, callback %r encountered exception %r", callback.__name__, exc, self.partition_id + ) + + return wrapper_callback + + async def flush(self, timeout_time=None, raise_error=True): + async with self._lock: + await self._flush(timeout_time, raise_error) + + async def _flush(self, timeout_time=None, raise_error=True): + # pylint: disable=protected-access + # try flushing all the buffered batch within given time + _LOGGER.info("Partition: %r started flushing.", self.partition_id) + if self._cur_batch: # if there is batch, enqueue it to the buffer first + self._buffered_queue.put(self._cur_batch) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + while self._cur_buffered_len: + remaining_time = timeout_time - time.time() if timeout_time else None + if (remaining_time and remaining_time > 0) or remaining_time is None: + batch = self._buffered_queue.get() + self._buffered_queue.task_done() + try: + _LOGGER.info("Partition %r is sending.", self.partition_id) + await self._producer.send(batch, timeout=timeout_time - time.time() if timeout_time else None) + _LOGGER.info("Partition %r sending %r events succeeded.", self.partition_id, len(batch)) + await self._on_success(batch._internal_events, self.partition_id) + except Exception as exc: # pylint: disable=broad-except + _LOGGER.info( + "Partition %r sending %r events failed due to exception: %r", self.partition_id, len(batch), exc + ) + await self._on_error(batch._internal_events, self.partition_id, exc) + finally: + self._cur_buffered_len -= len(batch) + # If flush could not get the semaphore, we log and raise error if wanted + else: + _LOGGER.info("Partition %r fails to flush due to timeout.", self.partition_id) + if raise_error: + raise OperationTimeoutError( + "Failed to flush {!r} within {}".format(self.partition_id, timeout_time) + ) + break + # after finishing flushing, reset cur batch and put it into the buffer + self._last_send_time = time.time() + self._cur_batch = EventDataBatch(self._max_message_size_on_link) + _LOGGER.info("Partition %r finished flushing.", self.partition_id) + + async def check_max_wait_time_worker(self): + while self._running: + if self._cur_buffered_len > 0: + now_time = time.time() + _LOGGER.info("Partition %r worker is checking max_wait_time.", self.partition_id) + # flush the partition if its beyond the waiting time or the buffer is at max capacity + if (now_time - self._last_send_time > self._max_wait_time) or ( + self._cur_buffered_len >= self._max_buffer_len + ): + # in the worker, not raising error for flush, users can not handle this + async with self._lock: + await self._flush(raise_error=False) + await asyncio.sleep(min(self._max_wait_time, 5)) + + @property + def buffered_event_count(self): + return self._cur_buffered_len diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py new file mode 100644 index 000000000000..25de2cc6f728 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py @@ -0,0 +1,140 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import asyncio +import logging +from typing import Dict, List, Callable, Optional, Awaitable, TYPE_CHECKING +from asyncio import Lock + +from ._partition_resolver_async import PartitionResolver +from ...aio._producer_async import EventHubProducer +from ._buffered_producer_async import BufferedProducer +from ...exceptions import EventDataSendError, ConnectError, EventHubError + +if TYPE_CHECKING: + from ..._producer_client import SendEventTypes + +_LOGGER = logging.getLogger(__name__) + + +class BufferedProducerDispatcher: + # pylint: disable=too-many-instance-attributes + def __init__( + self, + partitions: List[str], + on_success: Callable[["SendEventTypes", Optional[str]], Awaitable[None]], + on_error: Callable[["SendEventTypes", Optional[str], Exception], Awaitable[None]], + create_producer: Callable[..., EventHubProducer], + eventhub_name: str, + max_message_size_on_link: int, + *, + max_buffer_length: int = 1500, + max_wait_time: float = 1, + ): + self._buffered_producers: Dict[str, BufferedProducer] = {} + self._partition_ids: List[str] = partitions + self._lock = Lock() + self._on_success = on_success + self._on_error = on_error + self._create_producer = create_producer + self._eventhub_name = eventhub_name + self._max_message_size_on_link = max_message_size_on_link + self._partition_resolver = PartitionResolver(self._partition_ids) + self._max_wait_time = max_wait_time + self._max_buffer_length = max_buffer_length + + async def _get_partition_id(self, partition_id, partition_key): + if partition_id: + if partition_id not in self._partition_ids: + raise ConnectError( + "Invalid partition {} for the event hub {}".format( + partition_id, self._eventhub_name + ) + ) + return partition_id + if isinstance(partition_key, str): + return await self._partition_resolver.get_partition_id_by_partition_key(partition_key) + return await self._partition_resolver.get_next_partition_id() + + async def enqueue_events(self, events, *, partition_id=None, partition_key=None, timeout_time=None): + pid = await self._get_partition_id(partition_id, partition_key) + async with self._lock: + try: + await self._buffered_producers[pid].put_events(events, timeout_time) + except KeyError: + buffered_producer = BufferedProducer( + self._create_producer(partition_id=pid), + pid, + self._on_success, + self._on_error, + self._max_message_size_on_link, + max_wait_time=self._max_wait_time, + max_buffer_length=self._max_buffer_length + ) + await buffered_producer.start() + self._buffered_producers[pid] = buffered_producer + await buffered_producer.put_events(events, timeout_time) + + async def flush(self, timeout_time=None): + # flush all the buffered producer, the method will block until finishes or times out + async with self._lock: + futures = [] + for pid, producer in self._buffered_producers.items(): + # call each producer's flush method + futures.append((pid, asyncio.ensure_future(producer.flush(timeout_time=timeout_time)))) + + # gather results + exc_results = {} + for pid, future in futures: + try: + await future + except Exception as exc: # pylint: disable=broad-except + exc_results[pid] = exc + + if not exc_results: + _LOGGER.info("Flushing all partitions succeeded") + return + + _LOGGER.warning('Flushing all partitions partially failed with result %r.', exc_results) + raise EventDataSendError( + message="Flushing all partitions partially failed, failed partitions are {!r}" + " Exception details are {!r}".format(exc_results.keys(), exc_results) + ) + + async def close(self, *, flush=True, timeout_time=None, raise_error=False): + + async with self._lock: + + futures = [] + # stop all buffered producers + for pid, producer in self._buffered_producers.items(): + futures.append((pid, asyncio.ensure_future( + producer.stop(flush=flush, timeout_time=timeout_time, raise_error=raise_error) + ))) + + exc_results = {} + # gather results + for pid, future in futures: + try: + await future + except Exception as exc: # pylint: disable=broad-except + exc_results[pid] = exc + + if exc_results: + _LOGGER.warning('Stopping all partitions failed with result %r.', exc_results) + if raise_error: + raise EventHubError( + message="Stopping all partitions partially failed, failed partitions are {!r}" + " Exception details are {!r}".format(exc_results.keys(), exc_results) + ) + + def get_buffered_event_count(self, pid): + try: + return self._buffered_producers[pid].buffered_event_count + except KeyError: + return 0 + + @property + def total_buffered_event_count(self): + return sum([self.get_buffered_event_count(pid) for pid in self._buffered_producers]) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_partition_resolver_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_partition_resolver_async.py new file mode 100644 index 000000000000..5cd77a86145b --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_partition_resolver_async.py @@ -0,0 +1,31 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +""" +jenkins-hash lookup3 algorithm implementation +""" + +from asyncio import Lock +from ..._buffered_producer._partition_resolver import generate_hash_code # pylint: disable=protected-access + + +class PartitionResolver: + def __init__(self, partitions): + self._idx = -1 + self._partitions = partitions + self._partitions_cnt = len(self._partitions) + self._lock = Lock() + + async def get_next_partition_id(self): + """ + round-robin partition assignment + """ + async with self._lock: + self._idx += 1 + self._idx %= self._partitions_cnt + return self._partitions[self._idx] + + async def get_partition_id_by_partition_key(self, partition_key): + hash_code = generate_hash_code(partition_key) + return self._partitions[abs(hash_code % self._partitions_cnt)] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index 102116b82a91..622f9eea87ce 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -24,6 +24,7 @@ transform_outbound_single_message, ) from .._constants import TIMEOUT_SYMBOL +from ..amqp import AmqpAnnotatedMessage from ._client_base_async import ConsumerProducerMixin from ._async_utils import get_dict_with_loop_if_needed @@ -173,11 +174,11 @@ def _on_outcome( def _wrap_eventdata( self, - event_data: Union[EventData, EventDataBatch, Iterable[EventData]], + event_data: Union[EventData, AmqpAnnotatedMessage, EventDataBatch, Iterable[EventData]], span: Optional[AbstractSpan], partition_key: Optional[AnyStr], ) -> Union[EventData, EventDataBatch]: - if isinstance(event_data, EventData): + if isinstance(event_data, (EventData, AmqpAnnotatedMessage)): outgoing_event_data = transform_outbound_single_message(event_data, EventData) if partition_key: set_message_partition_key(outgoing_event_data.message, partition_key) @@ -206,7 +207,7 @@ def _wrap_eventdata( async def send( self, - event_data: Union[EventData, EventDataBatch, Iterable[EventData]], + event_data: Union[EventData, AmqpAnnotatedMessage, EventDataBatch, Iterable[EventData]], *, partition_key: Optional[AnyStr] = None, timeout: Optional[float] = None diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index e628cf1817de..024a776d11e8 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -4,14 +4,31 @@ # -------------------------------------------------------------------------------------------- import asyncio import logging - -from typing import Any, Union, TYPE_CHECKING, List, Optional, Dict, cast +import time + +from typing import ( + Any, + Union, + List, + Optional, + Dict, + Callable, + cast +) +from typing_extensions import ( + TYPE_CHECKING, + Literal, + Awaitable, + overload +) from uamqp import constants from ..exceptions import ConnectError, EventHubError from ..amqp import AmqpAnnotatedMessage from ._client_base_async import ClientBaseAsync from ._producer_async import EventHubProducer +from ._buffered_producer import BufferedProducerDispatcher +from .._utils import set_event_partition_key from .._constants import ALL_PARTITIONS from .._common import EventDataBatch, EventData @@ -24,7 +41,8 @@ _LOGGER = logging.getLogger(__name__) -class EventHubProducerClient(ClientBaseAsync): # pylint: disable=client-accepts-api-version-keyword +class EventHubProducerClient(ClientBaseAsync): # pylint: disable=client-accepts-api-version-keyword + # pylint: disable=too-many-instance-attributes """ The EventHubProducerClient class defines a high level interface for sending events to the Azure Event Hubs service. @@ -38,6 +56,36 @@ class EventHubProducerClient(ClientBaseAsync): # pylint: disable=client-accept generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method. :type credential: ~azure.core.credentials_async.AsyncTokenCredential or ~azure.core.credentials.AzureSasCredential or ~azure.core.credentials.AzureNamedKeyCredential + :keyword bool buffered_mode: If True, the producer client will collect events in a buffer, efficiently batch, + then publish. Default is False. + :keyword on_success: The callback to be called once a batch has been successfully published. + The callback takes two parameters: + - `events`: The list of events that have been successfully published + - `partition_id`: The partition id that the events in the list have been published to. + The callback function should be defined like: `on_success(events, partition_id)`. + Required when `buffered_mode` is True while optional if `buffered_mode` is False. + :paramtype on_success: Optional[Callable[[SendEventTypes, Optional[str]], Awaitable[None]]] + :keyword on_error: The callback to be called once a batch has failed to be published. + Required when in `buffered_mode` is True while optional if `buffered_mode` is False. + The callback function should be defined like: `on_error(events, partition_id, error)`, where: + - `events`: The list of events that failed to be published, + - `partition_id`: The partition id that the events in the list have been tried to be published to and + - `error`: The exception related to the sending failure. + If `buffered_mode` is False, `on_error` callback is optional and errors will be handled as follows: + - If an `on_error` callback is passed during the producer client instantiation, + then error information will be passed to the `on_error` callback, which will then be called. + - If an `on_error` callback is not passed in during client instantiation, + then the error will be raised by default. + If `buffered_mode` is True, `on_error` callback is required and errors will be handled as follows: + - If events fail to enqueue within the given timeout, then an error will be directly raised. + - If events fail to send after enqueuing successfully, the `on_error` callback will be called. + :paramtype on_error: Optional[Callable[[SendEventTypes, Optional[str], Exception], Awaitable[None]]] + :keyword int max_buffer_length: Buffered mode only. + The total number of events per partition that can be buffered before a flush will be triggered. + The default value is 1500 in buffered mode. + :keyword Optional[float] max_wait_time: Buffered mode only. + The amount of time to wait for a batch to be built with events in the buffer before publishing. + The default value is 1 in buffered mode. :keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`. :keyword float auth_timeout: The time in seconds to wait for a token to be authorized by the service. The default value is 60 seconds. If set to 0, no timeout will be enforced from the client. @@ -62,7 +110,7 @@ class EventHubProducerClient(ClientBaseAsync): # pylint: disable=client-accept If the port 5671 is unavailable/blocked in the network environment, `TransportType.AmqpOverWebsocket` could be used instead which uses port 443 for communication. :paramtype transport_type: ~azure.eventhub.TransportType - :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + :keyword dict http_proxy: HTTP proxy settings. This must be a dictionary with the following keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). Additionally the following keys may also be present: `'username', 'password'`. :keyword str custom_endpoint_address: The custom endpoint address to use for establishing a connection to @@ -84,12 +132,46 @@ class EventHubProducerClient(ClientBaseAsync): # pylint: disable=client-accept :caption: Create a new instance of the EventHubProducerClient. """ + @overload + def __init__( + self, + fully_qualified_namespace: str, + eventhub_name: str, + credential: "CredentialTypes", + *, + buffered_mode: Literal[False] = False, + **kwargs: Any + ) -> None: + ... + + @overload + def __init__( + self, + fully_qualified_namespace: str, + eventhub_name: str, + credential: "CredentialTypes", + *, + buffered_mode: Literal[True], + on_error: Callable[[SendEventTypes, Optional[str], Exception], Awaitable[None]], + on_success: Callable[[SendEventTypes, Optional[str]], Awaitable[None]], + max_buffer_length: int = 1500, + max_wait_time: float = 1, + **kwargs: Any + ) -> None: + ... + def __init__( self, fully_qualified_namespace: str, eventhub_name: str, credential: "CredentialTypes", - **kwargs + *, + buffered_mode: bool = False, + on_error: Optional[Callable[[SendEventTypes, Optional[str], Exception], Awaitable[None]]] = None, + on_success: Optional[Callable[[SendEventTypes, Optional[str]], Awaitable[None]]] = None, + max_buffer_length: Optional[int] = None, + max_wait_time: Optional[float] = None, + **kwargs: Any ) -> None: super(EventHubProducerClient, self).__init__( fully_qualified_namespace=fully_qualified_namespace, @@ -106,6 +188,31 @@ def __init__( ) # sync the creation of self._producers self._max_message_size_on_link = 0 self._partition_ids = None # Optional[List[str]] + self._buffered_mode = buffered_mode + self._on_success = on_success + self._on_error = on_error + self._buffered_producer_dispatcher = None + self._max_buffer_length = max_buffer_length + self._max_wait_time = max_wait_time + if self._buffered_mode: + setattr(self, "send_batch", self._buffered_send_batch) + setattr(self, "send_event", self._buffered_send_event) + if not self._on_error: + raise TypeError( + "EventHubProducerClient in buffered mode missing 1 required keyword argument: 'on_error'" + ) + if not self._on_success: + raise TypeError( + "EventHubProducerClient in buffered mode missing 1 required keyword argument: 'on_success'" + ) + if self._max_wait_time is None: + self._max_wait_time = 1 + if self._max_wait_time <= 0: + raise ValueError("'max_wait_time' must be a float greater than 0 in buffered mode") + if self._max_buffer_length is None: + self._max_buffer_length = 1500 + if self._max_buffer_length <= 0: + raise ValueError("'max_buffer_length' must be an integer greater than 0 in buffered mode") async def __aenter__(self): return self @@ -113,13 +220,79 @@ async def __aenter__(self): async def __aexit__(self, *args): await self.close() + async def _buffered_send(self, events, **kwargs): + try: + await self._buffered_producer_dispatcher.enqueue_events(events, **kwargs) + except AttributeError: + await self._get_partitions() + await self._get_max_message_size() + self._buffered_producer_dispatcher = BufferedProducerDispatcher( + self._partition_ids, + self._on_success, + self._on_error, + self._create_producer, + self.eventhub_name, + self._max_message_size_on_link, + max_wait_time=self._max_wait_time, + max_buffer_length=self._max_buffer_length + ) + await self._buffered_producer_dispatcher.enqueue_events(events, **kwargs) + + async def _batch_preparer(self, event_data_batch, **kwargs): + partition_id = kwargs.pop("partition_id", None) + partition_key = kwargs.pop("partition_key", None) + + if isinstance(event_data_batch, EventDataBatch): + if partition_id or partition_key: + raise TypeError( + "partition_id and partition_key should be None when sending an EventDataBatch " + "because type EventDataBatch itself may have partition_id or partition_key" + ) + to_send_batch = event_data_batch + else: + to_send_batch = await self.create_batch( + partition_id=partition_id, partition_key=partition_key + ) + to_send_batch._load_events( # pylint:disable=protected-access + event_data_batch + ) + + return to_send_batch, to_send_batch._partition_id, partition_key # pylint:disable=protected-access + + async def _buffered_send_batch(self, event_data_batch, **kwargs): + batch, pid, pkey = await self._batch_preparer(event_data_batch, **kwargs) + + if len(batch) == 0: + return + + timeout = kwargs.get("timeout") + timeout_time = time.time() + timeout if timeout else None + await self._buffered_send( + event_data_batch, + partition_id=pid, + partition_key=pkey, + timeout_time=timeout_time + ) + + async def _buffered_send_event(self, event, **kwargs): + partition_key = kwargs.get("partition_key") + set_event_partition_key(event, partition_key) + timeout = kwargs.get("timeout") + timeout_time = time.time() + timeout if timeout else None + await self._buffered_send( + event, + partition_id=kwargs.get("partition_id"), + partition_key=partition_key, + timeout_time=timeout_time + ) + async def _get_partitions(self) -> None: if not self._partition_ids: self._partition_ids = await self.get_partition_ids() # type: ignore for p_id in cast(List[str], self._partition_ids): self._producers[p_id] = None - async def _get_max_mesage_size(self) -> None: + async def _get_max_message_size(self) -> None: # pylint: disable=protected-access,line-too-long async with self._lock: if not self._max_message_size_on_link: @@ -181,11 +354,44 @@ def _create_producer( return handler @classmethod + @overload + def from_connection_string( + cls, + conn_str: str, + *, + eventhub_name: Optional[str] = None, + buffered_mode: Literal[False] = False, + **kwargs: Any + ) -> "EventHubProducerClient": + ... + + @classmethod + @overload def from_connection_string( cls, conn_str: str, *, eventhub_name: Optional[str] = None, + buffered_mode: Literal[True], + on_error: Callable[[SendEventTypes, Optional[str], Exception], Awaitable[None]], + on_success: Callable[[SendEventTypes, Optional[str]], Awaitable[None]], + max_buffer_length: int = 1500, + max_wait_time: float = 1, + **kwargs: Any + ) -> "EventHubProducerClient": + ... + + @classmethod + def from_connection_string( + cls, + conn_str: str, + *, + eventhub_name: Optional[str] = None, + buffered_mode: bool = False, + on_error: Optional[Callable[[SendEventTypes, Optional[str], Exception], Awaitable[None]]] = None, + on_success: Optional[Callable[[SendEventTypes, Optional[str]], Awaitable[None]]] = None, + max_buffer_length: Optional[int] = None, + max_wait_time: Optional[float] = None, logging_enable: bool = False, http_proxy: Optional[Dict[str, Union[str, int]]] = None, auth_timeout: float = 60, @@ -198,8 +404,37 @@ def from_connection_string( :param str conn_str: The connection string of an Event Hub. :keyword str eventhub_name: The path of the specific Event Hub to connect the client to. + :keyword bool buffered_mode: If True, the producer client will collect events in a buffer, efficiently batch, + then publish. Default is False. + :keyword on_success: The callback to be called once a batch has been successfully published. + The callback takes two parameters: + - `events`: The list of events that have been successfully published + - `partition_id`: The partition id that the events in the list have been published to. + The callback function should be defined like: `on_success(events, partition_id)`. + It is required when `buffered_mode` is True while optional if `buffered_mode` is False. + :paramtype on_success: Optional[Callable[[SendEventTypes, Optional[str]], Awaitable[None]]] + :keyword on_error: The callback to be called once a batch has failed to be published. + The callback function should be defined like: `on_error(events, partition_id, error)`, where: + - `events`: The list of events that failed to be published, + - `partition_id`: The partition id that the events in the list have been tried to be published to and + - `error`: The exception related to the sending failure. + If `buffered_mode` is False, `on_error` callback is optional and errors will be handled as follows: + - If an `on_error` callback is passed during the producer client instantiation, + then error information will be passed to the `on_error` callback, which will then be called. + - If an `on_error` callback is not passed in during client instantiation, + then the error will be raised by default. + If `buffered_mode` is True, `on_error` callback is required and errors will be handled as follows: + - If events fail to enqueue within the given timeout, then an error will be directly raised. + - If events fail to send after enqueuing successfully, the `on_error` callback will be called. + :paramtype on_error: Optional[Callable[[SendEventTypes, Optional[str], Exception], Awaitable[None]]] + :keyword int max_buffer_length: Buffered mode only. + The total number of events per partition that can be buffered before a flush will be triggered. + The default value is 1500 in buffered mode. + :keyword Optional[float] max_wait_time: Buffered mode only. + The amount of time to wait for a batch to be built with events in the buffer before publishing. + The default value is 1 in buffered mode. :keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`. - :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + :keyword dict http_proxy: HTTP proxy settings. This must be a dictionary with the following keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). Additionally the following keys may also be present: `'username', 'password'`. :keyword float auth_timeout: The time in seconds to wait for a token to be authorized by the service. @@ -247,6 +482,11 @@ def from_connection_string( constructor_args = cls._from_connection_string( conn_str, eventhub_name=eventhub_name, + buffered_mode=buffered_mode, + on_success=on_success, + on_error=on_error, + max_buffer_length=max_buffer_length, + max_wait_time=max_wait_time, logging_enable=logging_enable, http_proxy=http_proxy, auth_timeout=auth_timeout, @@ -257,26 +497,116 @@ def from_connection_string( ) return cls(**constructor_args) + async def send_event( + self, + event_data: Union[EventData, AmqpAnnotatedMessage], + **kwargs: Any + ) -> None: + """ + Sends an event data. + By default, the method will block until acknowledgement is received or operation times out. + If the `EventHubProducerClient` is configured to run in buffered mode, the method will enqueue the event + into local buffer and return. The producer will do automatic batching and sending in the background. + + If `buffered_mode` is False, `on_error` callback is optional and errors will be handled as follows: + - If an `on_error` callback is passed during the producer client instantiation, + then error information will be passed to the `on_error` callback, which will then be called. + - If an `on_error` callback is not passed in during client instantiation, + then the error will be raised by default. + + If `buffered_mode` is True, `on_error` callback is required and errors will be handled as follows: + - If events fail to enqueue within the given timeout, then an error will be directly raised. + - If events fail to send after enqueuing successfully, the `on_error` callback will be called. + + :param event_data: The `EventData` object to be sent. + :type event_data: Union[~azure.eventhub.EventData, ~azure.eventhub.amqp.AmqpAnnotatedMessage] + :keyword float timeout: The maximum wait time to send the event data in non-buffered mode or the + maximum wait time to enqueue the event data into the buffer in buffered mode. + In non-buffered mode, the default wait time specified when the producer + was created will be used. In buffered mode, the default wait time is None. + :keyword str partition_id: The specific partition ID to send to. Default is None, in which case the service + will assign to all partitions using round-robin. + A `TypeError` will be raised if partition_id is specified and event_data_batch is an `EventDataBatch` because + `EventDataBatch` itself has partition_id. + :keyword str partition_key: With the given partition_key, event data will be sent to + a particular partition of the Event Hub decided by the service. + A `TypeError` will be raised if partition_key is specified and event_data_batch is an `EventDataBatch` because + `EventDataBatch` itself has partition_key. + If both partition_id and partition_key are provided, the partition_id will take precedence. + **WARNING: Setting partition_key of non-string value on the events to be sent is discouraged + as the partition_key will be ignored by the Event Hub service and events will be assigned + to all partitions using round-robin. Furthermore, there are SDKs for consuming events which expect + partition_key to only be string type, they might fail to parse the non-string value.** + :rtype: None + :raises: :class:`AuthenticationError` + :class:`ConnectError` + :class:`ConnectionLostError` + :class:`EventDataError` + :class:`EventDataSendError` + :class:`EventHubError` + :raises OperationTimeoutError: If the value specified by the timeout parameter elapses before the event can be + sent in non-buffered mode or the events can not be enqueued into the buffered in buffered mode. + """ + input_pid = kwargs.get("partition_id") + pid = input_pid or ALL_PARTITIONS + partition_key = kwargs.get("partition_key") + timeout = kwargs.get("timeout") + try: + try: + await cast(EventHubProducer, self._producers[pid]).send( + event_data, partition_key=partition_key, timeout=timeout + ) + except (KeyError, AttributeError, EventHubError): + await self._start_producer(pid, timeout) + await cast(EventHubProducer, self._producers[pid]).send( + event_data, partition_key=partition_key, timeout=timeout + ) + if self._on_success: + await self._on_success([event_data], input_pid) + except Exception as exc: # pylint: disable=broad-except + if self._on_error: + await self._on_error([event_data], input_pid, exc) + else: + raise + async def send_batch( self, event_data_batch: Union[EventDataBatch, SendEventTypes], - *, - timeout: Optional[Union[int, float]] = None, - **kwargs + **kwargs: Any ) -> None: - """Sends event data and blocks until acknowledgement is received or operation times out. - - If you're sending a finite list of `EventData` or `AmqpAnnotatedMessage` and you know it's within the event hub - frame size limit, you can send them with a `send_batch` call. Otherwise, use :meth:`create_batch` + # pylint: disable=protected-access + """ + Sends a batch of event data. + By default, the method will block until acknowledgement is received or operation times out. + If the `EventHubProducerClient` is configured to run in buffered mode, the method will enqueue the events + into local buffer and return. The producer will do automatic sending in the background. + + If `buffered_mode` is False, `on_error` callback is optional and errors will be handled as follows: + - If an `on_error` callback is passed during the producer client instantiation, + then error information will be passed to the `on_error` callback, which will then be called. + - If an `on_error` callback is not passed in during client instantiation, + then the error will be raised by default. + + If `buffered_mode` is True, `on_error` callback is required and errors will be handled as follows: + - If events fail to enqueue within the given timeout, then an error will be directly raised. + - If events fail to send after enqueuing successfully, the `on_error` callback will be called. + + In buffered mode, sending a batch will remain intact and sent as a single unit. + The batch will not be rearranged. This may result in inefficiency of sending events. + + If you're sending a finite list of `EventData` or `AmqpAnnotatedMessage` and you know it's within the + event hub frame size limit, you can send them with a `send_batch` call. Otherwise, use :meth:`create_batch` to create `EventDataBatch` and add either `EventData` or `AmqpAnnotatedMessage` into the batch one by one until the size limit, and then call this method to send out the batch. - :param event_data_batch: The `EventDataBatch` object to be sent or a list of `EventData` to be sent - in a batch. All `EventData` in the list or `EventDataBatch` will land on the same partition. + :param event_data_batch: The `EventDataBatch` object to be sent or a list of `EventData` to be sent in a batch. + All `EventData` or `AmqpAnnotatedMessage` in the list or `EventDataBatch` will land on the same partition. :type event_data_batch: Union[~azure.eventhub.EventDataBatch, List[Union[~azure.eventhub.EventData, ~azure.eventhub.amqp.AmqpAnnotatedMessage]] - :keyword float timeout: The maximum wait time to send the event data. - If not specified, the default wait time specified when the producer was created will be used. + :keyword float timeout: The maximum wait time to send the event data in non-buffered mode or the + maximum wait time to enqueue the event data into the buffer in buffered mode. + In non-buffered mode, the default wait time specified when the producer + was created will be used. In buffered mode, the default wait time is None. :keyword str partition_id: The specific partition ID to send to. Default is None, in which case the service will assign to all partitions using round-robin. A `TypeError` will be raised if partition_id is specified and event_data_batch is an `EventDataBatch` because @@ -299,6 +629,8 @@ async def send_batch( :class:`EventHubError` :class:`ValueError` :class:`TypeError` + :raises OperationTimeoutError: If the value specified by the timeout parameter elapses before the event can be + sent in non-buffered mode or the events can be enqueued into the buffered in buffered mode. .. admonition:: Example: @@ -310,40 +642,33 @@ async def send_batch( :caption: Asynchronously sends event data """ - partition_id = kwargs.get("partition_id") - partition_key = kwargs.get("partition_key") - - if isinstance(event_data_batch, EventDataBatch): - if partition_id or partition_key: - raise TypeError( - "partition_id and partition_key should be None when sending an EventDataBatch " - "because type EventDataBatch itself may have partition_id or partition_key" - ) - to_send_batch = event_data_batch - else: - to_send_batch = await self.create_batch( - partition_id=partition_id, partition_key=partition_key - ) - to_send_batch._load_events( # pylint:disable=protected-access - event_data_batch - ) + batch, pid, pkey = await self._batch_preparer(event_data_batch, **kwargs) - if len(to_send_batch) == 0: + if len(batch) == 0: return - partition_id = ( - to_send_batch._partition_id # pylint:disable=protected-access - or ALL_PARTITIONS - ) + partition_id = pid or ALL_PARTITIONS + timeout = kwargs.pop("timeout", None) + try: - await cast(EventHubProducer, self._producers[partition_id]).send( - to_send_batch, timeout=timeout - ) - except (KeyError, AttributeError, EventHubError): - await self._start_producer(partition_id, timeout) - await cast(EventHubProducer, self._producers[partition_id]).send( - to_send_batch, timeout=timeout - ) + try: + await cast(EventHubProducer, self._producers[partition_id]).send( + batch, partition_key=pkey, timeout=timeout + ) + if self._on_success: + await self._on_success(batch._internal_events, pid) + except (KeyError, AttributeError, EventHubError): + await self._start_producer(partition_id, timeout) + await cast(EventHubProducer, self._producers[partition_id]).send( + batch, partition_key=pkey, timeout=timeout + ) + if self._on_success: + await self._on_success(batch._internal_events, pid) + except Exception as exc: # pylint: disable=broad-except + if self._on_error: + await self._on_error(batch._internal_events, pid, exc) + else: + raise async def create_batch( self, @@ -380,7 +705,7 @@ async def create_batch( """ if not self._max_message_size_on_link: - await self._get_max_mesage_size() + await self._get_max_message_size() if max_size_in_bytes and max_size_in_bytes > self._max_message_size_on_link: raise ValueError( @@ -443,10 +768,37 @@ async def get_partition_properties(self, partition_id: str) -> Dict[str, Any]: EventHubProducerClient, self )._get_partition_properties_async(partition_id) - async def close(self) -> None: + async def flush(self, **kwargs: Any) -> None: + """ + Buffered mode only. + Flush events in the buffer to be sent immediately if the client is working in buffered mode. + + :keyword Optional[float] timeout: Timeout to flush the buffered events, default is None which means no timeout. + :rtype: None + :raises EventDataSendError: If the producer fails to flush the buffer within the given timeout + in buffered mode. + """ + async with self._lock: + if self._buffered_mode and self._buffered_producer_dispatcher: + timeout = kwargs.get("timeout") + timeout_time = time.time() + timeout if timeout else None + await self._buffered_producer_dispatcher.flush(timeout_time=timeout_time) + + async def close( + self, + *, + flush: bool = True, + **kwargs: Any + ) -> None: """Close the Producer client underlying AMQP connection and links. + :keyword bool flush: Buffered mode only. If set to True, events in the buffer will be sent + immediately. Default is True. + :keyword Optional[float] timeout: Buffered mode only. Timeout to close the producer. + Default is None which means no timeout. :rtype: None + :raises EventHubError: If an error occurred when flushing the buffer if `flush` is set to True or closing the + underlying AMQP connections in buffered mode. .. admonition:: Example: @@ -459,9 +811,57 @@ async def close(self) -> None: """ async with self._lock: + if self._buffered_mode and self._buffered_producer_dispatcher: + timeout = kwargs.get("timeout") + timeout_time = time.time() + timeout if timeout else None + await self._buffered_producer_dispatcher.close( + flush=flush, + timeout_time=timeout_time, + raise_error=True + ) + self._buffered_producer_dispatcher = None + for pid in self._producers: if self._producers[pid] is not None: await self._producers[pid].close() # type: ignore self._producers[pid] = None await super(EventHubProducerClient, self)._close_async() + + def get_buffered_event_count(self, partition_id: str) -> Optional[int]: + """ + The number of events that are buffered and waiting to be published for a given partition. + Returns None in non-buffered mode. + + :param str partition_id: The target partition ID. + :rtype: int or None + """ + if not self._buffered_mode: + return None + + try: + return cast( + BufferedProducerDispatcher, + self._buffered_producer_dispatcher + ).get_buffered_event_count(partition_id) + except AttributeError: + return 0 + + @property + def total_buffered_event_count(self) -> Optional[int]: + """ + The total number of events that are currently buffered and waiting to be published, across all partitions. + Returns None in non-buffered mode. + + :rtype: int or None + """ + if not self._buffered_mode: + return None + + try: + return cast( + BufferedProducerDispatcher, + self._buffered_producer_dispatcher + ).total_buffered_event_count + except AttributeError: + return 0 diff --git a/sdk/eventhub/azure-eventhub/samples/README.md b/sdk/eventhub/azure-eventhub/samples/README.md index bc1be0ed6652..33f613cc12e9 100644 --- a/sdk/eventhub/azure-eventhub/samples/README.md +++ b/sdk/eventhub/azure-eventhub/samples/README.md @@ -80,6 +80,11 @@ Both [sync version](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/ - Send AMQPAnnotatedMessage of different body types. - Receive messages and parse the body according to the body type. +- [send_buffered_mode.py](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/eventhub/azure-eventhub/samples/sync_samples/send_buffered_mode.py) ([async_version](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/eventhub/azure-eventhub/samples/async_samples/send_buffered_mode_async.py)) - Examples to send events in buffered mode: + - Send single events, which will be automatically batched. + - Send a batch of events by enqueuing an EventDataBatch object to the buffer. + - Send events in buffer immediately by calling `flush`. + ## Prerequisites - Python 3.6 or later. - **Microsoft Azure Subscription:** To use Azure services, including Azure Event Hubs, you'll need a subscription. diff --git a/sdk/eventhub/azure-eventhub/samples/async_samples/send_buffered_mode_async.py b/sdk/eventhub/azure-eventhub/samples/async_samples/send_buffered_mode_async.py new file mode 100644 index 000000000000..33a5f4f4ad97 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/samples/async_samples/send_buffered_mode_async.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python + +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +""" +Examples to show sending events in buffered mode to an Event Hub asynchronously. +""" + +import time +import asyncio +import os + +from azure.eventhub.aio import EventHubProducerClient +from azure.eventhub import EventData + +CONNECTION_STR = os.environ['EVENT_HUB_CONN_STR'] +EVENTHUB_NAME = os.environ['EVENT_HUB_NAME'] + + +async def on_success(events, pid): + # sending succeeded + print(events, pid) + + +async def on_error(events, pid, error): + # sending failed + print(events, pid, error) + + +async def run(): + + producer = EventHubProducerClient.from_connection_string( + conn_str=CONNECTION_STR, + eventhub_name=EVENTHUB_NAME, + buffered_mode=True, + on_success=on_success, + on_error=on_error + ) + + # exiting the context manager will automatically call flush + async with producer: + # single events will be batched automatically + for i in range(10): + # the method returning indicates the event has been enqueued to the buffer + await producer.send_event(EventData('Single data {}'.format(i))) + + batch = await producer.create_batch() + for i in range(10): + batch.add(EventData('Single data in batch {}'.format(i))) + # alternatively, you can enqueue an EventDataBatch object to the buffer + await producer.send_batch(batch) + + # calling flush sends out the events in the buffer immediately + await producer.flush() + +start_time = time.time() +asyncio.run(run()) +print("Send messages in {} seconds.".format(time.time() - start_time)) diff --git a/sdk/eventhub/azure-eventhub/samples/sync_samples/send_buffered_mode.py b/sdk/eventhub/azure-eventhub/samples/sync_samples/send_buffered_mode.py new file mode 100644 index 000000000000..2939e3445b76 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/samples/sync_samples/send_buffered_mode.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python + +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +""" +Examples to show sending events in buffered mode to an Event Hub. +""" + +import time +import os +from azure.eventhub import EventHubProducerClient, EventData + +CONNECTION_STR = os.environ['EVENT_HUB_CONN_STR'] +EVENTHUB_NAME = os.environ['EVENT_HUB_NAME'] + + +def on_success(events, pid): + # sending succeeded + print(events, pid) + + +def on_error(events, pid, error): + # sending failed + print(events, pid, error) + + +producer = EventHubProducerClient.from_connection_string( + conn_str=CONNECTION_STR, + eventhub_name=EVENTHUB_NAME, + buffered_mode=True, + on_success=on_success, + on_error=on_error +) + +start_time = time.time() + +# exiting the context manager will automatically call flush +with producer: + # single events will be batched automatically + for i in range(10): + # the method returning indicates the event has been enqueued to the buffer + producer.send_event(EventData('Single data {}'.format(i))) + + batch = producer.create_batch() + for i in range(10): + batch.add(EventData('Single data in batch {}'.format(i))) + # alternatively, you can enqueue an EventDataBatch object to the buffer + producer.send_batch(batch) + + # calling flush sends out the events in the buffer immediately + producer.flush() + +print("Send messages in {} seconds.".format(time.time() - start_time)) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_buffered_producer_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_buffered_producer_async.py new file mode 100644 index 000000000000..1cdfe0b62319 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_buffered_producer_async.py @@ -0,0 +1,492 @@ +#!/usr/bin/env python + +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import asyncio +from collections import defaultdict +from uuid import uuid4 + +import pytest + +from azure.eventhub import EventData +from azure.eventhub.aio import EventHubProducerClient, EventHubConsumerClient +from azure.eventhub.aio._buffered_producer import PartitionResolver +from azure.eventhub.amqp import ( + AmqpAnnotatedMessage, +) +from azure.eventhub.exceptions import EventDataSendError, OperationTimeoutError, EventHubError + + +async def random_pkey_generation(partitions): + pr = PartitionResolver(partitions) + total = len(partitions) + dic = {} + + while total: + key = str(uuid4()) + pid = await pr.get_partition_id_by_partition_key(key) + if pid in dic: + continue + else: + dic[pid] = key + total -= 1 + + return dic + + +@pytest.mark.liveTest() +@pytest.mark.asyncio +async def test_producer_client_constructor(connection_str): + async def on_success(events, pid): + pass + + async def on_error(events, error, pid): + pass + with pytest.raises(TypeError): + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True) + with pytest.raises(TypeError): + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_success=on_success) + with pytest.raises(TypeError): + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_error=on_error) + with pytest.raises(ValueError): + EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error, + max_wait_time=0 + ) + with pytest.raises(ValueError): + EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error, + max_buffer_length=0 + ) + + +@pytest.mark.liveTest +@pytest.mark.asyncio +@pytest.mark.parametrize( + "flush_after_sending, close_after_sending", + [ + (False, False), + (True, False), + (False, True) + ] +) +async def test_basic_send_single_events_round_robin(connection_str, flush_after_sending, close_after_sending): + received_events = defaultdict(list) + + async def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + receive_thread = asyncio.ensure_future(consumer.receive(on_event=on_event)) + + sent_events = defaultdict(list) + + async def on_success(events, pid): + if len(events) > 1: + on_success.batching = True + sent_events[pid].extend(events) + + async def on_error(events, pid, err): + on_error.err = err + + on_error.err = None # ensure no error + on_success.batching = False # ensure batching happened + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error + ) + + async with producer: + partitions = await producer.get_partition_ids() + partitions_cnt = len(partitions) + # perform single sending round-robin + total_single_event_cnt = 100 + eventdata_set, amqpannoated_set = set(), set() + for i in range(total_single_event_cnt // 2): + event = EventData("test:{}".format(i)) + event.properties = {"event_idx": i} + await producer.send_event(event) + eventdata_set.add(i) + for i in range(total_single_event_cnt // 2, total_single_event_cnt): + event = AmqpAnnotatedMessage(data_body="test:{}".format(i)) + event.application_properties = {"event_idx": i} + amqpannoated_set.add(i) + await producer.send_event(event) + + for pid in partitions: + assert producer.get_buffered_event_count(pid) > 0 + assert producer.total_buffered_event_count > 0 + + if not flush_after_sending and not close_after_sending: + # ensure it's buffered sending + for pid in partitions: + assert len(sent_events[pid]) < total_single_event_cnt // partitions_cnt + assert sum([len(sent_events[pid]) for pid in partitions]) < total_single_event_cnt + else: + if flush_after_sending: + await producer.flush() + if close_after_sending: + await producer.close() + # ensure all events are sent after calling flush + assert sum([len(sent_events[pid]) for pid in partitions]) == total_single_event_cnt + + # give some time for producer to complete sending and consumer to complete receiving + await asyncio.sleep(10) + assert len(sent_events) == len(received_events) == partitions_cnt + + for pid in partitions: + assert producer.get_buffered_event_count(pid) == 0 + assert producer.total_buffered_event_count == 0 + assert not on_error.err + + # ensure all events are received in the correct partition + for pid in partitions: + assert len(sent_events[pid]) >= total_single_event_cnt // partitions_cnt + assert len(sent_events[pid]) == len(received_events[pid]) + for i in range(len(sent_events[pid])): + event = sent_events[pid][i] + try: # amqp annotated message + event_idx = event.application_properties["event_idx"] + amqpannoated_set.remove(event_idx) + except AttributeError: # event data + event_idx = event.properties["event_idx"] + eventdata_set.remove(event_idx) + assert received_events[pid][i].properties[b"event_idx"] == event_idx + assert partitions[event_idx % partitions_cnt] == pid + + assert on_success.batching + assert not eventdata_set + assert not amqpannoated_set + + await consumer.close() + await receive_thread + + +@pytest.mark.liveTest +@pytest.mark.asyncio +@pytest.mark.parametrize( + "flush_after_sending, close_after_sending", + [ + (True, False), + (False, True), + (False, False) + ] +) +async def test_basic_send_batch_events_round_robin(connection_str, flush_after_sending, close_after_sending): + received_events = defaultdict(list) + + async def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + receive_thread = asyncio.ensure_future(consumer.receive(on_event=on_event)) + + sent_events = defaultdict(list) + + async def on_success(events, pid): + sent_events[pid].extend(events) + + async def on_error(events, pid, err): + on_error.err = err + + on_error.err = None + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error + ) + + async with producer: + partitions = await producer.get_partition_ids() + partitions_cnt = len(partitions) + # perform batch sending round-robin + total_events_cnt = 100 + batch_cnt = partitions_cnt * 2 - 1 + each_partition_cnt = total_events_cnt // batch_cnt + remain_events = total_events_cnt % batch_cnt + batches = [] + event_idx = 0 + eventdata_set, amqpannoated_set = set(), set() + for i in range(batch_cnt): + batch = await producer.create_batch() + for j in range(each_partition_cnt // 2): + event = EventData("test{}:{}".format(i, event_idx)) + event.properties = {'batch_idx': i, 'event_idx': event_idx} + batch.add(event) + eventdata_set.add(event_idx) + event_idx += 1 + for j in range(each_partition_cnt // 2, each_partition_cnt): + event = AmqpAnnotatedMessage(data_body="test{}:{}".format(i, event_idx)) + event.application_properties = {'batch_idx': i, 'event_idx': event_idx} + batch.add(event) + amqpannoated_set.add(event_idx) + event_idx += 1 + batches.append(batch) + + # put remain_events in the last batch + last_batch = await producer.create_batch() + for i in range(remain_events): + event = EventData("test:{}:{}".format(len(batches), event_idx)) + event.properties = {'batch_idx': len(batches), 'event_idx': event_idx} + last_batch.add(event) + eventdata_set.add(event_idx) + event_idx += 1 + batches.append(last_batch) + + for batch in batches: + await producer.send_batch(batch) + + if not flush_after_sending and not close_after_sending: + # ensure it's buffered sending + for pid in partitions: + assert len(sent_events[pid]) < each_partition_cnt + assert sum([len(sent_events[pid]) for pid in partitions]) < total_events_cnt + # give some time for producer to complete sending and consumer to complete receiving + else: + if flush_after_sending: + await producer.flush() + if close_after_sending: + await producer.close() + # ensure all events are sent + assert sum([len(sent_events[pid]) for pid in partitions]) == total_events_cnt + + await asyncio.sleep(10) + assert len(sent_events) == len(received_events) == partitions_cnt + + # ensure all events are received in the correct partition + for pid in partitions: + assert len(sent_events[pid]) > 0 + assert len(sent_events[pid]) == len(received_events[pid]) + for i in range(len(sent_events[pid])): + event = sent_events[pid][i] + try: # amqp annotated message + event_idx = event.application_properties["event_idx"] + amqpannoated_set.remove(event_idx) + except AttributeError: # event data + event_idx = event.properties["event_idx"] + eventdata_set.remove(event_idx) + assert received_events[pid][i].properties[b"event_idx"] == event_idx + + assert not amqpannoated_set + assert not eventdata_set + assert not on_error.err + + await consumer.close() + await receive_thread + + +@pytest.mark.liveTest +@pytest.mark.asyncio +async def test_send_with_hybrid_partition_assignment(connection_str): + received_events = defaultdict(list) + + async def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + receive_thread = asyncio.ensure_future(consumer.receive(on_event=on_event)) + + sent_events = defaultdict(list) + + async def on_success(events, pid): + sent_events[pid].extend(events) + + async def on_error(events, pid, err): + on_error.err = err + + on_error.err = None + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error + ) + + async with producer: + partitions = await producer.get_partition_ids() + partitions_cnt = len(partitions) + pid_to_pkey = await random_pkey_generation(partitions) + expected_event_idx_to_partition = {} + event_idx = 0 + # 1. send by partition_key, each partition 2 events, two single + one batch containing two + for pid in partitions: + pkey = pid_to_pkey[pid] + await producer.send_event(EventData('{}'.format(event_idx)), partition_key=pkey) + batch = await producer.create_batch(partition_key=pkey) + batch.add(EventData('{}'.format(event_idx + 1))) + await producer.send_batch(batch) + for i in range(2): + expected_event_idx_to_partition[event_idx + i] = pid + event_idx += 2 + + # 2. send by partition_id, each partition 2 events, two single + one batch containing two + for pid in partitions: + await producer.send_event(EventData('{}'.format(event_idx)), partition_id=pid) + batch = await producer.create_batch(partition_id=pid) + batch.add(EventData('{}'.format(event_idx + 1))) + await producer.send_batch(batch) + for i in range(2): + expected_event_idx_to_partition[event_idx + i] = pid + event_idx += 2 + + # 3. send without partition, each partition 2 events, two single + one batch containing two + for _ in partitions: + await producer.send_event(EventData('{}'.format(event_idx))) + batch = await producer.create_batch() + batch.add(EventData('{}'.format(event_idx + 1))) + await producer.send_batch(batch) + event_idx += 2 + + await producer.flush() + assert len(sent_events) == partitions_cnt + + await asyncio.sleep(10) + + visited = set() + for pid in partitions: + assert len(sent_events[pid]) == 2 * 3 + + for sent_event in sent_events[pid]: + if int(sent_event.body_as_str()) in expected_event_idx_to_partition: + assert expected_event_idx_to_partition[int(sent_event.body_as_str())] == pid + + for recv_event in received_events[pid]: + if int(sent_event.body_as_str()) in expected_event_idx_to_partition: + assert expected_event_idx_to_partition[int(sent_event.body_as_str())] == pid + + assert recv_event.body_as_str() not in visited + visited.add(recv_event.body_as_str()) + + assert len(visited) == 2 * 3 * len(partitions) + + assert not on_error.err + await consumer.close() + await receive_thread + + +@pytest.mark.liveTest +@pytest.mark.asyncio +async def test_send_with_timing_configuration(connection_str): + received_events = defaultdict(list) + + async def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + receive_thread = asyncio.ensure_future(consumer.receive(on_event=on_event)) + + sent_events = defaultdict(list) + + async def on_success(events, pid): + sent_events[pid].extend(events) + + async def on_error(events, pid, err): + on_error.err = err + + on_error.err = None + + # test max_wait_time + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + max_wait_time=10, + on_success=on_success, + on_error=on_error + ) + + async with producer: + partitions = await producer.get_partition_ids() + await producer.send_event(EventData('data')) + await asyncio.sleep(5) + assert not sent_events + await asyncio.sleep(10) + assert sum([len(sent_events[pid]) for pid in partitions]) == 1 + + assert not on_error.err + + # test max_buffer_length per partition + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + max_wait_time=1000, + max_buffer_length=10, + on_success=on_success, + on_error=on_error + ) + + sent_events.clear() + received_events.clear() + async with producer: + partitions = await producer.get_partition_ids() + for i in range(7): + await producer.send_event(EventData('data'), partition_id="0") + assert not sent_events + batch = await producer.create_batch(partition_id="0") + for i in range(9): + batch.add(EventData('9')) + await producer.send_batch(batch) # will flush 7 events and put the batch in buffer + assert sum([len(sent_events[pid]) for pid in partitions]) == 7 + for i in range(5): + await producer.send_event(EventData('data'), partition_id="0") # will flush batch (9 events) + 1 event, leaving 4 in buffer + assert sum([len(sent_events[pid]) for pid in partitions]) == 17 + await producer.flush() + assert sum([len(sent_events[pid]) for pid in partitions]) == 21 + + await asyncio.sleep(5) + assert sum([len(received_events[pid]) for pid in partitions]) == 21 + assert not on_error.err + await consumer.close() + await receive_thread + + +@pytest.mark.liveTest +@pytest.mark.asyncio +async def test_long_sleep(connection_str): + received_events = defaultdict(list) + + async def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + + receive_thread = asyncio.ensure_future(consumer.receive(on_event=on_event)) + + sent_events = defaultdict(list) + + async def on_success(events, pid): + sent_events[pid].extend(events) + + async def on_error(events, pid, err): + on_error.err = err + + on_error.err = None # ensure no error + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error + ) + + async with producer: + await producer.send_event(EventData("test"), partition_id="0") + await asyncio.sleep(220) + await producer.send_event(EventData("test"), partition_id="0") + await asyncio.sleep(5) + + assert not on_error.err + assert len(sent_events["0"]) == 2 + assert len(received_events["0"]) == 2 + + await consumer.close() + await receive_thread diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_negative_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_negative_async.py index eb1fc0beb663..4d5880d9d2dc 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_negative_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_negative_async.py @@ -35,6 +35,26 @@ async def test_send_with_invalid_hostname_async(invalid_hostname, connstr_receiv batch.add(EventData("test data")) await client.send_batch(batch) + # test setting callback + async def on_error(events, pid, err): + assert len(events) == 1 + assert not pid + on_error.err = err + + on_error.err = None + client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error) + async with client: + batch = EventDataBatch() + batch.add(EventData("test data")) + await client.send_batch(batch) + assert isinstance(on_error.err, ConnectError) + + on_error.err = None + client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error) + async with client: + await client.send_event(EventData("test data")) + assert isinstance(on_error.err, ConnectError) + @pytest.mark.parametrize("invalid_place", ["hostname", "key_name", "access_key", "event_hub", "partition"]) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py index e2c281d3a32a..efdab757f92e 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py @@ -10,10 +10,11 @@ import pytest import time import json +import uamqp from azure.eventhub import EventData, TransportType, EventDataBatch from azure.eventhub.aio import EventHubProducerClient, EventHubConsumerClient -from azure.eventhub.exceptions import EventDataSendError +from azure.eventhub.exceptions import EventDataSendError, OperationTimeoutError from azure.eventhub.amqp import ( AmqpMessageHeader, AmqpMessageBodyType, @@ -77,6 +78,10 @@ async def test_send_amqp_annotated_message(connstr_receivers): batch.add(event_data) await client.send_batch(batch) await client.send_batch([data_message, value_message, sequence_message, event_data]) + await client.send_event(data_message) + await client.send_event(value_message) + await client.send_event(sequence_message) + await client.send_event(event_data) received_count = {} received_count["data_msg"] = 0 @@ -130,23 +135,23 @@ async def on_event(partition_context, event): await task - assert len(on_event.received) == 8 - assert received_count["data_msg"] == 2 - assert received_count["seq_msg"] == 2 - assert received_count["value_msg"] == 2 - assert received_count["normal_msg"] == 2 + assert len(on_event.received) == 12 + assert received_count["data_msg"] == 3 + assert received_count["seq_msg"] == 3 + assert received_count["value_msg"] == 3 + assert received_count["normal_msg"] == 3 @pytest.mark.liveTest @pytest.mark.asyncio -async def test_send_with_partition_key_async(connstr_receivers): +async def test_send_with_partition_key_async(connstr_receivers, live_eventhub): connection_str, receivers = connstr_receivers client = EventHubProducerClient.from_connection_string(connection_str) async with client: data_val = 0 for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: partition_key = b"test_partition_" + partition - for i in range(50): + for i in range(10): batch = await client.create_batch(partition_key=partition_key) batch.add(EventData(str(data_val))) data_val += 1 @@ -154,16 +159,61 @@ async def test_send_with_partition_key_async(connstr_receivers): await client.send_batch(await client.create_batch()) + for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: + partition_key = b"test_partition_" + partition + for i in range(10): + event_data = EventData(str(data_val)) + event_data.properties = {'is_single': True} + data_val += 1 + await client.send_event(event_data, partition_key=partition_key) + + batch_cnt = 0 + single_cnt = 0 found_partition_keys = {} + reconnect_receivers = [] for index, partition in enumerate(receivers): - received = partition.receive_message_batch(timeout=5000) - for message in received: + retry_total = 0 + while retry_total < 3: + timeout = 5000 + retry_total * 1000 try: - event_data = EventData._from_message(message) - existing = found_partition_keys[event_data.partition_key] - assert existing == index - except KeyError: - found_partition_keys[event_data.partition_key] = index + received = partition.receive_message_batch(timeout=timeout) + for message in received: + try: + event_data = EventData._from_message(message) + if event_data.properties and event_data.properties[b'is_single']: + single_cnt += 1 + else: + batch_cnt += 1 + existing = found_partition_keys[event_data.partition_key] + assert existing == index + except KeyError: + found_partition_keys[event_data.partition_key] = index + if received: + break + retry_total += 1 + except uamqp.errors.ConnectionClose: + for r in reconnect_receivers: + r.close() + uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) + sas_auth = uamqp.authentication.SASTokenAuth.from_shared_access_key( + uri, live_eventhub['key_name'], live_eventhub['access_key']) + + source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( + live_eventhub['hostname'], + live_eventhub['event_hub'], + live_eventhub['consumer_group'], + index) + partition = uamqp.ReceiveClient(source, auth=sas_auth, debug=True, timeout=0, prefetch=500) + reconnect_receivers.append(partition) + retry_total += 1 + if retry_total == 3: + raise OperationTimeoutError(f"Exhausted retries for receiving from {live_eventhub['hostname']}.") + for r in reconnect_receivers: + r.close() + + assert single_cnt == 60 + assert batch_cnt == 60 + assert len(found_partition_keys) == 6 @pytest.mark.parametrize("payload", [b"", b"A single event"]) @@ -176,12 +226,14 @@ async def test_send_and_receive_small_body_async(connstr_receivers, payload): batch = await client.create_batch() batch.add(EventData(payload)) await client.send_batch(batch) + await client.send_event(EventData(payload)) received = [] for r in receivers: received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=5000)]) - assert len(received) == 1 + assert len(received) == 2 assert list(received[0].body)[0] == payload + assert list(received[1].body)[0] == payload @pytest.mark.liveTest @@ -194,29 +246,35 @@ async def test_send_partition_async(connstr_receivers): batch = await client.create_batch() batch.add(EventData(b"Data")) await client.send_batch(batch) + await client.send_event(EventData(b"Data")) async with client: batch = await client.create_batch(partition_id="1") batch.add(EventData(b"Data")) await client.send_batch(batch) + await client.send_event(EventData(b"Data"), partition_id="1") partition_0 = receivers[0].receive_message_batch(timeout=5000) partition_1 = receivers[1].receive_message_batch(timeout=5000) - assert len(partition_0) + len(partition_1) == 2 + assert len(partition_1) >= 2 + assert len(partition_0) + len(partition_1) == 4 async with client: batch = await client.create_batch() batch.add(EventData(b"Data")) await client.send_batch(batch) + await client.send_event(EventData(b"Data")) async with client: - batch = await client.create_batch(partition_id="1") + batch = await client.create_batch(partition_id="0") batch.add(EventData(b"Data")) await client.send_batch(batch) + await client.send_event(EventData(b"Data"), partition_id="0") partition_0 = receivers[0].receive_message_batch(timeout=5000) partition_1 = receivers[1].receive_message_batch(timeout=5000) - assert len(partition_0) + len(partition_1) == 2 + assert len(partition_0) >= 2 + assert len(partition_0) + len(partition_1) == 4 @pytest.mark.liveTest @@ -229,6 +287,8 @@ async def test_send_non_ascii_async(connstr_receivers): batch.add(EventData(u"é,è,à,ù,â,ê,î,ô,û")) batch.add(EventData(json.dumps({"foo": u"漢字"}))) await client.send_batch(batch) + await client.send_event(EventData(u"é,è,à,ù,â,ê,î,ô,û"), partition_id="0") + await client.send_event(EventData(json.dumps({"foo": u"漢字"})), partition_id="0") await asyncio.sleep(1) # receive_message_batch() returns immediately once it receives any messages before the max_batch_size # and timeout reach. Could be 1, 2, or any number between 1 and max_batch_size. @@ -236,9 +296,11 @@ async def test_send_non_ascii_async(connstr_receivers): partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] + \ [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] - assert len(partition_0) == 2 + assert len(partition_0) == 4 assert partition_0[0].body_as_str() == u"é,è,à,ù,â,ê,î,ô,û" assert partition_0[1].body_as_json() == {"foo": u"漢字"} + assert partition_0[2].body_as_str() == u"é,è,à,ù,â,ê,î,ô,û" + assert partition_0[3].body_as_json() == {"foo": u"漢字"} @pytest.mark.liveTest @@ -255,19 +317,23 @@ async def test_send_multiple_partition_with_app_prop_async(connstr_receivers): batch = await client.create_batch(partition_id="0") batch.add(ed0) await client.send_batch(batch) + await client.send_event(ed0, partition_id="0") ed1 = EventData(b"Message 1") ed1.properties = app_prop batch = await client.create_batch(partition_id="1") batch.add(ed1) await client.send_batch(batch) + await client.send_event(ed1, partition_id="1") partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] - assert len(partition_0) == 1 + assert len(partition_0) == 2 assert partition_0[0].properties[b"raw_prop"] == b"raw_value" + assert partition_0[1].properties[b"raw_prop"] == b"raw_value" partition_1 = [EventData._from_message(x) for x in receivers[1].receive_message_batch(timeout=5000)] - assert len(partition_1) == 1 + assert len(partition_1) == 2 assert partition_1[0].properties[b"raw_prop"] == b"raw_value" + assert partition_0[1].properties[b"raw_prop"] == b"raw_value" @pytest.mark.liveTest @@ -281,11 +347,12 @@ async def test_send_over_websocket_async(connstr_receivers): batch = await client.create_batch(partition_id="0") batch.add(EventData("Event Data")) await client.send_batch(batch) + await client.send_event(EventData("Event Data"), partition_id="0") time.sleep(1) received = [] received.extend(receivers[0].receive_message_batch(max_batch_size=5, timeout=10000)) - assert len(received) == 1 + assert len(received) == 2 @pytest.mark.liveTest @@ -314,7 +381,6 @@ async def test_send_with_create_event_batch_async(connstr_receivers): assert EventData._from_message(received[0]).properties[b"raw_prop"] == b"raw_value" - @pytest.mark.liveTest @pytest.mark.asyncio async def test_send_list_async(connstr_receivers): @@ -367,3 +433,52 @@ async def test_send_batch_pid_pk_async(invalid_hostname, partition_id, partition async with client: with pytest.raises(TypeError): await client.send_batch(batch, partition_id=partition_id, partition_key=partition_key) + + +@pytest.mark.liveTest +@pytest.mark.asyncio +async def test_send_with_callback_async(connstr_receivers): + + async def on_error(events, pid, err): + on_error.err = err + + async def on_success(events, pid): + sent_events.append((events, pid)) + + sent_events = [] + on_error.err = None + connection_str, receivers = connstr_receivers + client = EventHubProducerClient.from_connection_string(connection_str, on_success=on_success, on_error=on_error) + + async with client: + batch = await client.create_batch() + batch.add(EventData(b"Data")) + batch.add(EventData(b"Data")) + await client.send_batch(batch) + assert len(sent_events[-1][0]) == 2 + assert not sent_events[-1][1] + await client.send_event(EventData(b"Data")) + assert len(sent_events[-1][0]) == 1 + assert not sent_events[-1][1] + + batch = await client.create_batch(partition_key='key') + batch.add(EventData(b"Data")) + batch.add(EventData(b"Data")) + await client.send_batch(batch) + assert len(sent_events[-1][0]) == 2 + assert not sent_events[-1][1] + await client.send_event(EventData(b"Data"), partition_key='key') + assert len(sent_events[-1][0]) == 1 + assert not sent_events[-1][1] + + batch = await client.create_batch(partition_id="0") + batch.add(EventData(b"Data")) + await client.send_batch(batch) + batch.add(EventData(b"Data")) + assert len(sent_events[-1][0]) == 2 + assert sent_events[-1][1] == "0" + await client.send_event(EventData(b"Data"), partition_id="0") + assert len(sent_events[-1][0]) == 1 + assert sent_events[-1][1] == "0" + + assert not on_error.err diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py new file mode 100644 index 000000000000..1360e699b8bf --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py @@ -0,0 +1,498 @@ +#!/usr/bin/env python + +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import time +from collections import defaultdict +from threading import Thread +from uuid import uuid4 +from concurrent.futures import ThreadPoolExecutor + +import pytest + +from azure.eventhub import EventData +from azure.eventhub import EventHubProducerClient, EventHubConsumerClient +from azure.eventhub._buffered_producer import PartitionResolver +from azure.eventhub.amqp import ( + AmqpAnnotatedMessage, +) +from azure.eventhub.exceptions import EventDataSendError, OperationTimeoutError, EventHubError + + +def random_pkey_generation(partitions): + pr = PartitionResolver(partitions) + total = len(partitions) + dic = {} + + while total: + key = str(uuid4()) + pid = pr.get_partition_id_by_partition_key(key) + if pid in dic: + continue + else: + dic[pid] = key + total -= 1 + + return dic + + +@pytest.mark.liveTest() +def test_producer_client_constructor(connection_str): + def on_success(events, pid): + pass + + def on_error(events, error, pid): + pass + with pytest.raises(TypeError): + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True) + with pytest.raises(TypeError): + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_success=on_success) + with pytest.raises(TypeError): + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_error=on_error) + with pytest.raises(ValueError): + EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error, + max_wait_time=0 + ) + with pytest.raises(ValueError): + EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error, + max_buffer_length=0 + ) + + +@pytest.mark.liveTest +@pytest.mark.parametrize( + "flush_after_sending, close_after_sending", + [ + (False, False), + (True, False), + (False, True) + ] +) +@pytest.mark.liveTest +def test_basic_send_single_events_round_robin(connection_str, flush_after_sending, close_after_sending): + received_events = defaultdict(list) + + def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + receive_thread = Thread(target=consumer.receive, args=(on_event,)) + receive_thread.daemon = True + receive_thread.start() + + sent_events = defaultdict(list) + + def on_success(events, pid): + if len(events) > 1: + on_success.batching = True + sent_events[pid].extend(events) + + def on_error(events, pid, err): + on_error.err = err + + on_error.err = None # ensure no error + on_success.batching = False # ensure batching happened + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error + ) + + with producer: + partitions = producer.get_partition_ids() + partitions_cnt = len(partitions) + # perform single sending round-robin + total_single_event_cnt = 100 + eventdata_set, amqpannoated_set = set(), set() + for i in range(total_single_event_cnt // 2): + event = EventData("test:{}".format(i)) + event.properties = {"event_idx": i} + producer.send_event(event) + eventdata_set.add(i) + for i in range(total_single_event_cnt // 2, total_single_event_cnt): + event = AmqpAnnotatedMessage(data_body="test:{}".format(i)) + event.application_properties = {"event_idx": i} + amqpannoated_set.add(i) + producer.send_event(event) + + for pid in partitions: + assert producer.get_buffered_event_count(pid) > 0 + assert producer.total_buffered_event_count > 0 + + if not flush_after_sending and not close_after_sending: + # ensure it's buffered sending + for pid in partitions: + assert len(sent_events[pid]) < total_single_event_cnt // partitions_cnt + assert sum([len(sent_events[pid]) for pid in partitions]) < total_single_event_cnt + else: + if flush_after_sending: + producer.flush() + if close_after_sending: + producer.close() + # ensure all events are sent after calling flush + assert sum([len(sent_events[pid]) for pid in partitions]) == total_single_event_cnt + + # give some time for producer to complete sending and consumer to complete receiving + time.sleep(10) + assert len(sent_events) == len(received_events) == partitions_cnt + + for pid in partitions: + assert producer.get_buffered_event_count(pid) == 0 + assert producer.total_buffered_event_count == 0 + assert not on_error.err + + # ensure all events are received in the correct partition + for pid in partitions: + assert len(sent_events[pid]) >= total_single_event_cnt // partitions_cnt + assert len(sent_events[pid]) == len(received_events[pid]) + for i in range(len(sent_events[pid])): + event = sent_events[pid][i] + try: # amqp annotated message + event_idx = event.application_properties["event_idx"] + amqpannoated_set.remove(event_idx) + except AttributeError: # event data + event_idx = event.properties["event_idx"] + eventdata_set.remove(event_idx) + assert received_events[pid][i].properties[b"event_idx"] == event_idx + assert partitions[event_idx % partitions_cnt] == pid + + assert on_success.batching + assert not eventdata_set + assert not amqpannoated_set + + consumer.close() + receive_thread.join() + + +@pytest.mark.liveTest +@pytest.mark.parametrize( + "flush_after_sending, close_after_sending", + [ + (True, False), + (False, True), + (False, False) + ] +) +def test_basic_send_batch_events_round_robin(connection_str, flush_after_sending, close_after_sending): + received_events = defaultdict(list) + + def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + receive_thread = Thread(target=consumer.receive, args=(on_event,)) + receive_thread.daemon = True + receive_thread.start() + + sent_events = defaultdict(list) + + def on_success(events, pid): + sent_events[pid].extend(events) + + def on_error(events, pid, err): + on_error.err = err + + on_error.err = None + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error + ) + + with producer: + partitions = producer.get_partition_ids() + partitions_cnt = len(partitions) + # perform batch sending round-robin + total_events_cnt = 100 + batch_cnt = partitions_cnt * 2 - 1 + each_partition_cnt = total_events_cnt // batch_cnt + remain_events = total_events_cnt % batch_cnt + batches = [] + event_idx = 0 + eventdata_set, amqpannoated_set = set(), set() + for i in range(batch_cnt): + batch = producer.create_batch() + for j in range(each_partition_cnt // 2): + event = EventData("test{}:{}".format(i, event_idx)) + event.properties = {'batch_idx': i, 'event_idx': event_idx} + batch.add(event) + eventdata_set.add(event_idx) + event_idx += 1 + for j in range(each_partition_cnt // 2, each_partition_cnt): + event = AmqpAnnotatedMessage(data_body="test{}:{}".format(i, event_idx)) + event.application_properties = {'batch_idx': i, 'event_idx': event_idx} + batch.add(event) + amqpannoated_set.add(event_idx) + event_idx += 1 + batches.append(batch) + + # put remain_events in the last batch + last_batch = producer.create_batch() + for i in range(remain_events): + event = EventData("test:{}:{}".format(len(batches), event_idx)) + event.properties = {'batch_idx': len(batches), 'event_idx': event_idx} + last_batch.add(event) + eventdata_set.add(event_idx) + event_idx += 1 + batches.append(last_batch) + + for batch in batches: + producer.send_batch(batch) + + if not flush_after_sending and not close_after_sending: + # ensure it's buffered sending + for pid in partitions: + assert len(sent_events[pid]) < each_partition_cnt + assert sum([len(sent_events[pid]) for pid in partitions]) < total_events_cnt + # give some time for producer to complete sending and consumer to complete receiving + else: + if flush_after_sending: + producer.flush() + if close_after_sending: + producer.close() + # ensure all events are sent + assert sum([len(sent_events[pid]) for pid in partitions]) == total_events_cnt + + time.sleep(10) + assert len(sent_events) == len(received_events) == partitions_cnt + + # ensure all events are received in the correct partition + for pid in partitions: + assert len(sent_events[pid]) > 0 + assert len(sent_events[pid]) == len(received_events[pid]) + for i in range(len(sent_events[pid])): + event = sent_events[pid][i] + try: # amqp annotated message + event_idx = event.application_properties["event_idx"] + amqpannoated_set.remove(event_idx) + except AttributeError: # event data + event_idx = event.properties["event_idx"] + eventdata_set.remove(event_idx) + assert received_events[pid][i].properties[b"event_idx"] == event_idx + + assert not amqpannoated_set + assert not eventdata_set + assert not on_error.err + + consumer.close() + receive_thread.join() + + +@pytest.mark.liveTest +def test_send_with_hybrid_partition_assignment(connection_str): + received_events = defaultdict(list) + + def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + receive_thread = Thread(target=consumer.receive, args=(on_event,)) + receive_thread.daemon = True + receive_thread.start() + + sent_events = defaultdict(list) + + def on_success(events, pid): + sent_events[pid].extend(events) + + def on_error(events, pid, err): + on_error.err = err + + on_error.err = None + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error + ) + + with producer: + partitions = producer.get_partition_ids() + partitions_cnt = len(partitions) + pid_to_pkey = random_pkey_generation(partitions) + expected_event_idx_to_partition = {} + event_idx = 0 + # 1. send by partition_key, each partition 2 events, two single + one batch containing two + for pid in partitions: + pkey = pid_to_pkey[pid] + producer.send_event(EventData('{}'.format(event_idx)), partition_key=pkey) + batch = producer.create_batch(partition_key=pkey) + batch.add(EventData('{}'.format(event_idx + 1))) + producer.send_batch(batch) + for i in range(2): + expected_event_idx_to_partition[event_idx + i] = pid + event_idx += 2 + + # 2. send by partition_id, each partition 2 events, two single + one batch containing two + for pid in partitions: + producer.send_event(EventData('{}'.format(event_idx)), partition_id=pid) + batch = producer.create_batch(partition_id=pid) + batch.add(EventData('{}'.format(event_idx + 1))) + producer.send_batch(batch) + for i in range(2): + expected_event_idx_to_partition[event_idx + i] = pid + event_idx += 2 + + # 3. send without partition, each partition 2 events, two single + one batch containing two + for _ in partitions: + producer.send_event(EventData('{}'.format(event_idx))) + batch = producer.create_batch() + batch.add(EventData('{}'.format(event_idx + 1))) + producer.send_batch(batch) + event_idx += 2 + + producer.flush() + assert len(sent_events) == partitions_cnt + + time.sleep(10) + + visited = set() + for pid in partitions: + assert len(sent_events[pid]) == 2 * 3 + + for sent_event in sent_events[pid]: + if int(sent_event.body_as_str()) in expected_event_idx_to_partition: + assert expected_event_idx_to_partition[int(sent_event.body_as_str())] == pid + + for recv_event in received_events[pid]: + if int(sent_event.body_as_str()) in expected_event_idx_to_partition: + assert expected_event_idx_to_partition[int(sent_event.body_as_str())] == pid + + assert recv_event.body_as_str() not in visited + visited.add(recv_event.body_as_str()) + + assert len(visited) == 2 * 3 * len(partitions) + + assert not on_error.err + consumer.close() + receive_thread.join() + + +def test_send_with_timing_configuration(connection_str): + received_events = defaultdict(list) + + def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + receive_thread = Thread(target=consumer.receive, args=(on_event,)) + receive_thread.daemon = True + receive_thread.start() + + sent_events = defaultdict(list) + + def on_success(events, pid): + sent_events[pid].extend(events) + + def on_error(events, pid, err): + on_error.err = err + + on_error.err = None + + # test max_wait_time + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + max_wait_time=10, + on_success=on_success, + on_error=on_error + ) + + with producer: + partitions = producer.get_partition_ids() + producer.send_event(EventData('data')) + time.sleep(5) + assert not sent_events + time.sleep(10) + assert sum([len(sent_events[pid]) for pid in partitions]) == 1 + + assert not on_error.err + + # test max_buffer_length per partition + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + max_wait_time=1000, + max_buffer_length=10, + on_success=on_success, + on_error=on_error + ) + + sent_events.clear() + received_events.clear() + with producer: + partitions = producer.get_partition_ids() + for i in range(7): + producer.send_event(EventData('data'), partition_id="0") + assert not sent_events + batch = producer.create_batch(partition_id="0") + for i in range(9): + batch.add(EventData('9')) + producer.send_batch(batch) # will flush 7 events and put the batch in buffer + assert sum([len(sent_events[pid]) for pid in partitions]) == 7 + for i in range(5): + producer.send_event(EventData('data'), partition_id="0") # will flush batch (9 events) + 1 event, leaving 4 in buffer + assert sum([len(sent_events[pid]) for pid in partitions]) == 17 + producer.flush() + assert sum([len(sent_events[pid]) for pid in partitions]) == 21 + + time.sleep(5) + assert sum([len(received_events[pid]) for pid in partitions]) == 21 + assert not on_error.err + consumer.close() + receive_thread.join() + + +@pytest.mark.liveTest +def test_long_sleep(connection_str): + received_events = defaultdict(list) + + def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + receive_thread = Thread(target=consumer.receive, args=(on_event,)) + receive_thread.daemon = True + receive_thread.start() + + sent_events = defaultdict(list) + + def on_success(events, pid): + sent_events[pid].extend(events) + + def on_error(events, pid, err): + on_error.err = err + + on_error.err = None # ensure no error + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error + ) + + with producer: + producer.send_event(EventData("test"), partition_id="0") + time.sleep(220) + producer.send_event(EventData("test"), partition_id="0") + time.sleep(5) + + assert not on_error.err + assert len(sent_events["0"]) == 2 + assert len(received_events["0"]) == 2 + + consumer.close() + receive_thread.join() diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py index d3b775381796..3b7249c2cace 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py @@ -33,6 +33,26 @@ def test_send_batch_with_invalid_hostname(invalid_hostname): batch.add(EventData("test data")) client.send_batch(batch) + # test setting callback + def on_error(events, pid, err): + assert len(events) == 1 + assert not pid + on_error.err = err + + on_error.err = None + client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error) + with client: + batch = EventDataBatch() + batch.add(EventData("test data")) + client.send_batch(batch) + assert isinstance(on_error.err, ConnectError) + + on_error.err = None + client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error) + with client: + client.send_event(EventData("test data")) + assert isinstance(on_error.err, ConnectError) + @pytest.mark.liveTest def test_receive_with_invalid_hostname_sync(invalid_hostname): diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index 0872b55ae85f..4624bcf93f9e 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -11,9 +11,10 @@ import json import sys +import uamqp from azure.eventhub import EventData, TransportType, EventDataBatch from azure.eventhub import EventHubProducerClient, EventHubConsumerClient -from azure.eventhub.exceptions import EventDataSendError +from azure.eventhub.exceptions import EventDataSendError, OperationTimeoutError from azure.eventhub.amqp import ( AmqpMessageHeader, AmqpMessageBodyType, @@ -21,15 +22,16 @@ AmqpMessageProperties, ) + @pytest.mark.liveTest -def test_send_with_partition_key(connstr_receivers): +def test_send_with_partition_key(connstr_receivers, live_eventhub): connection_str, receivers = connstr_receivers client = EventHubProducerClient.from_connection_string(connection_str) with client: data_val = 0 for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: partition_key = b"test_partition_" + partition - for i in range(50): + for i in range(10): batch = client.create_batch(partition_key=partition_key) batch.add(EventData(str(data_val))) data_val += 1 @@ -37,16 +39,62 @@ def test_send_with_partition_key(connstr_receivers): client.send_batch(client.create_batch()) + for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: + partition_key = b"test_partition_" + partition + for i in range(10): + event_data = EventData(str(data_val)) + event_data.properties = {'is_single': True} + data_val += 1 + client.send_event(event_data, partition_key=partition_key) + + batch_cnt = 0 + single_cnt = 0 found_partition_keys = {} + reconnect_receivers = [] for index, partition in enumerate(receivers): - received = partition.receive_message_batch(timeout=5000) - for message in received: + retry_total = 0 + while retry_total < 3: + timeout = 5000 + retry_total * 1000 try: - event_data = EventData._from_message(message) - existing = found_partition_keys[event_data.partition_key] - assert existing == index - except KeyError: - found_partition_keys[event_data.partition_key] = index + received = partition.receive_message_batch(timeout=timeout) + for message in received: + try: + event_data = EventData._from_message(message) + if event_data.properties and event_data.properties[b'is_single']: + single_cnt += 1 + else: + batch_cnt += 1 + existing = found_partition_keys[event_data.partition_key] + assert existing == index + except KeyError: + found_partition_keys[event_data.partition_key] = index + if received: + break + retry_total += 1 + except uamqp.errors.ConnectionClose: + for r in reconnect_receivers: + r.close() + uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) + sas_auth = uamqp.authentication.SASTokenAuth.from_shared_access_key( + uri, live_eventhub['key_name'], live_eventhub['access_key']) + + source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( + live_eventhub['hostname'], + live_eventhub['event_hub'], + live_eventhub['consumer_group'], + index) + partition = uamqp.ReceiveClient(source, auth=sas_auth, debug=True, timeout=0, prefetch=500) + reconnect_receivers.append(partition) + retry_total += 1 + if retry_total == 3: + raise OperationTimeoutError(f"Exhausted retries for receiving from {live_eventhub['hostname']}.") + + for r in reconnect_receivers: + r.close() + + assert single_cnt == 60 + assert batch_cnt == 60 + assert len(found_partition_keys) == 6 @pytest.mark.liveTest @@ -60,13 +108,31 @@ def test_send_and_receive_large_body_size(connstr_receivers): batch = client.create_batch() batch.add(EventData("A" * payload)) client.send_batch(batch) + client.send_event(EventData("A" * payload)) received = [] for r in receivers: received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=10000)]) - assert len(received) == 1 + assert len(received) == 2 + assert len(list(received[0].body)[0]) == payload + assert len(list(received[1].body)[0]) == payload + + client = EventHubProducerClient.from_connection_string(connection_str) + with client: + payload = 250 * 1024 + batch = client.create_batch() + batch.add(EventData("A" * payload)) + client.send_batch(batch) + client.send_event(EventData("A" * payload)) + + received = [] + for r in receivers: + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=10000)]) + + assert len(received) == 2 assert len(list(received[0].body)[0]) == payload + assert len(list(received[1].body)[0]) == payload @pytest.mark.liveTest @@ -124,6 +190,10 @@ def test_send_amqp_annotated_message(connstr_receivers): batch.add(event_data) client.send_batch(batch) client.send_batch([data_message, value_message, sequence_message, event_data]) + client.send_event(data_message) + client.send_event(value_message) + client.send_event(sequence_message) + client.send_event(event_data) received_count = {} received_count["data_msg"] = 0 @@ -177,11 +247,11 @@ def on_event(partition_context, event): for event in on_event.received: check_values(event) - assert len(on_event.received) == 8 - assert received_count["data_msg"] == 2 - assert received_count["seq_msg"] == 2 - assert received_count["value_msg"] == 2 - assert received_count["normal_msg"] == 2 + assert len(on_event.received) == 12 + assert received_count["data_msg"] == 3 + assert received_count["seq_msg"] == 3 + assert received_count["value_msg"] == 3 + assert received_count["normal_msg"] == 3 @pytest.mark.parametrize("payload", @@ -194,12 +264,14 @@ def test_send_and_receive_small_body(connstr_receivers, payload): batch = client.create_batch() batch.add(EventData(payload)) client.send_batch(batch) + client.send_event(EventData(payload)) received = [] for r in receivers: received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=5000)]) - assert len(received) == 1 + assert len(received) == 2 assert list(received[0].body)[0] == payload + assert list(received[1].body)[0] == payload @pytest.mark.liveTest @@ -211,29 +283,35 @@ def test_send_partition(connstr_receivers): batch = client.create_batch() batch.add(EventData(b"Data")) client.send_batch(batch) + client.send_event(EventData(b"Data")) with client: batch = client.create_batch(partition_id="1") batch.add(EventData(b"Data")) client.send_batch(batch) + client.send_event(EventData(b"Data"), partition_id="1") partition_0 = receivers[0].receive_message_batch(timeout=5000) partition_1 = receivers[1].receive_message_batch(timeout=5000) - assert len(partition_0) + len(partition_1) == 2 + assert len(partition_1) >= 2 + assert len(partition_0) + len(partition_1) == 4 with client: batch = client.create_batch() batch.add(EventData(b"Data")) client.send_batch(batch) + client.send_event(EventData(b"Data")) with client: - batch = client.create_batch(partition_id="1") + batch = client.create_batch(partition_id="0") batch.add(EventData(b"Data")) client.send_batch(batch) + client.send_event(EventData(b"Data"), partition_id="0") partition_0 = receivers[0].receive_message_batch(timeout=5000) partition_1 = receivers[1].receive_message_batch(timeout=5000) - assert len(partition_0) + len(partition_1) == 2 + assert len(partition_0) >= 2 + assert len(partition_0) + len(partition_1) == 4 @pytest.mark.liveTest @@ -245,15 +323,19 @@ def test_send_non_ascii(connstr_receivers): batch.add(EventData(u"é,è,à,ù,â,ê,î,ô,û")) batch.add(EventData(json.dumps({"foo": u"漢字"}))) client.send_batch(batch) + client.send_event(EventData(u"é,è,à,ù,â,ê,î,ô,û"), partition_id="0") + client.send_event(EventData(json.dumps({"foo": u"漢字"})), partition_id="0") time.sleep(1) # receive_message_batch() returns immediately once it receives any messages before the max_batch_size # and timeout reach. Could be 1, 2, or any number between 1 and max_batch_size. # So call it twice to ensure the two events are received. partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] + \ [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] - assert len(partition_0) == 2 + assert len(partition_0) == 4 assert partition_0[0].body_as_str() == u"é,è,à,ù,â,ê,î,ô,û" assert partition_0[1].body_as_json() == {"foo": u"漢字"} + assert partition_0[2].body_as_str() == u"é,è,à,ù,â,ê,î,ô,û" + assert partition_0[3].body_as_json() == {"foo": u"漢字"} @pytest.mark.liveTest @@ -269,19 +351,23 @@ def test_send_multiple_partitions_with_app_prop(connstr_receivers): batch = client.create_batch(partition_id="0") batch.add(ed0) client.send_batch(batch) + client.send_event(ed0, partition_id="0") ed1 = EventData(b"Message 1") ed1.properties = app_prop batch = client.create_batch(partition_id="1") batch.add(ed1) client.send_batch(batch) + client.send_event(ed1, partition_id="1") partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] - assert len(partition_0) == 1 + assert len(partition_0) == 2 assert partition_0[0].properties[b"raw_prop"] == b"raw_value" + assert partition_0[1].properties[b"raw_prop"] == b"raw_value" partition_1 = [EventData._from_message(x) for x in receivers[1].receive_message_batch(timeout=5000)] - assert len(partition_1) == 1 + assert len(partition_1) == 2 assert partition_1[0].properties[b"raw_prop"] == b"raw_value" + assert partition_1[1].properties[b"raw_prop"] == b"raw_value" @pytest.mark.liveTest @@ -293,11 +379,12 @@ def test_send_over_websocket_sync(connstr_receivers): batch = client.create_batch(partition_id="0") batch.add(EventData("Event Data")) client.send_batch(batch) + client.send_event(EventData("Event Data"), partition_id="0") time.sleep(1) received = [] received.extend(receivers[0].receive_message_batch(max_batch_size=5, timeout=10000)) - assert len(received) == 1 + assert len(received) == 2 @pytest.mark.liveTest @@ -371,3 +458,50 @@ def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key): with client: with pytest.raises(TypeError): client.send_batch(batch, partition_id=partition_id, partition_key=partition_key) + + +def test_send_with_callback(connstr_receivers): + + def on_error(events, pid, err): + on_error.err = err + + def on_success(events, pid): + sent_events.append((events, pid)) + + sent_events = [] + on_error.err = None + connection_str, receivers = connstr_receivers + client = EventHubProducerClient.from_connection_string(connection_str, on_success=on_success, on_error=on_error) + + with client: + batch = client.create_batch() + batch.add(EventData(b"Data")) + batch.add(EventData(b"Data")) + client.send_batch(batch) + assert len(sent_events[-1][0]) == 2 + assert not sent_events[-1][1] + client.send_event(EventData(b"Data")) + assert len(sent_events[-1][0]) == 1 + assert not sent_events[-1][1] + + batch = client.create_batch(partition_key='key') + batch.add(EventData(b"Data")) + batch.add(EventData(b"Data")) + client.send_batch(batch) + assert len(sent_events[-1][0]) == 2 + assert not sent_events[-1][1] + client.send_event(EventData(b"Data"), partition_key='key') + assert len(sent_events[-1][0]) == 1 + assert not sent_events[-1][1] + + batch = client.create_batch(partition_id="0") + batch.add(EventData(b"Data")) + client.send_batch(batch) + batch.add(EventData(b"Data")) + assert len(sent_events[-1][0]) == 2 + assert sent_events[-1][1] == "0" + client.send_event(EventData(b"Data"), partition_id="0") + assert len(sent_events[-1][0]) == 1 + assert sent_events[-1][1] == "0" + + assert not on_error.err diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/asynctests/test_partition_resolver_async.py b/sdk/eventhub/azure-eventhub/tests/unittest/asynctests/test_partition_resolver_async.py new file mode 100644 index 000000000000..195ee9e01315 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/unittest/asynctests/test_partition_resolver_async.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import asyncio +from collections import defaultdict +import pytest +from azure.eventhub.aio._buffered_producer._partition_resolver_async import PartitionResolver + + +class TestPartitionResolver: + + @pytest.mark.asyncio + @pytest.mark.parametrize("partition_cnt", [1, 2, 16, 32, 256]) + async def test_basic_round_robin(self, partition_cnt): + partitions = [str(i) for i in range(partition_cnt)] + pr = PartitionResolver(partitions) + for i in range(2*partition_cnt): + expected = str(i % partition_cnt) + real = await pr.get_next_partition_id() + assert expected == real + + @pytest.mark.asyncio + @pytest.mark.parametrize("partition_cnt", [1, 2, 16, 32, 256]) + async def test_concurrent_round_robin_fairly(self, partition_cnt): + partitions = [str(i) for i in range(partition_cnt)] + pr = PartitionResolver(partitions) + dic = defaultdict(int) + lock = asyncio.Lock() + + async def gen_pid(): + pid = await pr.get_next_partition_id() + async with lock: + dic[pid] += 1 + + futures = [asyncio.ensure_future(gen_pid()) for _ in range(5*partition_cnt)] + + for future in futures: + await future + + assert len(dic) == partition_cnt + for i in range(partition_cnt): + assert dic[str(i)] == 5 diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_partition_resolver.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_partition_resolver.py new file mode 100644 index 000000000000..1ff803127823 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_partition_resolver.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +import pytest +from threading import Lock +from azure.eventhub._buffered_producer._partition_resolver import generate_hash_code, PartitionResolver + + +class TestPartitionResolver: + def test_partition_key(self): + input = { + "7": -15263, + "7149583486996073602": 12977, + "FWfAT": -22341, + "sOdeEAsyQoEuEFPGerWO": -6503, + "FAyAIctPeCgmiwLKbJcyswoHglHVjQdvtBowLACDNORsYvOcLddNJYDmhAVkbyLOrHTKLneMNcbgWVlasVywOByANjs": 5226, + "1XYM6!(7(lF5wq4k4m*e$Nc!1ezLJv*1YK1Y-C^*&B$O)lq^iUkG(TNzXG;Zi#z2Og*Qq0#^*k):vXh$3,C7We7%W0meJ;b3,rQCg^J;^twXgs5E$$hWKxqp": 23950, + "E(x;RRIaQcJs*P;D&jTPau-4K04oqr:lF6Z):ERpo&;9040qyV@G1_c9mgOs-8_8/10Fwa-7b7-yP!T-!IH&968)FWuI;(^g$2fN;)HJ^^yTn:": -29304, + "!c*_!I@1^c": 15372, + "p4*!jioeO/z-!-;w:dh": -3104, + "$0cb": 26269, + "-4189260826195535198": 453 + } + + for k, v in input.items(): + assert generate_hash_code(k) == v + + @pytest.mark.parametrize("partition_cnt", [1, 2, 16, 32, 256]) + def test_basic_round_robin(self, partition_cnt): + partitions = [str(i) for i in range(partition_cnt)] + pr = PartitionResolver(partitions) + for i in range(2*partition_cnt): + expected = str(i % partition_cnt) + real = pr.get_next_partition_id() + assert expected == real + + @pytest.mark.parametrize("partition_cnt", [1, 2, 16, 32, 256]) + def test_concurrent_round_robin_fairly(self, partition_cnt): + partitions = [str(i) for i in range(partition_cnt)] + pr = PartitionResolver(partitions) + exc = ThreadPoolExecutor() + + dic = defaultdict(int) + lock = Lock() + + def gen_pid(): + pid = pr.get_next_partition_id() + with lock: + dic[pid] += 1 + + for i in range(5*partition_cnt): + exc.submit(gen_pid) + + exc.shutdown() + assert len(dic) == partition_cnt + for i in range(partition_cnt): + assert dic[str(i)] == 5