Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EventHubs] kwargs/error testing #27065

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,14 @@ def _management_request(
return response
raise self._amqp_transport.get_error(status_code, description)
except Exception as exception: # pylint: disable=broad-except
# is_consumer=True passed in here, ALTHOUGH this method is shared by the producer and consumer.
# is_consumer will only be checked if FileNotFoundError is raised by self.mgmt_client.open() due to
# invalid/non-existent connection_verify filepath. The producer will encounter the FileNotFoundError
# when opening the SendClient, so is_consumer=True will not be passed to amqp_transport.handle_exception
# there. This is for uamqp exception parity, which raises FileNotFoundError in the consumer and
# EventHubError in the producer. TODO: Remove `is_consumer` kwarg when resolving issue #27128.
last_exception = self._amqp_transport._handle_exception( # pylint: disable=protected-access
exception, self
exception, self, is_consumer=True
)
self._backoff(
retried_times=retried_times, last_exception=last_exception
Expand Down Expand Up @@ -553,10 +559,10 @@ def _close_connection(self):
self._close_handler()
self._client._conn_manager.reset_connection_if_broken() # pylint: disable=protected-access

def _handle_exception(self, exception):
def _handle_exception(self, exception, *, is_consumer=False):
exception = self._amqp_transport.check_timeout_exception(self, exception)
return self._amqp_transport._handle_exception( # pylint: disable=protected-access
exception, self
exception, self, is_consumer=is_consumer
)

def _do_retryable_operation(self, operation, timeout=None, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,13 @@ def receive(self, batch=False, max_batch_size=300, max_wait_time=None):
break
except Exception as exception: # pylint: disable=broad-except
self._amqp_transport.check_link_stolen(self, exception)
# TODO: below block hangs when retry_total > 0
# need to remove/refactor, issue #27137
if not self.running: # exit by close
return
if self._last_received_event:
self._offset = self._last_received_event.offset
last_exception = self._handle_exception(exception)
last_exception = self._handle_exception(exception, is_consumer=True)
retried_times += 1
if retried_times > max_retries:
_LOGGER.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ def _do_receive(self, partition_id, consumer):
error,
)
self._process_error(self._partition_contexts[partition_id], error)
# TODO: close consumer if non-retryable. issue #27137
# Does OWNERSHIP_LOST make sense for all errors?
self._close_consumer(partition_id, consumer, CloseReason.OWNERSHIP_LOST)

def start(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,15 @@ def _connect(self):
else:
self._set_state(ConnectionState.HDR_SENT)
except (OSError, IOError, SSLError, socket.error) as exc:
# FileNotFoundError is being raised for exception parity with uamqp when invalid
# `connection_verify` file path is passed in. Remove later when resolving issue #27128.
if isinstance(exc, FileNotFoundError) and exc.filename and "ca_certs" in exc.filename:
raise
raise AMQPConnectionError(
ErrorCondition.SocketError,
description="Failed to initiate the connection due to exception: " + str(exc),
error=exc,
)
except Exception: # pylint:disable=try-except-raise
raise

def _disconnect(self):
# type: () -> None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
TransportType,
AMQP_WS_SUBPROTOCOL,
)
from .error import AuthenticationException, ErrorCondition


try:
Expand Down Expand Up @@ -551,7 +552,15 @@ def _wrap_socket_sni( # pylint: disable=no-self-use
}

# TODO: We need to refactor this.
sock = ssl.wrap_socket(**opts) # pylint: disable=deprecated-method
try:
sock = ssl.wrap_socket(**opts) # pylint: disable=deprecated-method
except FileNotFoundError as exc:
# FileNotFoundError does not have missing filename info, so adding it below.
# Assuming that this must be ca_certs, since this is the only file path that
# users can pass in (`connection_verify` in the EH/SB clients) through opts above.
# For uamqp exception parity. Remove later when resolving issue #27128.
exc.filename = {"ca_certs": ca_certs}
raise exc
# Set SNI headers if supported
if (
(server_hostname is not None)
Expand Down Expand Up @@ -684,7 +693,12 @@ def connect(self):
if username or password:
http_proxy_auth = (username, password)
try:
from websocket import create_connection
from websocket import (
create_connection,
WebSocketAddressException,
WebSocketTimeoutException,
WebSocketConnectionClosedException
)

self.ws = create_connection(
url="wss://{}".format(self._custom_endpoint or self._host),
Expand All @@ -696,6 +710,25 @@ def connect(self):
http_proxy_port=http_proxy_port,
http_proxy_auth=http_proxy_auth,
)
except WebSocketAddressException as exc:
raise AuthenticationException(
ErrorCondition.ClientError,
description="Failed to authenticate the connection due to exception: " + str(exc),
error=exc,
)
# TODO: resolve pylance error when type: ignore is removed below, issue #22051
except (WebSocketTimeoutException, SSLError, WebSocketConnectionClosedException) as exc: # type: ignore
self.close()
if isinstance(exc, WebSocketTimeoutException):
message = f'Send timed out ({str(exc)})'
elif isinstance(exc, SSLError):
message = f'Send disconnected by SSL ({str(exc)})'
else:
message = f'Send disconnected ({str(exc)})'
raise ConnectionError(message)
except (OSError, IOError, SSLError):
self.close()
raise
except ImportError:
raise ValueError(
"Please install websocket-client library to use websocket transport."
Expand All @@ -722,7 +755,12 @@ def _read(self, n, initial=False, buffer=None, _errnos=None): # pylint: disable
n = 0
return view
except WebSocketTimeoutException as wte:
raise ConnectionError('recv timed out (%s)' % wte)
raise ConnectionError('Receive timed out (%s)' % wte)

def close(self):
if self.ws:
self._shutdown_transport()
self.ws = None

def _shutdown_transport(self):
# TODO Sync and Async close functions named differently
Expand All @@ -739,9 +777,9 @@ def _write(self, s):
try:
self.ws.send_binary(s)
except WebSocketTimeoutException as e:
raise ConnectionError('send timed out (%s)' % e)
raise ConnectionError('Send timed out (%s)' % e)
except SSLError as e:
raise ConnectionError('send disconnected by SSL (%s)' % e)
raise ConnectionError('Send disconnected by SSL (%s)' % e)
except WebSocketConnectionClosedException as e:
raise ConnectionError('send disconnected (%s)' % e)
raise ConnectionError('Send disconnected (%s)' % e)

Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ async def _connect(self):
else:
await self._set_state(ConnectionState.HDR_SENT)
except (OSError, IOError, SSLError, socket.error, asyncio.TimeoutError) as exc:
# FileNotFoundError is being raised for exception parity with uamqp when invalid
# `connection_verify` file path is passed in. Remove later when resolving issue #27128.
if isinstance(exc, FileNotFoundError) and exc.filename and "ca_certs" in exc.filename:
raise
raise AMQPConnectionError(
ErrorCondition.SocketError,
description="Failed to initiate the connection due to exception: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
AMQP_PORT,
TIMEOUT_INTERVAL,
)
from ..error import AuthenticationException, ErrorCondition


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -178,6 +179,11 @@ def _build_ssl_opts(self, sslopts):
if (certfile is not None) and (keyfile is not None):
context.load_cert_chain(certfile, keyfile)
return context
ca_certs = sslopts.get("ca_certs")
if ca_certs:
context = ssl.SSLContext(ssl_version)
context.load_verify_locations(ca_certs)
return context
return True
except TypeError:
raise TypeError(
Expand Down Expand Up @@ -220,9 +226,8 @@ def __init__(

self.connect_timeout = connect_timeout
self.socket_settings = socket_settings
self.loop = asyncio.get_running_loop()
self.socket_lock = asyncio.Lock()
self.sslopts = self._build_ssl_opts(ssl_opts)
self.sslopts = ssl_opts

async def connect(self):
try:
Expand Down Expand Up @@ -263,7 +268,7 @@ async def _connect(self, host, port, timeout):
for n, family in enumerate(addr_types):
# first, resolve the address for a single address family
try:
entries = await self.loop.getaddrinfo(
entries = await asyncio.get_event_loop().getaddrinfo(
host, port, family=family, type=socket.SOCK_STREAM, proto=SOL_TCP
)
entries_num = len(entries)
Expand All @@ -285,7 +290,7 @@ async def _connect(self, host, port, timeout):
except NotImplementedError:
pass
self.sock.settimeout(timeout)
await self.loop.sock_connect(self.sock, sa)
await asyncio.get_event_loop().sock_connect(self.sock, sa)
except socket.error as ex:
e = ex
if self.sock is not None:
Expand All @@ -302,6 +307,17 @@ def _init_socket(self, socket_settings):
self.sock.settimeout(None) # set socket back to blocking mode
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self._set_socket_options(socket_settings)
try:
# Building ssl opts here instead of constructor, so that invalid cert error is raised
# when client is connecting, rather then during creation. For uamqp exception parity.
self.sslopts = self._build_ssl_opts(self.sslopts)
except FileNotFoundError as exc:
# FileNotFoundError does not have missing filename info, so adding it below.
# Assuming that this must be ca_certs, since this is the only file path that
# users can pass in (`connection_verify` in the EH/SB clients) through sslopts above.
# For uamqp exception parity. Remove later when resolving issue #27128.
exc.filename = self.sslopts
raise exc
self.sock.settimeout(1) # set socket back to non-blocking mode

def _get_tcp_socket_defaults(self, sock): # pylint: disable=no-self-use
Expand Down Expand Up @@ -386,6 +402,7 @@ async def _write(self, s):

async def close(self):
if self.writer is not None:
# Closing the writer closes the underlying socket.
self.writer.close()
if self.sslopts:
# see issue: https://github.com/encode/httpx/issues/914
Expand Down Expand Up @@ -444,7 +461,7 @@ def __init__(
):
self._read_buffer = BytesIO()
self.socket_lock = asyncio.Lock()
self.sslopts = self._build_ssl_opts(ssl_opts) if isinstance(ssl_opts, dict) else None
self.sslopts = ssl_opts if isinstance(ssl_opts, dict) else None
self._connect_timeout = connect_timeout or TIMEOUT_INTERVAL
self._custom_endpoint = kwargs.get("custom_endpoint")
self.host, self.port = to_host_port(host, port)
Expand All @@ -454,6 +471,7 @@ def __init__(
self.connected = False

async def connect(self):
self.sslopts = self._build_ssl_opts(self.sslopts)
username, password = None, None
http_proxy_host, http_proxy_port = None, None
http_proxy_auth = None
Expand All @@ -467,7 +485,7 @@ async def connect(self):
password = self._http_proxy.get("password", None)

try:
from aiohttp import ClientSession
from aiohttp import ClientSession, ClientConnectorError
from urllib.parse import urlsplit

if username or password:
Expand All @@ -483,26 +501,33 @@ async def connect(self):
parsed_url = urlsplit(url)
url = f"{parsed_url.scheme}://{parsed_url.netloc}:{self.port}{parsed_url.path}"

# Enabling heartbeat that sends a ping message every n seconds and waits for pong response.
# if pong response is not received then close connection. This raises an error when trying
# to communicate with the websocket which is no longer active.
# We are waiting a bug fix in aiohttp for these 2 bugs where aiohttp ws might hang on network disconnect
# and the heartbeat mechanism helps mitigate these two.
# https://github.com/aio-libs/aiohttp/pull/5860
# https://github.com/aio-libs/aiohttp/issues/2309

self.ws = await self.session.ws_connect(
url=url,
timeout=self._connect_timeout,
protocols=[AMQP_WS_SUBPROTOCOL],
autoclose=False,
proxy=http_proxy_host,
proxy_auth=http_proxy_auth,
ssl=self.sslopts,
heartbeat=DEFAULT_WEBSOCKET_HEARTBEAT_SECONDS,
)
try:
# Enabling heartbeat that sends a ping message every n seconds and waits for pong response.
# if pong response is not received then close connection. This raises an error when trying
# to communicate with the websocket which is no longer active.
# We are waiting a bug fix in aiohttp for these 2 bugs where aiohttp ws might hang on network disconnect
# and the heartbeat mechanism helps mitigate these two.
# https://github.com/aio-libs/aiohttp/pull/5860
# https://github.com/aio-libs/aiohttp/issues/2309

self.ws = await self.session.ws_connect(
url=url,
timeout=self._connect_timeout,
protocols=[AMQP_WS_SUBPROTOCOL],
autoclose=False,
proxy=http_proxy_host,
proxy_auth=http_proxy_auth,
ssl=self.sslopts,
heartbeat=DEFAULT_WEBSOCKET_HEARTBEAT_SECONDS,
)
except ClientConnectorError as exc:
if self._custom_endpoint:
raise AuthenticationException(
ErrorCondition.ClientError,
description="Failed to authenticate the connection due to exception: " + str(exc),
error=exc,
)
self.connected = True

except ImportError:
raise ValueError(
"Please install aiohttp library to use websocket transport."
Expand Down Expand Up @@ -532,7 +557,7 @@ async def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-argum
n = 0
return view
except asyncio.TimeoutError as te:
raise ConnectionError('recv timed out (%s)' % te)
raise ConnectionError('Receive timed out (%s)' % te)

async def close(self):
"""Do any preliminary work in shutting down the connection."""
Expand All @@ -549,4 +574,4 @@ async def write(self, s):
try:
await self.ws.send_bytes(s)
except asyncio.TimeoutError as te:
raise ConnectionError('send timed out (%s)' % te)
raise ConnectionError('Send timed out (%s)' % te)
Loading