From b6f8b721d3d0ccb9724fa370f80a080bc15d94d0 Mon Sep 17 00:00:00 2001 From: swathipil <76007337+swathipil@users.noreply.github.com> Date: Fri, 13 May 2022 20:00:36 -0700 Subject: [PATCH] Adding back Rakshith's websocket changes (#24410) * Adding back Rakshith's sync websocket changes * fix async send and receive * fix transport bugs * add websocket to dev reqs + async fix hostname * thank you kashif * fix tests + turn on websocket tests * update consumer test timing --- sdk/eventhub/azure-eventhub/CHANGELOG.md | 6 +- .../azure/eventhub/_client_base.py | 11 +- .../azure/eventhub/_consumer.py | 11 +- .../azure/eventhub/_producer.py | 8 +- .../azure/eventhub/_pyamqp/_connection.py | 23 +- .../azure/eventhub/_pyamqp/_transport.py | 80 ++++++- .../eventhub/_pyamqp/aio/_client_async.py | 5 +- .../eventhub/_pyamqp/aio/_connection_async.py | 21 +- .../azure/eventhub/_pyamqp/aio/_sasl_async.py | 46 +++- .../eventhub/_pyamqp/aio/_transport_async.py | 215 ++++++++++++------ .../azure/eventhub/_pyamqp/client.py | 13 +- .../azure/eventhub/_pyamqp/constants.py | 19 ++ .../azure/eventhub/_pyamqp/sasl.py | 76 ++++--- .../azure-eventhub/azure/eventhub/_version.py | 2 +- .../azure/eventhub/aio/_client_base_async.py | 9 +- .../azure/eventhub/aio/_consumer_async.py | 11 +- .../azure/eventhub/aio/_producer_async.py | 8 +- .../eventhub/aio/_producer_client_async.py | 4 - .../azure-eventhub/dev_requirements.txt | 1 + .../livetest/asynctests/test_send_async.py | 1 - .../synctests/test_consumer_client.py | 2 +- .../tests/livetest/synctests/test_send.py | 18 +- 22 files changed, 443 insertions(+), 147 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index fd6783749385..d72a775683b3 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -4,11 +4,7 @@ ### Features Added -### Breaking Changes - -### Bugs Fixed - -### Other Changes +- Added support for connection using websocket and http proxy. ## 5.8.0a3 (2022-03-08) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 7800e1ce04e0..e1fa271fd398 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -333,8 +333,6 @@ def _create_auth(self): functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, custom_endpoint_hostname=self._config.custom_endpoint_hostname, port=self._config.connection_port, verify=self._config.connection_verify, @@ -379,8 +377,15 @@ def _management_request(self, mgmt_msg, op_type): last_exception = None while retried_times <= self._config.max_retries: mgmt_auth = self._create_auth() + hostname = self._address.hostname + if self._config.transport_type.name == 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' mgmt_client = AMQPClient( - self._address.hostname, auth=mgmt_auth, debug=self._config.network_tracing + hostname, + auth=mgmt_auth, + debug=self._config.network_tracing, + transport_type=self._config.transport_type, + http_proxy=self._config.http_proxy ) try: mgmt_client.open() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index ddb9a14a166f..677954c4937e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -150,8 +150,13 @@ def _create_handler(self, auth): ) desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None + transport_type = self._client._config.transport_type # pylint:disable=protected-access + hostname = urlparse(source.address).hostname + if transport_type.name == 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' + self._handler = ReceiveClient( - urlparse(source.address).hostname, + hostname, source, auth=auth, idle_timeout=self._idle_timeout, @@ -164,7 +169,9 @@ def _create_handler(self, auth): properties=create_properties(self._client._config.user_agent), # pylint:disable=protected-access desired_capabilities=desired_capabilities, streaming_receive=True, - message_received_callback=self._message_received + message_received_callback=self._message_received, + transport_type=transport_type, + http_proxy=self._client._config.http_proxy # pylint:disable=protected-access ) def _open_with_retry(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index a72ce0753980..06ffd4733ee3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -125,8 +125,12 @@ def __init__(self, client, target, **kwargs): def _create_handler(self, auth): # type: (JWTTokenAuth) -> None + transport_type = self._client._config.transport_type # pylint:disable=protected-access + hostname = self._client._address.hostname # pylint: disable=protected-access + if transport_type.name == 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' self._handler = SendClient( - self._client._address.hostname, # pylint: disable=protected-access + hostname, # pylint: disable=protected-access self._target, auth=auth, idle_timeout=self._idle_timeout, @@ -136,6 +140,8 @@ def _create_handler(self, auth): client_name=self._name, link_properties=self._link_properties, properties=create_properties(self._client._config.user_agent), # pylint: disable=protected-access + transport_type=transport_type, + http_proxy=self._client._config.http_proxy # pylint: disable=protected-access ) def _open_with_retry(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index a26d220f3286..c73417d1e56f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -12,7 +12,7 @@ from ssl import SSLError from ._transport import Transport -from .sasl import SASLTransport +from .sasl import SASLTransport, SASLWithWebSocket from .session import Session from .performatives import OpenFrame, CloseFrame from .constants import ( @@ -22,7 +22,8 @@ MAX_FRAME_SIZE_BYTES, HEADER_FRAME, ConnectionState, - EMPTY_FRAME + EMPTY_FRAME, + TransportType ) from .error import ( @@ -77,12 +78,19 @@ class Connection(object): Default value is `0.1`. :keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames will be logged at the logging.INFO level. + :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. + Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy. + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings, + the transport_type would be AmqpOverWebSocket. + Additionally the following keys may also be present: `'username', 'password'`. """ def __init__(self, endpoint, **kwargs): # type(str, Any) -> None parsed_url = urlparse(endpoint) self._hostname = parsed_url.hostname + endpoint = self._hostname if parsed_url.port: self._port = parsed_url.port elif parsed_url.scheme == 'amqps': @@ -92,16 +100,21 @@ def __init__(self, endpoint, **kwargs): self.state = None # type: Optional[ConnectionState] transport = kwargs.get('transport') + self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) if transport: self._transport = transport elif 'sasl_credential' in kwargs: - self._transport = SASLTransport( - host=parsed_url.netloc, + sasl_transport = SASLTransport + if self._transport_type.name == 'AmqpOverWebsocket' or kwargs.get("http_proxy"): + sasl_transport = SASLWithWebSocket + endpoint = parsed_url.hostname + parsed_url.path + self._transport = sasl_transport( + host=endpoint, credential=kwargs['sasl_credential'], **kwargs ) else: - self._transport = Transport(parsed_url.netloc, **kwargs) + self._transport = Transport(parsed_url.netloc, transport_type=self._transport_type, **kwargs) self._container_id = kwargs.pop('container_id', None) or str(uuid.uuid4()) # type: str self._max_frame_size = kwargs.pop('max_frame_size', MAX_FRAME_SIZE_BYTES) # type: int diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 85371fdd07d9..1fe302998024 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -51,7 +51,7 @@ from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack from ._encode import encode_frame from ._decode import decode_frame, decode_empty_frame -from .constants import TLS_HEADER_FRAME +from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, TransportType, AMQP_WS_SUBPROTOCOL try: @@ -439,7 +439,7 @@ def write(self, s): def receive_frame(self, *args, **kwargs): try: - header, channel, payload = self.read(**kwargs) + header, channel, payload = self.read(**kwargs) if not payload: decoded = decode_empty_frame(header) else: @@ -646,12 +646,82 @@ def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)): result, self._read_buffer = rbuf[:n], rbuf[n:] return result - -def Transport(host, connect_timeout=None, ssl=False, **kwargs): +def Transport(host, transport_type, connect_timeout=None, ssl=False, **kwargs): """Create transport. Given a few parameters from the Connection constructor, select and create a subclass of _AbstractTransport. """ - transport = SSLTransport if ssl else TCPTransport + if transport_type == TransportType.AmqpOverWebsocket: + transport = WebSocketTransport + else: + transport = SSLTransport if ssl else TCPTransport return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + +class WebSocketTransport(_AbstractTransport): + def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs): + self.sslopts = ssl if isinstance(ssl, dict) else {} + self._connect_timeout = connect_timeout + self._host = host + super().__init__( + host, port, connect_timeout, **kwargs + ) + self.ws = None + self._http_proxy = kwargs.get('http_proxy', None) + + def connect(self): + http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None + if self._http_proxy: + http_proxy_host = self._http_proxy['proxy_hostname'] + http_proxy_port = self._http_proxy['proxy_port'] + username = self._http_proxy.get('username', None) + password = self._http_proxy.get('password', None) + if username or password: + http_proxy_auth = (username, password) + try: + from websocket import create_connection + self.ws = create_connection( + url="wss://{}".format(self._host), + subprotocols=[AMQP_WS_SUBPROTOCOL], + timeout=self._connect_timeout, + skip_utf8_validation=True, + sslopt=self.sslopts, + http_proxy_host=http_proxy_host, + http_proxy_port=http_proxy_port, + http_proxy_auth=http_proxy_auth + ) + except ImportError: + raise ValueError("Please install websocket-client library to use websocket transport.") + + def _read(self, n, initial=False, buffer=None, **kwargs): # pylint: disable=unused-arguments + """Read exactly n bytes from the peer.""" + + length = 0 + view = buffer or memoryview(bytearray(n)) + nbytes = self._read_buffer.readinto(view) + length += nbytes + n -= nbytes + while n: + data = self.ws.recv() + + if len(data) <= n: + view[length: length + len(data)] = data + n -= len(data) + else: + view[length: length + n] = data[0:n] + self._read_buffer = BytesIO(data[n:]) + n = 0 + + return view + + def _shutdown_transport(self): + """Do any preliminary work in shutting down the connection.""" + self.ws.close() + + def _write(self, s): + """Completely write a string to the peer. + ABNF, OPCODE_BINARY = 0x2 + See http://tools.ietf.org/html/rfc5234 + http://tools.ietf.org/html/rfc6455#section-5.2 + """ + self.ws.send_binary(s) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py index e1b88b192690..863285f7ca59 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py @@ -21,7 +21,6 @@ from ._receiver_async import ReceiverLink from ._sender_async import SenderLink from ._session_async import Session -from ._sasl_async import SASLTransport from ._cbs_async import CBSAuthenticator from ..client import AMQPClient as AMQPClientSync from ..client import ReceiveClient as ReceiveClientSync @@ -201,7 +200,9 @@ async def open_async(self): channel_max=self._channel_max, idle_timeout=self._idle_timeout, properties=self._properties, - network_trace=self._network_trace + network_trace=self._network_trace, + transport_type=self._transport_type, + http_proxy=self._http_proxy ) await self._connection.open() if not self._session: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index 3bfa62569e9a..790b02bae084 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -16,7 +16,7 @@ import asyncio from ._transport_async import AsyncTransport -from ._sasl_async import SASLTransport +from ._sasl_async import SASLTransport, SASLWithWebSocket from ._session_async import Session from ..performatives import OpenFrame, CloseFrame from .._connection import get_local_timeout @@ -27,7 +27,8 @@ MAX_CHANNELS, HEADER_FRAME, ConnectionState, - EMPTY_FRAME + EMPTY_FRAME, + TransportType ) from ..error import ( @@ -58,11 +59,19 @@ class Connection(object): :param list(str) offered_capabilities: The extension capabilities the sender supports. :param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports :param dict properties: Connection properties. + :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. + Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy. + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings, + the transport_type would be AmqpOverWebSocket. + Additionally the following keys may also be present: `'username', 'password'`. """ def __init__(self, endpoint, **kwargs): parsed_url = urlparse(endpoint) self.hostname = parsed_url.hostname + endpoint = self.hostname + self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) if parsed_url.port: self.port = parsed_url.port elif parsed_url.scheme == 'amqps': @@ -75,8 +84,12 @@ def __init__(self, endpoint, **kwargs): if transport: self.transport = transport elif 'sasl_credential' in kwargs: - self.transport = SASLTransport( - host=parsed_url.netloc, + sasl_transport = SASLTransport + if self._transport_type.name == "AmqpOverWebsocket" or kwargs.get("http_proxy"): + sasl_transport = SASLWithWebSocket + endpoint = parsed_url.hostname + parsed_url.path + self.transport = sasl_transport( + host=endpoint, credential=kwargs['sasl_credential'], **kwargs ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py index dda1931b909b..88ee25917c7c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py @@ -7,9 +7,9 @@ import struct from enum import Enum -from ._transport_async import AsyncTransport +from ._transport_async import AsyncTransport, WebSocketTransportAsync from ..types import AMQPTypes, TYPE, VALUE -from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME +from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME, WEBSOCKET_PORT, TransportType from .._transport import AMQPS_PORT from ..performatives import ( SASLOutcome, @@ -73,14 +73,8 @@ def start(self): return b'' -class SASLTransport(AsyncTransport): - - def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): - self.credential = credential - ssl = ssl or True - super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs) - - async def negotiate(self): +class SASLTransportMixinAsync(): + async def _negotiate(self): await self.write(SASL_HEADER_FRAME) _, returned_header = await self.receive_frame() if returned_header[1] != SASL_HEADER_FRAME: @@ -104,3 +98,35 @@ async def negotiate(self): return else: raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) + + +class SASLTransport(AsyncTransport, SASLTransportMixinAsync): + + def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): + self.credential = credential + ssl = ssl or True + super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + + async def negotiate(self): + await self._negotiate() + + +class SASLWithWebSocket(WebSocketTransportAsync, SASLTransportMixinAsync): + def __init__( + self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs + ): + self.credential = credential + ssl = ssl or True + http_proxy = kwargs.pop('http_proxy', None) + self._transport = WebSocketTransportAsync( + host, + port=port, + connect_timeout=connect_timeout, + ssl=ssl, + http_proxy=http_proxy, + **kwargs + ) + super().__init__(host, port, connect_timeout, ssl, **kwargs) + + async def negotiate(self): + await self._negotiate() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index acbdd8af8e76..7f586bec9e5e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -49,7 +49,7 @@ from .._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack from .._encode import encode_frame from .._decode import decode_frame, decode_empty_frame -from ..constants import TLS_HEADER_FRAME +from ..constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, AMQP_WS_SUBPROTOCOL from .._transport import ( AMQP_FRAME, get_errno, @@ -59,7 +59,8 @@ SIGNED_INT_MAX, _UNAVAIL, set_cloexec, - AMQP_PORT + AMQP_PORT, + WebSocketTransport ) @@ -82,7 +83,73 @@ def get_running_loop(): return loop -class AsyncTransport(object): +class AsyncTransportMixin(): + async def receive_frame(self, *args, **kwargs): + try: + header, channel, payload = await self.read(**kwargs) + if not payload: + decoded = decode_empty_frame(header) + else: + decoded = decode_frame(payload) + # TODO: Catch decode error and return amqp:decode-error + #_LOGGER.info("ICH%d <- %r", channel, decoded) + return channel, decoded + except (socket.timeout, asyncio.IncompleteReadError, asyncio.TimeoutError): + return None, None + + async def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? + async with self.socket_lock: + read_frame_buffer = BytesIO() + try: + frame_header = memoryview(bytearray(8)) + read_frame_buffer.write(await self._read(8, buffer=frame_header, initial=True)) + + channel = struct.unpack('>H', frame_header[6:])[0] + size = frame_header[0:4] + if size == AMQP_FRAME: # Empty frame or AMQP header negotiation + return frame_header, channel, None + size = struct.unpack('>I', size)[0] + offset = frame_header[4] + frame_type = frame_header[5] + + # >I is an unsigned int, but the argument to sock.recv is signed, + # so we know the size can be at most 2 * SIGNED_INT_MAX + payload_size = size - len(frame_header) + payload = memoryview(bytearray(payload_size)) + if size > SIGNED_INT_MAX: + read_frame_buffer.write(await self._read(SIGNED_INT_MAX, buffer=payload)) + read_frame_buffer.write(await self._read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:])) + else: + read_frame_buffer.write(await self._read(payload_size, buffer=payload)) + except (socket.timeout, asyncio.IncompleteReadError): + read_frame_buffer.write(self._read_buffer.getvalue()) + self._read_buffer = read_frame_buffer + self._read_buffer.seek(0) + raise + except (OSError, IOError, SSLError, socket.error) as exc: + # Don't disconnect for ssl read time outs + # http://bugs.python.org/issue10272 + if isinstance(exc, SSLError) and 'timed out' in str(exc): + raise socket.timeout() + if get_errno(exc) not in _UNAVAIL: + self.connected = False + raise + offset -= 2 + return frame_header, channel, payload[offset:] + + async def send_frame(self, channel, frame, **kwargs): + header, performative = encode_frame(frame, **kwargs) + if performative is None: + data = header + else: + encoded_channel = struct.pack('>H', channel) + data = header + encoded_channel + performative + + await self.write(data) + #_LOGGER.info("OCH%d -> %r", channel, frame) + + +class AsyncTransport(AsyncTransportMixin): """Common superclass for TCP and SSL transports.""" def __init__(self, host, port=AMQP_PORT, connect_timeout=None, @@ -318,46 +385,6 @@ def close(self): self.sock = None self.connected = False - async def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? - async with self.socket_lock: - read_frame_buffer = BytesIO() - try: - frame_header = memoryview(bytearray(8)) - read_frame_buffer.write(await self._read(8, buffer=frame_header, initial=True)) - - channel = struct.unpack('>H', frame_header[6:])[0] - size = frame_header[0:4] - if size == AMQP_FRAME: # Empty frame or AMQP header negotiation - return frame_header, channel, None - size = struct.unpack('>I', size)[0] - offset = frame_header[4] - frame_type = frame_header[5] - - # >I is an unsigned int, but the argument to sock.recv is signed, - # so we know the size can be at most 2 * SIGNED_INT_MAX - payload_size = size - len(frame_header) - payload = memoryview(bytearray(payload_size)) - if size > SIGNED_INT_MAX: - read_frame_buffer.write(await self._read(SIGNED_INT_MAX, buffer=payload)) - read_frame_buffer.write(await self._read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:])) - else: - read_frame_buffer.write(await self._read(payload_size, buffer=payload)) - except (socket.timeout, asyncio.IncompleteReadError): - read_frame_buffer.write(self._read_buffer.getvalue()) - self._read_buffer = read_frame_buffer - self._read_buffer.seek(0) - raise - except (OSError, IOError, SSLError, socket.error) as exc: - # Don't disconnect for ssl read time outs - # http://bugs.python.org/issue10272 - if isinstance(exc, SSLError) and 'timed out' in str(exc): - raise socket.timeout() - if get_errno(exc) not in _UNAVAIL: - self.connected = False - raise - offset -= 2 - return frame_header, channel, payload[offset:] - async def write(self, s): try: await self._write(s) @@ -368,19 +395,6 @@ async def write(self, s): self.connected = False raise - async def receive_frame(self, *args, **kwargs): - try: - header, channel, payload = await self.read(**kwargs) - if not payload: - decoded = decode_empty_frame(header) - else: - decoded = decode_frame(payload) - # TODO: Catch decode error and return amqp:decode-error - #_LOGGER.info("ICH%d <- %r", channel, decoded) - return channel, decoded - except (socket.timeout, asyncio.IncompleteReadError, asyncio.TimeoutError): - return None, None - async def receive_frame_with_lock(self, *args, **kwargs): try: async with self.socket_lock: @@ -393,17 +407,6 @@ async def receive_frame_with_lock(self, *args, **kwargs): except socket.timeout: return None, None - async def send_frame(self, channel, frame, **kwargs): - header, performative = encode_frame(frame, **kwargs) - if performative is None: - data = header - else: - encoded_channel = struct.pack('>H', channel) - data = header + encoded_channel + performative - - await self.write(data) - #_LOGGER.info("OCH%d -> %r", channel, frame) - async def negotiate(self): if not self.sslopts: return @@ -412,3 +415,81 @@ async def negotiate(self): if returned_header[1] == TLS_HEADER_FRAME: raise ValueError("Mismatching TLS header protocol. Excpected: {}, received: {}".format( TLS_HEADER_FRAME, returned_header[1])) + + +class WebSocketTransportAsync(AsyncTransportMixin): + def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs + ): + self._read_buffer = BytesIO() + self.loop = get_running_loop() + self.socket_lock = asyncio.Lock() + self.sslopts = ssl if isinstance(ssl, dict) else {} + self._connect_timeout = connect_timeout + self.host = host + self.ws = None + self._http_proxy = kwargs.get('http_proxy', None) + + async def connect(self): + http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None + if self._http_proxy: + http_proxy_host = self._http_proxy['proxy_hostname'] + http_proxy_port = self._http_proxy['proxy_port'] + username = self._http_proxy.get('username', None) + password = self._http_proxy.get('password', None) + if username or password: + http_proxy_auth = (username, password) + try: + from websocket import create_connection + self.ws = create_connection( + url="wss://{}".format(self.host), + subprotocols=[AMQP_WS_SUBPROTOCOL], + timeout=self._connect_timeout, + skip_utf8_validation=True, + sslopt=self.sslopts, + http_proxy_host=http_proxy_host, + http_proxy_port=http_proxy_port, + http_proxy_auth=http_proxy_auth + ) + except ImportError: + raise ValueError("Please install websocket-client library to use websocket transport.") + + async def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments + """Read exactly n bytes from the peer.""" + + length = 0 + view = buffer or memoryview(bytearray(n)) + nbytes = self._read_buffer.readinto(view) + length += nbytes + n -= nbytes + while n: + data = await self.loop.run_in_executor( + None, self.ws.recv + ) + + if len(data) <= n: + view[length: length + len(data)] = data + n -= len(data) + else: + view[length: length + n] = data[0:n] + self._read_buffer = BytesIO(data[n:]) + n = 0 + + return view + + def close(self): + """Do any preliminary work in shutting down the connection.""" + # TODO: async close doesn't: + # 1) shutdown socket and close. --> self.sock.shutdown(socket.SHUT_RDWR) and self.sock.close() + # 2) set self.connected = False + # I think we need to do this, like in sync + self.ws.close() + + async def write(self, s): + """Completely write a string to the peer. + ABNF, OPCODE_BINARY = 0x2 + See http://tools.ietf.org/html/rfc5234 + http://tools.ietf.org/html/rfc6455#section-5.2 + """ + await self.loop.run_in_executor( + None, self.ws.send_binary, s + ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index 09d3303a2698..25fbb125a4fc 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -6,9 +6,7 @@ # pylint: disable=too-many-lines -from collections import namedtuple import logging -import threading import time import uuid import certifi @@ -37,6 +35,7 @@ SenderSettleMode, ReceiverSettleMode, LinkDeliverySettleReason, + TransportType, SEND_DISPOSITION_ACCEPT, SEND_DISPOSITION_REJECT, AUTH_TYPE_CBS, @@ -155,6 +154,12 @@ def __init__(self, hostname, auth=None, **kwargs): self._receive_settle_mode = kwargs.pop('receive_settle_mode', ReceiverSettleMode.Second) self._desired_capabilities = kwargs.pop('desired_capabilities', None) + # transport + if kwargs.get('transport_type') is TransportType.Amqp and kwargs.get('http_proxy') is not None: + raise ValueError("Http proxy settings can't be passed if transport_type is explicitly set to Amqp") + self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) + self._http_proxy = kwargs.pop('http_proxy', None) + def __enter__(self): """Run Client in a context manager.""" self.open() @@ -240,7 +245,9 @@ def open(self): channel_max=self._channel_max, idle_timeout=self._idle_timeout, properties=self._properties, - network_trace=self._network_trace + network_trace=self._network_trace, + transport_type=self._transport_type, + http_proxy=self._http_proxy ) self._connection.open() if not self._session: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py index 0e60bbca7a56..7083c724c222 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py @@ -21,6 +21,14 @@ SECURE_PORT = 5671 +# default port for AMQP over Websocket +WEBSOCKET_PORT = 443 + + +# subprotocol for AMQP over Websocket +AMQP_WS_SUBPROTOCOL = 'AMQPWSB10' + + MAJOR = 1 #: Major protocol version. MINOR = 0 #: Minor protocol version. REV = 0 #: Protocol revision. @@ -302,3 +310,14 @@ class MessageDeliveryState(object): MessageDeliveryState.Timeout, MessageDeliveryState.Cancelled ) + + +class TransportType(Enum): + """Transport type + The underlying transport protocol type: + Amqp: AMQP over the default TCP transport protocol, it uses port 5671. + AmqpOverWebsocket: Amqp over the Web Sockets transport protocol, it uses + port 443. + """ + Amqp = 1 + AmqpOverWebsocket = 2 diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py index 99dd25d43730..7353a886b388 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py @@ -7,9 +7,9 @@ import struct from enum import Enum -from ._transport import SSLTransport, AMQPS_PORT +from ._transport import SSLTransport, WebSocketTransport, AMQPS_PORT from .types import AMQPTypes, TYPE, VALUE -from .constants import FIELD, SASLCode, SASL_HEADER_FRAME +from .constants import FIELD, SASLCode, SASL_HEADER_FRAME, TransportType, WEBSOCKET_PORT from .performatives import ( SASLOutcome, SASLResponse, @@ -69,7 +69,34 @@ def start(self): return b'' -class SASLTransport(SSLTransport): +class SASLTransportMixin(): + def _negotiate(self): + self.write(SASL_HEADER_FRAME) + _, returned_header = self.receive_frame() + if returned_header[1] != SASL_HEADER_FRAME: + raise ValueError("Mismatching AMQP header protocol. Expected: {}, received: {}".format( + SASL_HEADER_FRAME, returned_header[1])) + + _, supported_mechansisms = self.receive_frame(verify_frame_type=1) + if self.credential.mechanism not in supported_mechansisms[1][0]: # sasl_server_mechanisms + raise ValueError("Unsupported SASL credential type: {}".format(self.credential.mechanism)) + sasl_init = SASLInit( + mechanism=self.credential.mechanism, + initial_response=self.credential.start(), + hostname=self.host) + self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) + + _, next_frame = self.receive_frame(verify_frame_type=1) + frame_type, fields = next_frame + if frame_type != 0x00000044: # SASLOutcome + raise NotImplementedError("Unsupported SASL challenge") + if fields[0] == SASLCode.Ok: # code + return + else: + raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) + + +class SASLTransport(SSLTransport, SASLTransportMixin): def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): self.credential = credential @@ -78,26 +105,23 @@ def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl= def negotiate(self): with self.block(): - self.write(SASL_HEADER_FRAME) - _, returned_header = self.receive_frame() - if returned_header[1] != SASL_HEADER_FRAME: - raise ValueError("Mismatching AMQP header protocol. Expected: {}, received: {}".format( - SASL_HEADER_FRAME, returned_header[1])) - - _, supported_mechansisms = self.receive_frame(verify_frame_type=1) - if self.credential.mechanism not in supported_mechansisms[1][0]: # sasl_server_mechanisms - raise ValueError("Unsupported SASL credential type: {}".format(self.credential.mechanism)) - sasl_init = SASLInit( - mechanism=self.credential.mechanism, - initial_response=self.credential.start(), - hostname=self.host) - self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) - - _, next_frame = self.receive_frame(verify_frame_type=1) - frame_type, fields = next_frame - if frame_type != 0x00000044: # SASLOutcome - raise NotImplementedError("Unsupported SASL challenge") - if fields[0] == SASLCode.Ok: # code - return - else: - raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) + self._negotiate() + +class SASLWithWebSocket(WebSocketTransport, SASLTransportMixin): + + def __init__(self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs): + self.credential = credential + ssl = ssl or True + http_proxy = kwargs.pop('http_proxy', None) + self._transport = WebSocketTransport( + host, + port=port, + connect_timeout=connect_timeout, + ssl=ssl, + http_proxy=http_proxy, + **kwargs + ) + super().__init__(host, port, connect_timeout, ssl, **kwargs) + + def negotiate(self): + self._negotiate() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py index 440fcc69d1d5..03c76c0832af 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py @@ -3,4 +3,4 @@ # Licensed under the MIT License. # ------------------------------------ -VERSION = "5.8.0b4" +VERSION = "5.8.0a4" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py index c03b510b02b2..f5606c717965 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py @@ -302,8 +302,15 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> last_exception = None while retried_times <= self._config.max_retries: mgmt_auth = await self._create_auth_async() + hostname = self._address.hostname + if self._config.transport_type.name == 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' mgmt_client = AMQPClientAsync( - self._address.hostname, auth=mgmt_auth, debug=self._config.network_tracing + hostname, + auth=mgmt_auth, + debug=self._config.network_tracing, + transport_type=self._config.transport_type, + http_proxy=self._config.http_proxy ) try: await mgmt_client.open_async() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index d5be74195636..d02849d746fd 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -144,8 +144,13 @@ def _create_handler(self, auth: "JWTTokenAuthAsync") -> None: ) desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None + hostname = urlparse(source.address).hostname + transport_type = self._client._config.transport_type # pylint:disable=protected-access + if transport_type.name == 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' + self._handler = ReceiveClientAsync( - urlparse(source.address).hostname, + hostname, source, auth=auth, idle_timeout=self._idle_timeout, @@ -158,7 +163,9 @@ def _create_handler(self, auth: "JWTTokenAuthAsync") -> None: properties=create_properties(self._client._config.user_agent), # pylint:disable=protected-access desired_capabilities=desired_capabilities, streaming_receive=True, - message_received_callback=self._message_received + message_received_callback=self._message_received, + transport_type = transport_type, + http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access ) async def _open_with_retry(self) -> None: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index 18467d15a416..bf8fcc0ccb94 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -99,8 +99,12 @@ def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> N self._link_properties = {TIMEOUT_SYMBOL: pyamqp_utils.amqp_long_value(int(self._timeout * 1000))} def _create_handler(self, auth: "JWTTokenAsync") -> None: + hostname = self._client._address.hostname # pylint: disable=protected-access + transport_type = self._client._config.transport_type # pylint:disable=protected-access + if transport_type.name == 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' self._handler = SendClientAsync( - self._client._address.hostname, # pylint: disable=protected-access + hostname, self._target, auth=auth, idle_timeout=self._idle_timeout, @@ -110,6 +114,8 @@ def _create_handler(self, auth: "JWTTokenAsync") -> None: client_name=self._name, link_properties=self._link_properties, properties=create_properties(self._client._config.user_agent), # pylint: disable=protected-access + transport_type=transport_type, + http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access **self._internal_kwargs ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index 8f5ffe61528c..fdb350107c46 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -185,11 +185,9 @@ def from_connection_string( *, eventhub_name: Optional[str] = None, logging_enable: bool = False, - http_proxy: Optional[Dict[str, Union[str, int]]] = None, auth_timeout: float = 60, user_agent: Optional[str] = None, retry_total: int = 3, - transport_type: Optional["TransportType"] = None, **kwargs: Any ) -> "EventHubProducerClient": """Create an EventHubProducerClient from a connection string. @@ -246,11 +244,9 @@ def from_connection_string( conn_str, eventhub_name=eventhub_name, logging_enable=logging_enable, - http_proxy=http_proxy, auth_timeout=auth_timeout, user_agent=user_agent, retry_total=retry_total, - transport_type=transport_type, **kwargs ) return cls(**constructor_args) diff --git a/sdk/eventhub/azure-eventhub/dev_requirements.txt b/sdk/eventhub/azure-eventhub/dev_requirements.txt index df47262912ac..9c91833e14d8 100644 --- a/sdk/eventhub/azure-eventhub/dev_requirements.txt +++ b/sdk/eventhub/azure-eventhub/dev_requirements.txt @@ -4,5 +4,6 @@ azure-mgmt-eventhub==10.0.0 azure-mgmt-resource==20.0.0 aiohttp>=3.0 +websocket-client -e ../../../tools/azure-devtools -e ../../servicebus/azure-servicebus \ No newline at end of file diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py index e3560f6e7e2f..016b1bdc9c86 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py @@ -273,7 +273,6 @@ async def test_send_multiple_partition_with_app_prop_async(connstr_receivers): @pytest.mark.liveTest @pytest.mark.asyncio async def test_send_over_websocket_async(connstr_receivers): - pytest.skip("websocket unsupported") connection_str, receivers = connstr_receivers client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py index 8da5ddeb6ead..39b0f06f4d3d 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py @@ -133,7 +133,7 @@ def on_event_batch(partition_context, event_batch): worker = threading.Thread(target=client.receive_batch, args=(on_event_batch,), kwargs={"starting_position": "-1"}) worker.start() - time.sleep(10) + time.sleep(20) assert on_event_batch.received == 2 checkpoints = list(client._event_processors.values())[0]._checkpoint_store.list_checkpoints( diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index f9ea0c55c773..9b484340855c 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -218,9 +218,22 @@ def test_send_partition(connstr_receivers): client.send_batch(batch) partition_0 = receivers[0].receive_message_batch(timeout=5) - assert len(partition_0) == 0 partition_1 = receivers[1].receive_message_batch(timeout=5) - assert len(partition_1) == 1 + assert len(partition_0) + len(partition_1) == 2 + + with client: + batch = client.create_batch() + batch.add(EventData(b"Data")) + client.send_batch(batch) + + with client: + batch = client.create_batch(partition_id="1") + batch.add(EventData(b"Data")) + client.send_batch(batch) + + partition_0 = receivers[0].receive_message_batch(timeout=5) + partition_1 = receivers[1].receive_message_batch(timeout=5) + assert len(partition_0) + len(partition_1) == 2 @pytest.mark.liveTest @@ -273,7 +286,6 @@ def test_send_multiple_partitions_with_app_prop(connstr_receivers): @pytest.mark.liveTest def test_send_over_websocket_sync(connstr_receivers): - pytest.skip("websocket not supported") connection_str, receivers = connstr_receivers client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket)