Skip to content

Commit

Permalink
[EventHubs] kwargs/error testing (#27065)
Browse files Browse the repository at this point in the history
* adding tests

* add auth/connection tests + fixes

* fix connection verify error handling

* revert consumer retry change

* call ws close in sync transport

* typo

* fix ws exc import

* fix async transport

* fix link detach vendor error exception parity

* add operationtimeouterror

* add more negative tests

* annas comments + lint

* lint + tests

* add ids for uamqp vs pyamqp tests

* update tests

* skip macos tests
  • Loading branch information
swathipil authored Nov 15, 2022
1 parent 0a14a3c commit 4666e53
Show file tree
Hide file tree
Showing 23 changed files with 781 additions and 91 deletions.
12 changes: 9 additions & 3 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,14 @@ def _management_request(
return response
raise self._amqp_transport.get_error(status_code, description)
except Exception as exception: # pylint: disable=broad-except
# is_consumer=True passed in here, ALTHOUGH this method is shared by the producer and consumer.
# is_consumer will only be checked if FileNotFoundError is raised by self.mgmt_client.open() due to
# invalid/non-existent connection_verify filepath. The producer will encounter the FileNotFoundError
# when opening the SendClient, so is_consumer=True will not be passed to amqp_transport.handle_exception
# there. This is for uamqp exception parity, which raises FileNotFoundError in the consumer and
# EventHubError in the producer. TODO: Remove `is_consumer` kwarg when resolving issue #27128.
last_exception = self._amqp_transport._handle_exception( # pylint: disable=protected-access
exception, self
exception, self, is_consumer=True
)
self._backoff(
retried_times=retried_times, last_exception=last_exception
Expand Down Expand Up @@ -553,10 +559,10 @@ def _close_connection(self):
self._close_handler()
self._client._conn_manager.reset_connection_if_broken() # pylint: disable=protected-access

def _handle_exception(self, exception):
def _handle_exception(self, exception, *, is_consumer=False):
exception = self._amqp_transport.check_timeout_exception(self, exception)
return self._amqp_transport._handle_exception( # pylint: disable=protected-access
exception, self
exception, self, is_consumer=is_consumer
)

def _do_retryable_operation(self, operation, timeout=None, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,13 @@ def receive(self, batch=False, max_batch_size=300, max_wait_time=None):
break
except Exception as exception: # pylint: disable=broad-except
self._amqp_transport.check_link_stolen(self, exception)
# TODO: below block hangs when retry_total > 0
# need to remove/refactor, issue #27137
if not self.running: # exit by close
return
if self._last_received_event:
self._offset = self._last_received_event.offset
last_exception = self._handle_exception(exception)
last_exception = self._handle_exception(exception, is_consumer=True)
retried_times += 1
if retried_times > max_retries:
_LOGGER.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ def _do_receive(self, partition_id, consumer):
error,
)
self._process_error(self._partition_contexts[partition_id], error)
# TODO: close consumer if non-retryable. issue #27137
# Does OWNERSHIP_LOST make sense for all errors?
self._close_consumer(partition_id, consumer, CloseReason.OWNERSHIP_LOST)

def start(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,15 @@ def _connect(self):
else:
self._set_state(ConnectionState.HDR_SENT)
except (OSError, IOError, SSLError, socket.error) as exc:
# FileNotFoundError is being raised for exception parity with uamqp when invalid
# `connection_verify` file path is passed in. Remove later when resolving issue #27128.
if isinstance(exc, FileNotFoundError) and exc.filename and "ca_certs" in exc.filename:
raise
raise AMQPConnectionError(
ErrorCondition.SocketError,
description="Failed to initiate the connection due to exception: " + str(exc),
error=exc,
)
except Exception: # pylint:disable=try-except-raise
raise

def _disconnect(self):
# type: () -> None
Expand Down
50 changes: 44 additions & 6 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
TransportType,
AMQP_WS_SUBPROTOCOL,
)
from .error import AuthenticationException, ErrorCondition


try:
Expand Down Expand Up @@ -551,7 +552,15 @@ def _wrap_socket_sni( # pylint: disable=no-self-use
}

# TODO: We need to refactor this.
sock = ssl.wrap_socket(**opts) # pylint: disable=deprecated-method
try:
sock = ssl.wrap_socket(**opts) # pylint: disable=deprecated-method
except FileNotFoundError as exc:
# FileNotFoundError does not have missing filename info, so adding it below.
# Assuming that this must be ca_certs, since this is the only file path that
# users can pass in (`connection_verify` in the EH/SB clients) through opts above.
# For uamqp exception parity. Remove later when resolving issue #27128.
exc.filename = {"ca_certs": ca_certs}
raise exc
# Set SNI headers if supported
if (
(server_hostname is not None)
Expand Down Expand Up @@ -684,7 +693,12 @@ def connect(self):
if username or password:
http_proxy_auth = (username, password)
try:
from websocket import create_connection
from websocket import (
create_connection,
WebSocketAddressException,
WebSocketTimeoutException,
WebSocketConnectionClosedException
)

self.ws = create_connection(
url="wss://{}".format(self._custom_endpoint or self._host),
Expand All @@ -696,6 +710,25 @@ def connect(self):
http_proxy_port=http_proxy_port,
http_proxy_auth=http_proxy_auth,
)
except WebSocketAddressException as exc:
raise AuthenticationException(
ErrorCondition.ClientError,
description="Failed to authenticate the connection due to exception: " + str(exc),
error=exc,
)
# TODO: resolve pylance error when type: ignore is removed below, issue #22051
except (WebSocketTimeoutException, SSLError, WebSocketConnectionClosedException) as exc: # type: ignore
self.close()
if isinstance(exc, WebSocketTimeoutException):
message = f'Send timed out ({str(exc)})'
elif isinstance(exc, SSLError):
message = f'Send disconnected by SSL ({str(exc)})'
else:
message = f'Send disconnected ({str(exc)})'
raise ConnectionError(message)
except (OSError, IOError, SSLError):
self.close()
raise
except ImportError:
raise ValueError(
"Please install websocket-client library to use websocket transport."
Expand All @@ -722,7 +755,12 @@ def _read(self, n, initial=False, buffer=None, _errnos=None): # pylint: disable
n = 0
return view
except WebSocketTimeoutException as wte:
raise ConnectionError('recv timed out (%s)' % wte)
raise ConnectionError('Receive timed out (%s)' % wte)

def close(self):
if self.ws:
self._shutdown_transport()
self.ws = None

def _shutdown_transport(self):
# TODO Sync and Async close functions named differently
Expand All @@ -739,9 +777,9 @@ def _write(self, s):
try:
self.ws.send_binary(s)
except WebSocketTimeoutException as e:
raise ConnectionError('send timed out (%s)' % e)
raise ConnectionError('Send timed out (%s)' % e)
except SSLError as e:
raise ConnectionError('send disconnected by SSL (%s)' % e)
raise ConnectionError('Send disconnected by SSL (%s)' % e)
except WebSocketConnectionClosedException as e:
raise ConnectionError('send disconnected (%s)' % e)
raise ConnectionError('Send disconnected (%s)' % e)

Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ async def _connect(self):
else:
await self._set_state(ConnectionState.HDR_SENT)
except (OSError, IOError, SSLError, socket.error, asyncio.TimeoutError) as exc:
# FileNotFoundError is being raised for exception parity with uamqp when invalid
# `connection_verify` file path is passed in. Remove later when resolving issue #27128.
if isinstance(exc, FileNotFoundError) and exc.filename and "ca_certs" in exc.filename:
raise
raise AMQPConnectionError(
ErrorCondition.SocketError,
description="Failed to initiate the connection due to exception: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
AMQP_PORT,
TIMEOUT_INTERVAL,
)
from ..error import AuthenticationException, ErrorCondition


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -178,6 +179,11 @@ def _build_ssl_opts(self, sslopts):
if (certfile is not None) and (keyfile is not None):
context.load_cert_chain(certfile, keyfile)
return context
ca_certs = sslopts.get("ca_certs")
if ca_certs:
context = ssl.SSLContext(ssl_version)
context.load_verify_locations(ca_certs)
return context
return True
except TypeError:
raise TypeError(
Expand Down Expand Up @@ -220,9 +226,8 @@ def __init__(

self.connect_timeout = connect_timeout
self.socket_settings = socket_settings
self.loop = asyncio.get_running_loop()
self.socket_lock = asyncio.Lock()
self.sslopts = self._build_ssl_opts(ssl_opts)
self.sslopts = ssl_opts

async def connect(self):
try:
Expand Down Expand Up @@ -263,7 +268,7 @@ async def _connect(self, host, port, timeout):
for n, family in enumerate(addr_types):
# first, resolve the address for a single address family
try:
entries = await self.loop.getaddrinfo(
entries = await asyncio.get_event_loop().getaddrinfo(
host, port, family=family, type=socket.SOCK_STREAM, proto=SOL_TCP
)
entries_num = len(entries)
Expand All @@ -285,7 +290,7 @@ async def _connect(self, host, port, timeout):
except NotImplementedError:
pass
self.sock.settimeout(timeout)
await self.loop.sock_connect(self.sock, sa)
await asyncio.get_event_loop().sock_connect(self.sock, sa)
except socket.error as ex:
e = ex
if self.sock is not None:
Expand All @@ -302,6 +307,17 @@ def _init_socket(self, socket_settings):
self.sock.settimeout(None) # set socket back to blocking mode
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self._set_socket_options(socket_settings)
try:
# Building ssl opts here instead of constructor, so that invalid cert error is raised
# when client is connecting, rather then during creation. For uamqp exception parity.
self.sslopts = self._build_ssl_opts(self.sslopts)
except FileNotFoundError as exc:
# FileNotFoundError does not have missing filename info, so adding it below.
# Assuming that this must be ca_certs, since this is the only file path that
# users can pass in (`connection_verify` in the EH/SB clients) through sslopts above.
# For uamqp exception parity. Remove later when resolving issue #27128.
exc.filename = self.sslopts
raise exc
self.sock.settimeout(1) # set socket back to non-blocking mode

def _get_tcp_socket_defaults(self, sock): # pylint: disable=no-self-use
Expand Down Expand Up @@ -386,6 +402,7 @@ async def _write(self, s):

async def close(self):
if self.writer is not None:
# Closing the writer closes the underlying socket.
self.writer.close()
if self.sslopts:
# see issue: https://github.com/encode/httpx/issues/914
Expand Down Expand Up @@ -444,7 +461,7 @@ def __init__(
):
self._read_buffer = BytesIO()
self.socket_lock = asyncio.Lock()
self.sslopts = self._build_ssl_opts(ssl_opts) if isinstance(ssl_opts, dict) else None
self.sslopts = ssl_opts if isinstance(ssl_opts, dict) else None
self._connect_timeout = connect_timeout or TIMEOUT_INTERVAL
self._custom_endpoint = kwargs.get("custom_endpoint")
self.host, self.port = to_host_port(host, port)
Expand All @@ -454,6 +471,7 @@ def __init__(
self.connected = False

async def connect(self):
self.sslopts = self._build_ssl_opts(self.sslopts)
username, password = None, None
http_proxy_host, http_proxy_port = None, None
http_proxy_auth = None
Expand All @@ -467,7 +485,7 @@ async def connect(self):
password = self._http_proxy.get("password", None)

try:
from aiohttp import ClientSession
from aiohttp import ClientSession, ClientConnectorError
from urllib.parse import urlsplit

if username or password:
Expand All @@ -483,26 +501,33 @@ async def connect(self):
parsed_url = urlsplit(url)
url = f"{parsed_url.scheme}://{parsed_url.netloc}:{self.port}{parsed_url.path}"

# Enabling heartbeat that sends a ping message every n seconds and waits for pong response.
# if pong response is not received then close connection. This raises an error when trying
# to communicate with the websocket which is no longer active.
# We are waiting a bug fix in aiohttp for these 2 bugs where aiohttp ws might hang on network disconnect
# and the heartbeat mechanism helps mitigate these two.
# https://github.com/aio-libs/aiohttp/pull/5860
# https://github.com/aio-libs/aiohttp/issues/2309

self.ws = await self.session.ws_connect(
url=url,
timeout=self._connect_timeout,
protocols=[AMQP_WS_SUBPROTOCOL],
autoclose=False,
proxy=http_proxy_host,
proxy_auth=http_proxy_auth,
ssl=self.sslopts,
heartbeat=DEFAULT_WEBSOCKET_HEARTBEAT_SECONDS,
)
try:
# Enabling heartbeat that sends a ping message every n seconds and waits for pong response.
# if pong response is not received then close connection. This raises an error when trying
# to communicate with the websocket which is no longer active.
# We are waiting a bug fix in aiohttp for these 2 bugs where aiohttp ws might hang on network disconnect
# and the heartbeat mechanism helps mitigate these two.
# https://github.com/aio-libs/aiohttp/pull/5860
# https://github.com/aio-libs/aiohttp/issues/2309

self.ws = await self.session.ws_connect(
url=url,
timeout=self._connect_timeout,
protocols=[AMQP_WS_SUBPROTOCOL],
autoclose=False,
proxy=http_proxy_host,
proxy_auth=http_proxy_auth,
ssl=self.sslopts,
heartbeat=DEFAULT_WEBSOCKET_HEARTBEAT_SECONDS,
)
except ClientConnectorError as exc:
if self._custom_endpoint:
raise AuthenticationException(
ErrorCondition.ClientError,
description="Failed to authenticate the connection due to exception: " + str(exc),
error=exc,
)
self.connected = True

except ImportError:
raise ValueError(
"Please install aiohttp library to use websocket transport."
Expand Down Expand Up @@ -532,7 +557,7 @@ async def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-argum
n = 0
return view
except asyncio.TimeoutError as te:
raise ConnectionError('recv timed out (%s)' % te)
raise ConnectionError('Receive timed out (%s)' % te)

async def close(self):
"""Do any preliminary work in shutting down the connection."""
Expand All @@ -549,4 +574,4 @@ async def write(self, s):
try:
await self.ws.send_bytes(s)
except asyncio.TimeoutError as te:
raise ConnectionError('send timed out (%s)' % te)
raise ConnectionError('Send timed out (%s)' % te)
Loading

0 comments on commit 4666e53

Please sign in to comment.