From 2805d00393731628c31c8e19cb04c04a0ff3af4d Mon Sep 17 00:00:00 2001 From: "Adam Ling (MSFT)" Date: Thu, 3 Mar 2022 23:59:39 -0800 Subject: [PATCH] [EH Pyproto] Async recv perf improvement (#23122) * stop spawning too much coroutines * improve send * async recv perf improvement * async perf improve * update version * align with sync imple * update method name * remove redundant except catch --- .../eventhub/_pyamqp/aio/_connection_async.py | 19 ++- .../eventhub/_pyamqp/aio/_transport_async.py | 8 +- .../azure-eventhub/azure/eventhub/_version.py | 2 +- .../azure/eventhub/aio/_consumer_async.py | 108 ++++++++++-------- .../asynctests/test_reconnect_async.py | 12 +- 5 files changed, 75 insertions(+), 74 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index ffa8d271d0bb..fa5193f5e28f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -126,7 +126,8 @@ async def _set_state(self, new_state): self.state = new_state _LOGGER.info("Connection '%s' state changed: %r -> %r", self._container_id, previous_state, new_state) - await asyncio.gather(*[session._on_connection_state_change() for session in self.outgoing_endpoints.values()]) + for session in self.outgoing_endpoints.values(): + await session._on_connection_state_change() async def _connect(self): try: @@ -205,11 +206,11 @@ def _get_next_outgoing_channel(self): async def _outgoing_empty(self): if self.network_trace: - _LOGGER.info("<- empty()", extra=self.network_trace_params) + _LOGGER.info("-> empty()", extra=self.network_trace_params) try: if self._can_write(): await self.transport.write(EMPTY_FRAME) - self._last_frame_sent_time = time.time() + self.last_frame_sent_time = time.time() except (OSError, IOError, SSLError, socket.error) as exc: self._error = AMQPConnectionError( ErrorCondition.SocketError, @@ -421,8 +422,7 @@ async def _wait_for_response(self, wait, end_state): async def _listen_one_frame(self, **kwargs): new_frame = await self._read_frame(**kwargs) - if await self._process_incoming_frame(*new_frame): - raise ValueError("Stop") # Stop listening + return await self._process_incoming_frame(*new_frame) async def listen(self, wait=False, batch=1, **kwargs): try: @@ -450,12 +450,9 @@ async def listen(self, wait=False, batch=1, **kwargs): description="Connection was already closed." ) return - try: - tasks = [asyncio.ensure_future(self._listen_one_frame(**kwargs)) for _ in range(batch)] - await asyncio.gather(*tasks) - except ValueError: - for task in tasks: - task.cancel() + for _ in range(batch): + if await asyncio.ensure_future(self._listen_one_frame(**kwargs)): + break except (OSError, IOError, SSLError, socket.error) as exc: self._error = AMQPConnectionError( ErrorCondition.SocketError, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index 61b7b72b3c55..acbdd8af8e76 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -276,13 +276,7 @@ async def _read(self, toread, initial=False, buffer=None, try: while toread: try: - # TODO: await self.reader.readexactly would not return until it has received something which - # is problematic in the case timeout is required while no frame coming in. - # asyncio.wait_for is used here for timeout control - # set socket timeout does not work, not triggering socket error maybe should be a different config? - # also we could consider using a low level socket instead of high level reader/writer - # https://docs.python.org/3/library/asyncio-eventloop.html - view[nbytes:nbytes + toread] = await asyncio.wait_for(self.reader.readexactly(toread), timeout=1) + view[nbytes:nbytes + toread] = await self.reader.readexactly(toread) nbytes = toread except asyncio.IncompleteReadError as exc: pbytes = len(exc.partial) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py index 3d19c8d056c4..13c5ec5b035e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py @@ -3,4 +3,4 @@ # Licensed under the MIT License. # ------------------------------------ -VERSION = "5.8.0b3" +VERSION = "5.8.0a3" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index f38064221fc7..d5be74195636 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -125,6 +125,9 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N ) self._message_buffer = deque() # type: Deque[Message] self._last_received_event = None # type: Optional[EventData] + self._message_buffer_lock = asyncio.Lock() + self._last_callback_called_time = None + self._callback_task_run = None def _create_handler(self, auth: "JWTTokenAuthAsync") -> None: source = Source(self._source, filters={}) @@ -162,7 +165,8 @@ async def _open_with_retry(self) -> None: await self._do_retryable_operation(self._open, operation_need_param=False) async def _message_received(self, message: Message) -> None: - self._message_buffer.append(message) + async with self._message_buffer_lock: + self._message_buffer.append(message) def _next_message_in_buffer(self): # pylint:disable=protected-access @@ -171,54 +175,64 @@ def _next_message_in_buffer(self): self._last_received_event = event_data return event_data - async def receive(self, batch=False, max_batch_size=300, max_wait_time=None) -> None: + async def _callback_task(self, batch, max_batch_size, max_wait_time): + while self._callback_task_run: + async with self._message_buffer_lock: + messages = [ + self._message_buffer.popleft() for _ in range(min(max_batch_size, len(self._message_buffer))) + ] + events = [EventData._from_message(message) for message in messages] + now_time = time.time() + if len(events) > 0: + await self._on_event_received(events if batch else events[0]) + self._last_callback_called_time = now_time + else: + if max_wait_time and (now_time - self._last_callback_called_time) > max_wait_time: + # no events received, and need to callback + await self._on_event_received([] if batch else None) + self._last_callback_called_time = now_time + # backoff a bit to avoid throttling CPU when no events are coming + await asyncio.sleep(0.05) + + async def _receive_task(self): max_retries = ( self._client._config.max_retries # pylint:disable=protected-access ) - has_not_fetched_once = True # ensure one trip when max_wait_time is very small - deadline = time.time() + (max_wait_time or 0) # max_wait_time can be None - while len(self._message_buffer) < max_batch_size and \ - (time.time() < deadline or has_not_fetched_once): - retried_times = 0 - has_not_fetched_once = False - while retried_times <= max_retries: - try: - await self._open() - await cast(ReceiveClientAsync, self._handler).do_work_async(batch=self._prefetch) - break - except asyncio.CancelledError: # pylint: disable=try-except-raise - raise - except Exception as exception: # pylint: disable=broad-except - if ( + retried_times = 0 + while retried_times <= max_retries: + try: + await self._open() + await cast(ReceiveClientAsync, self._handler).do_work_async(batch=self._prefetch) + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise + except Exception as exception: # pylint: disable=broad-except + if ( isinstance(exception, error.AMQPLinkError) and exception.condition == error.ErrorCondition.LinkStolen # pylint: disable=no-member - ): - raise await self._handle_exception(exception) - if not self.running: # exit by close - return - if self._last_received_event: - self._offset = self._last_received_event.offset - last_exception = await self._handle_exception(exception) - retried_times += 1 - if retried_times > max_retries: - _LOGGER.info( - "%r operation has exhausted retry. Last exception: %r.", - self._name, - last_exception, - ) - raise last_exception - - if self._message_buffer: - while self._message_buffer: - if batch: - events_for_callback = [] # type: List[EventData] - for _ in range(min(max_batch_size, len(self._message_buffer))): - events_for_callback.append(self._next_message_in_buffer()) - await self._on_event_received(events_for_callback) - else: - await self._on_event_received(self._next_message_in_buffer()) - elif max_wait_time: - if batch: - await self._on_event_received([]) - else: - await self._on_event_received(None) + ): + raise await self._handle_exception(exception) + if not self.running: # exit by close + return + if self._last_received_event: + self._offset = self._last_received_event.offset + last_exception = await self._handle_exception(exception) + retried_times += 1 + if retried_times > max_retries: + _LOGGER.info( + "%r operation has exhausted retry. Last exception: %r.", + self._name, + last_exception, + ) + raise last_exception + + async def receive(self, batch=False, max_batch_size=300, max_wait_time=None) -> None: + self._callback_task_run = True + self._last_callback_called_time = time.time() + callback_task = asyncio.ensure_future(self._callback_task(batch, max_batch_size, max_wait_time)) + receive_task = asyncio.ensure_future(self._receive_task()) + + try: + await receive_task + finally: + self._callback_task_run = False + await callback_task diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_reconnect_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_reconnect_async.py index 5c04c46f2879..38a1690817cc 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_reconnect_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_reconnect_async.py @@ -127,13 +127,9 @@ async def on_event_received(event): await consumer._handler.do_work_async() assert consumer._handler._connection.state == constants.ConnectionState.END - duration = 10 - now_time = time.time() - end_time = now_time + duration - - while now_time < end_time: - await consumer.receive() - await asyncio.sleep(0.01) - now_time = time.time() + try: + await asyncio.wait_for(consumer.receive(), timeout=10) + except asyncio.TimeoutError: + pass assert on_event_received.event.body_as_str() == "Event"