Skip to content

Commit

Permalink
[EH Pyproto] Async recv perf improvement (#23122)
Browse files Browse the repository at this point in the history
* stop spawning too much coroutines

* improve send

* async recv perf improvement

* async perf improve

* update version

* align with sync imple

* update method name

* remove redundant except catch
  • Loading branch information
yunhaoling authored Mar 4, 2022
1 parent fe47341 commit 2805d00
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ async def _set_state(self, new_state):
self.state = new_state
_LOGGER.info("Connection '%s' state changed: %r -> %r", self._container_id, previous_state, new_state)

await asyncio.gather(*[session._on_connection_state_change() for session in self.outgoing_endpoints.values()])
for session in self.outgoing_endpoints.values():
await session._on_connection_state_change()

async def _connect(self):
try:
Expand Down Expand Up @@ -205,11 +206,11 @@ def _get_next_outgoing_channel(self):

async def _outgoing_empty(self):
if self.network_trace:
_LOGGER.info("<- empty()", extra=self.network_trace_params)
_LOGGER.info("-> empty()", extra=self.network_trace_params)
try:
if self._can_write():
await self.transport.write(EMPTY_FRAME)
self._last_frame_sent_time = time.time()
self.last_frame_sent_time = time.time()
except (OSError, IOError, SSLError, socket.error) as exc:
self._error = AMQPConnectionError(
ErrorCondition.SocketError,
Expand Down Expand Up @@ -421,8 +422,7 @@ async def _wait_for_response(self, wait, end_state):

async def _listen_one_frame(self, **kwargs):
new_frame = await self._read_frame(**kwargs)
if await self._process_incoming_frame(*new_frame):
raise ValueError("Stop") # Stop listening
return await self._process_incoming_frame(*new_frame)

async def listen(self, wait=False, batch=1, **kwargs):
try:
Expand Down Expand Up @@ -450,12 +450,9 @@ async def listen(self, wait=False, batch=1, **kwargs):
description="Connection was already closed."
)
return
try:
tasks = [asyncio.ensure_future(self._listen_one_frame(**kwargs)) for _ in range(batch)]
await asyncio.gather(*tasks)
except ValueError:
for task in tasks:
task.cancel()
for _ in range(batch):
if await asyncio.ensure_future(self._listen_one_frame(**kwargs)):
break
except (OSError, IOError, SSLError, socket.error) as exc:
self._error = AMQPConnectionError(
ErrorCondition.SocketError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,7 @@ async def _read(self, toread, initial=False, buffer=None,
try:
while toread:
try:
# TODO: await self.reader.readexactly would not return until it has received something which
# is problematic in the case timeout is required while no frame coming in.
# asyncio.wait_for is used here for timeout control
# set socket timeout does not work, not triggering socket error maybe should be a different config?
# also we could consider using a low level socket instead of high level reader/writer
# https://docs.python.org/3/library/asyncio-eventloop.html
view[nbytes:nbytes + toread] = await asyncio.wait_for(self.reader.readexactly(toread), timeout=1)
view[nbytes:nbytes + toread] = await self.reader.readexactly(toread)
nbytes = toread
except asyncio.IncompleteReadError as exc:
pbytes = len(exc.partial)
Expand Down
2 changes: 1 addition & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# Licensed under the MIT License.
# ------------------------------------

VERSION = "5.8.0b3"
VERSION = "5.8.0a3"
108 changes: 61 additions & 47 deletions sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N
)
self._message_buffer = deque() # type: Deque[Message]
self._last_received_event = None # type: Optional[EventData]
self._message_buffer_lock = asyncio.Lock()
self._last_callback_called_time = None
self._callback_task_run = None

def _create_handler(self, auth: "JWTTokenAuthAsync") -> None:
source = Source(self._source, filters={})
Expand Down Expand Up @@ -162,7 +165,8 @@ async def _open_with_retry(self) -> None:
await self._do_retryable_operation(self._open, operation_need_param=False)

async def _message_received(self, message: Message) -> None:
self._message_buffer.append(message)
async with self._message_buffer_lock:
self._message_buffer.append(message)

def _next_message_in_buffer(self):
# pylint:disable=protected-access
Expand All @@ -171,54 +175,64 @@ def _next_message_in_buffer(self):
self._last_received_event = event_data
return event_data

async def receive(self, batch=False, max_batch_size=300, max_wait_time=None) -> None:
async def _callback_task(self, batch, max_batch_size, max_wait_time):
while self._callback_task_run:
async with self._message_buffer_lock:
messages = [
self._message_buffer.popleft() for _ in range(min(max_batch_size, len(self._message_buffer)))
]
events = [EventData._from_message(message) for message in messages]
now_time = time.time()
if len(events) > 0:
await self._on_event_received(events if batch else events[0])
self._last_callback_called_time = now_time
else:
if max_wait_time and (now_time - self._last_callback_called_time) > max_wait_time:
# no events received, and need to callback
await self._on_event_received([] if batch else None)
self._last_callback_called_time = now_time
# backoff a bit to avoid throttling CPU when no events are coming
await asyncio.sleep(0.05)

async def _receive_task(self):
max_retries = (
self._client._config.max_retries # pylint:disable=protected-access
)
has_not_fetched_once = True # ensure one trip when max_wait_time is very small
deadline = time.time() + (max_wait_time or 0) # max_wait_time can be None
while len(self._message_buffer) < max_batch_size and \
(time.time() < deadline or has_not_fetched_once):
retried_times = 0
has_not_fetched_once = False
while retried_times <= max_retries:
try:
await self._open()
await cast(ReceiveClientAsync, self._handler).do_work_async(batch=self._prefetch)
break
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as exception: # pylint: disable=broad-except
if (
retried_times = 0
while retried_times <= max_retries:
try:
await self._open()
await cast(ReceiveClientAsync, self._handler).do_work_async(batch=self._prefetch)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as exception: # pylint: disable=broad-except
if (
isinstance(exception, error.AMQPLinkError)
and exception.condition == error.ErrorCondition.LinkStolen # pylint: disable=no-member
):
raise await self._handle_exception(exception)
if not self.running: # exit by close
return
if self._last_received_event:
self._offset = self._last_received_event.offset
last_exception = await self._handle_exception(exception)
retried_times += 1
if retried_times > max_retries:
_LOGGER.info(
"%r operation has exhausted retry. Last exception: %r.",
self._name,
last_exception,
)
raise last_exception

if self._message_buffer:
while self._message_buffer:
if batch:
events_for_callback = [] # type: List[EventData]
for _ in range(min(max_batch_size, len(self._message_buffer))):
events_for_callback.append(self._next_message_in_buffer())
await self._on_event_received(events_for_callback)
else:
await self._on_event_received(self._next_message_in_buffer())
elif max_wait_time:
if batch:
await self._on_event_received([])
else:
await self._on_event_received(None)
):
raise await self._handle_exception(exception)
if not self.running: # exit by close
return
if self._last_received_event:
self._offset = self._last_received_event.offset
last_exception = await self._handle_exception(exception)
retried_times += 1
if retried_times > max_retries:
_LOGGER.info(
"%r operation has exhausted retry. Last exception: %r.",
self._name,
last_exception,
)
raise last_exception

async def receive(self, batch=False, max_batch_size=300, max_wait_time=None) -> None:
self._callback_task_run = True
self._last_callback_called_time = time.time()
callback_task = asyncio.ensure_future(self._callback_task(batch, max_batch_size, max_wait_time))
receive_task = asyncio.ensure_future(self._receive_task())

try:
await receive_task
finally:
self._callback_task_run = False
await callback_task
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,9 @@ async def on_event_received(event):
await consumer._handler.do_work_async()
assert consumer._handler._connection.state == constants.ConnectionState.END

duration = 10
now_time = time.time()
end_time = now_time + duration

while now_time < end_time:
await consumer.receive()
await asyncio.sleep(0.01)
now_time = time.time()
try:
await asyncio.wait_for(consumer.receive(), timeout=10)
except asyncio.TimeoutError:
pass

assert on_event_received.event.body_as_str() == "Event"

0 comments on commit 2805d00

Please sign in to comment.