diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index 7d105b3b0fba..1af1b168b856 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,9 +1,11 @@ # Release History -## 5.8.0b4 (Unreleased) +## 5.8.0a4 (Unreleased) ### Features Added +- Added suppport for connection using websocket and http proxy. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index fdd8c7f297bd..85c57dbe7b81 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -324,8 +324,6 @@ def _create_auth(self): functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, custom_endpoint_hostname=self._config.custom_endpoint_hostname, port=self._config.connection_port, verify=self._config.connection_verify, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index ddb9a14a166f..31c6e8cda89c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -136,6 +136,10 @@ def __init__(self, client, source, **kwargs): def _create_handler(self, auth): # type: (JWTTokenAuth) -> None + transport_type = self._client._config.transport_type, # pylint:disable=protected-access + hostname = urlparse(source.address).hostname + if transport_type.name is 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' source = Source(address=self._source, filters={}) if self._offset is not None: filter_key = ApacheFilters.selector_filter @@ -151,11 +155,13 @@ def _create_handler(self, auth): desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None self._handler = ReceiveClient( - urlparse(source.address).hostname, + hostname, source, auth=auth, idle_timeout=self._idle_timeout, network_trace=self._client._config.network_tracing, # pylint:disable=protected-access + transport_type=transport_type, + http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access link_credit=self._prefetch, link_properties=self._link_properties, retry_policy=self._retry_policy, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index a72ce0753980..2867b67cea03 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -125,12 +125,18 @@ def __init__(self, client, target, **kwargs): def _create_handler(self, auth): # type: (JWTTokenAuth) -> None + transport_type=self._client._config.transport_type, # pylint:disable=protected-access + hostname = self._client._address.hostname, # pylint: disable=protected-access + if transport_type.name is 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' self._handler = SendClient( - self._client._address.hostname, # pylint: disable=protected-access + hostname, self._target, auth=auth, idle_timeout=self._idle_timeout, - network_trace=self._client._config.network_tracing, # pylint: disable=protected-access + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access + transport_type=transport_type, + http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index a26d220f3286..cc84d6870e02 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -12,7 +12,7 @@ from ssl import SSLError from ._transport import Transport -from .sasl import SASLTransport +from .sasl import SASLTransport, SASLWithWebSocket from .session import Session from .performatives import OpenFrame, CloseFrame from .constants import ( @@ -22,7 +22,8 @@ MAX_FRAME_SIZE_BYTES, HEADER_FRAME, ConnectionState, - EMPTY_FRAME + EMPTY_FRAME, + TransportType ) from .error import ( @@ -77,12 +78,19 @@ class Connection(object): Default value is `0.1`. :keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames will be logged at the logging.INFO level. + :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. + Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy. + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings, + the transport_type would be AmqpOverWebSocket. + Additionally the following keys may also be present: `'username', 'password'`. """ def __init__(self, endpoint, **kwargs): # type(str, Any) -> None parsed_url = urlparse(endpoint) self._hostname = parsed_url.hostname + endpoint = self._hostname if parsed_url.port: self._port = parsed_url.port elif parsed_url.scheme == 'amqps': @@ -92,16 +100,21 @@ def __init__(self, endpoint, **kwargs): self.state = None # type: Optional[ConnectionState] transport = kwargs.get('transport') + self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) if transport: self._transport = transport elif 'sasl_credential' in kwargs: - self._transport = SASLTransport( - host=parsed_url.netloc, + sasl_transport = SASLTransport + if self._transport_type.name is 'AmqpOverWebsocket' or kwargs.get("http_proxy"): + sasl_transport = SASLWithWebSocket + endpoint = parsed_url.hostname + parsed_url.path + self._transport = sasl_transport( + host=endpoint, credential=kwargs['sasl_credential'], **kwargs ) else: - self._transport = Transport(parsed_url.netloc, **kwargs) + self._transport = Transport(parsed_url.netloc, self._transport_type, **kwargs) self._container_id = kwargs.pop('container_id', None) or str(uuid.uuid4()) # type: str self._max_frame_size = kwargs.pop('max_frame_size', MAX_FRAME_SIZE_BYTES) # type: int diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 85371fdd07d9..29e506177cd3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -51,7 +51,7 @@ from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack from ._encode import encode_frame from ._decode import decode_frame, decode_empty_frame -from .constants import TLS_HEADER_FRAME +from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, TransportType, AMQP_WS_SUBPROTOCOL try: @@ -456,7 +456,6 @@ def send_frame(self, channel, frame, **kwargs): else: encoded_channel = struct.pack('>H', channel) data = header + encoded_channel + performative - self.write(data) def negotiate(self, encode, decode): @@ -647,11 +646,82 @@ def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)): return result -def Transport(host, connect_timeout=None, ssl=False, **kwargs): +def Transport(host, transport_type, connect_timeout=None, ssl=False, **kwargs): """Create transport. Given a few parameters from the Connection constructor, select and create a subclass of _AbstractTransport. """ - transport = SSLTransport if ssl else TCPTransport + if transport_type == TransportType.AmqpOverWebsocket: + transport = WebSocketTransport + else: + transport = SSLTransport if ssl else TCPTransport return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + +class WebSocketTransport(_AbstractTransport): + def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs + ): + self.sslopts = ssl if isinstance(ssl, dict) else {} + self._connect_timeout = connect_timeout + self._host = host + super().__init__( + host, port, connect_timeout, **kwargs + ) + self.ws = None + self._http_proxy = kwargs.get('http_proxy', None) + + def connect(self): + http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None + if self._http_proxy: + http_proxy_host = self._http_proxy['proxy_hostname'] + http_proxy_port = self._http_proxy['proxy_port'] + username = self._http_proxy.get('username', None) + password = self._http_proxy.get('password', None) + if username or password: + http_proxy_auth = (username, password) + try: + from websocket import create_connection + self.ws = create_connection( + url="wss://{}".format(self._host), + subprotocols=[AMQP_WS_SUBPROTOCOL], + timeout=self._connect_timeout, + skip_utf8_validation=True, + sslopt=self.sslopts, + http_proxy_host=http_proxy_host, + http_proxy_port=http_proxy_port, + http_proxy_auth=http_proxy_auth + ) + except ImportError: + raise ValueError("Please install websocket-client library to use websocket transport.") + + def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments + """Read exactly n bytes from the peer.""" + + length = 0 + view = buffer or memoryview(bytearray(n)) + nbytes = self._read_buffer.readinto(view) + length += nbytes + n -= nbytes + while n: + data = self.ws.recv() + + if len(data) <= n: + view[length: length + len(data)] = data + n -= len(data) + else: + view[length: length + n] = data[0:n] + self._read_buffer = BytesIO(data[n:]) + n = 0 + return view + + def _shutdown_transport(self): + """Do any preliminary work in shutting down the connection.""" + self.ws.close() + + def _write(self, s): + """Completely write a string to the peer. + ABNF, OPCODE_BINARY = 0x2 + See http://tools.ietf.org/html/rfc5234 + http://tools.ietf.org/html/rfc6455#section-5.2 + """ + self.ws.send_binary(s) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py index e1b88b192690..957588d2a921 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py @@ -176,6 +176,23 @@ async def _do_retryable_operation_async(self, operation, *args, **kwargs): absolute_timeout -= (end_time - start_time) raise retry_settings['history'][-1] + async def _keep_alive_worker_async(self): + interval = 10 if self._keep_alive is True else self._keep_alive + start_time = time.time() + try: + while self._connection and not self._shutdown: + current_time = time.time() + elapsed_time = (current_time - start_time) + if elapsed_time >= interval: + _logger.info("Keeping %r connection alive. %r", + self.__class__.__name__, + self._connection._container_id) + await self._connection._get_remote_timeout(current_time) + start_time = current_time + await asyncio.sleep(1) + except Exception as e: # pylint: disable=broad-except + _logger.info("Connection keep-alive for %r failed: %r.", self.__class__.__name__, e) + async def open_async(self): """Asynchronously open the client. The client can create a new Connection or an existing Connection can be passed in. This existing Connection @@ -200,6 +217,8 @@ async def open_async(self): max_frame_size=self._max_frame_size, channel_max=self._channel_max, idle_timeout=self._idle_timeout, + transport_type=self._transport_type, + http_proxy=self._http_proxy, properties=self._properties, network_trace=self._network_trace ) @@ -217,6 +236,8 @@ async def open_async(self): auth_timeout=self._auth_timeout ) await self._cbs_authenticator.open() + if self._keep_alive: + self._keep_alive_thread = asyncio.ensure_future(self._keep_alive_worker_async()) self._shutdown = False async def close_async(self): @@ -228,6 +249,9 @@ async def close_async(self): self._shutdown = True if not self._session: return # already closed. + if self._keep_alive_thread: + await self._keep_alive_thread + self._keep_alive_thread = None await self._close_link_async(close=True) if self._cbs_authenticator: await self._cbs_authenticator.close() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index 3bfa62569e9a..7b69e1f2e64e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -16,7 +16,7 @@ import asyncio from ._transport_async import AsyncTransport -from ._sasl_async import SASLTransport +from ._sasl_async import SASLTransport, SASLWithWebSocket from ._session_async import Session from ..performatives import OpenFrame, CloseFrame from .._connection import get_local_timeout @@ -27,7 +27,8 @@ MAX_CHANNELS, HEADER_FRAME, ConnectionState, - EMPTY_FRAME + EMPTY_FRAME, + TransportType ) from ..error import ( @@ -58,11 +59,19 @@ class Connection(object): :param list(str) offered_capabilities: The extension capabilities the sender supports. :param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports :param dict properties: Connection properties. + :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. + Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy. + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings, + the transport_type would be AmqpOverWebSocket. + Additionally the following keys may also be present: `'username', 'password'`. """ def __init__(self, endpoint, **kwargs): parsed_url = urlparse(endpoint) self.hostname = parsed_url.hostname + endpoint = self._hostname + self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) if parsed_url.port: self.port = parsed_url.port elif parsed_url.scheme == 'amqps': @@ -70,13 +79,16 @@ def __init__(self, endpoint, **kwargs): else: self.port = PORT self.state = None - transport = kwargs.get('transport') if transport: self.transport = transport elif 'sasl_credential' in kwargs: - self.transport = SASLTransport( - host=parsed_url.netloc, + sasl_transport = SASLTransport + if self._transport_type.name is 'AmqpOverWebsocket' or kwargs.get("http_proxy"): + sasl_transport = SASLWithWebSocket + endpoint = parsed_url.hostname + parsed_url.path + self._transport = sasl_transport( + host=endpoint, credential=kwargs['sasl_credential'], **kwargs ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py index dda1931b909b..014681787c27 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py @@ -7,9 +7,9 @@ import struct from enum import Enum -from ._transport_async import AsyncTransport +from ._transport_async import AsyncTransport, WebSocketTransportAsync from ..types import AMQPTypes, TYPE, VALUE -from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME +from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME, WEBSOCKET_PORT, TransportType from .._transport import AMQPS_PORT from ..performatives import ( SASLOutcome, @@ -72,14 +72,7 @@ class SASLExternalCredential(object): def start(self): return b'' - -class SASLTransport(AsyncTransport): - - def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): - self.credential = credential - ssl = ssl or True - super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs) - +class SASLTransportMixinAsync(): async def negotiate(self): await self.write(SASL_HEADER_FRAME) _, returned_header = await self.receive_frame() @@ -104,3 +97,26 @@ async def negotiate(self): return else: raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) + +class SASLTransport(AsyncTransport, SASLTransportMixinAsync): + def __init__(self, host, credential, connect_timeout=None, ssl=None, **kwargs): + self.credential = credential + ssl = ssl or True + super(SASLTransport, self).__init__(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + +class SASLWithWebSocket(WebSocketTransportAsync, SASLTransportMixinAsync): + def __init__( + self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs + ): # pylint: disable=super-init-not-called + self.credential = credential + ssl = ssl or True + http_proxy = kwargs.pop('http_proxy', None) + self._transport = WebSocketTransportAsync( + host, + port=port, + connect_timeout=connect_timeout, + ssl=ssl, + http_proxy=http_proxy, + **kwargs + ) + super().__init__(host, port, connect_timeout, ssl, **kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index acbdd8af8e76..39d09213eba3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -49,7 +49,7 @@ from .._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack from .._encode import encode_frame from .._decode import decode_frame, decode_empty_frame -from ..constants import TLS_HEADER_FRAME +from ..constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, AMQP_WS_SUBPROTOCOL from .._transport import ( AMQP_FRAME, get_errno, @@ -59,7 +59,8 @@ SIGNED_INT_MAX, _UNAVAIL, set_cloexec, - AMQP_PORT + AMQP_PORT, + WebSocketTransport ) @@ -78,11 +79,75 @@ def get_running_loop(): _LOGGER.warning('This version of Python is deprecated, please upgrade to >= v3.6') if loop is None: _LOGGER.warning('No running event loop') - loop = asyncio.get_event_loop() + loop = self.loop return loop +class AsyncTransportMixin(): + async def receive_frame(self, *args, **kwargs): + try: + header, channel, payload = await self.read(**kwargs) + if not payload: + decoded = decode_empty_frame(header) + else: + decoded = decode_frame(payload) + # TODO: Catch decode error and return amqp:decode-error + #_LOGGER.info("ICH%d <- %r", channel, decoded) + return channel, decoded + except (socket.timeout, asyncio.IncompleteReadError, asyncio.TimeoutError): + return None, None + + async def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? + async with self.socket_lock: + read_frame_buffer = BytesIO() + try: + frame_header = memoryview(bytearray(8)) + read_frame_buffer.write(await self._read(8, buffer=frame_header, initial=True)) + + channel = struct.unpack('>H', frame_header[6:])[0] + size = frame_header[0:4] + if size == AMQP_FRAME: # Empty frame or AMQP header negotiation + return frame_header, channel, None + size = struct.unpack('>I', size)[0] + offset = frame_header[4] + frame_type = frame_header[5] + + # >I is an unsigned int, but the argument to sock.recv is signed, + # so we know the size can be at most 2 * SIGNED_INT_MAX + payload_size = size - len(frame_header) + payload = memoryview(bytearray(payload_size)) + if size > SIGNED_INT_MAX: + read_frame_buffer.write(await self._read(SIGNED_INT_MAX, buffer=payload)) + read_frame_buffer.write(await self._read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:])) + else: + read_frame_buffer.write(await self._read(payload_size, buffer=payload)) + except (socket.timeout, asyncio.IncompleteReadError): + read_frame_buffer.write(self._read_buffer.getvalue()) + self._read_buffer = read_frame_buffer + self._read_buffer.seek(0) + raise + except (OSError, IOError, SSLError, socket.error) as exc: + # Don't disconnect for ssl read time outs + # http://bugs.python.org/issue10272 + if isinstance(exc, SSLError) and 'timed out' in str(exc): + raise socket.timeout() + if get_errno(exc) not in _UNAVAIL: + self.connected = False + raise + offset -= 2 + return frame_header, channel, payload[offset:] -class AsyncTransport(object): + async def send_frame(self, channel, frame, **kwargs): + header, performative = encode_frame(frame, **kwargs) + if performative is None: + data = header + else: + encoded_channel = struct.pack('>H', channel) + data = header + encoded_channel + performative + + await self.write(data) + #_LOGGER.info("OCH%d -> %r", channel, frame) + +class AsyncTransport(AsyncTransportMixin): """Common superclass for TCP and SSL transports.""" def __init__(self, host, port=AMQP_PORT, connect_timeout=None, @@ -95,7 +160,6 @@ def __init__(self, host, port=AMQP_PORT, connect_timeout=None, self.raise_on_initial_eintr = raise_on_initial_eintr self._read_buffer = BytesIO() self.host, self.port = to_host_port(host, port) - self.connect_timeout = connect_timeout self.read_timeout = read_timeout self.write_timeout = write_timeout @@ -318,46 +382,6 @@ def close(self): self.sock = None self.connected = False - async def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? - async with self.socket_lock: - read_frame_buffer = BytesIO() - try: - frame_header = memoryview(bytearray(8)) - read_frame_buffer.write(await self._read(8, buffer=frame_header, initial=True)) - - channel = struct.unpack('>H', frame_header[6:])[0] - size = frame_header[0:4] - if size == AMQP_FRAME: # Empty frame or AMQP header negotiation - return frame_header, channel, None - size = struct.unpack('>I', size)[0] - offset = frame_header[4] - frame_type = frame_header[5] - - # >I is an unsigned int, but the argument to sock.recv is signed, - # so we know the size can be at most 2 * SIGNED_INT_MAX - payload_size = size - len(frame_header) - payload = memoryview(bytearray(payload_size)) - if size > SIGNED_INT_MAX: - read_frame_buffer.write(await self._read(SIGNED_INT_MAX, buffer=payload)) - read_frame_buffer.write(await self._read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:])) - else: - read_frame_buffer.write(await self._read(payload_size, buffer=payload)) - except (socket.timeout, asyncio.IncompleteReadError): - read_frame_buffer.write(self._read_buffer.getvalue()) - self._read_buffer = read_frame_buffer - self._read_buffer.seek(0) - raise - except (OSError, IOError, SSLError, socket.error) as exc: - # Don't disconnect for ssl read time outs - # http://bugs.python.org/issue10272 - if isinstance(exc, SSLError) and 'timed out' in str(exc): - raise socket.timeout() - if get_errno(exc) not in _UNAVAIL: - self.connected = False - raise - offset -= 2 - return frame_header, channel, payload[offset:] - async def write(self, s): try: await self._write(s) @@ -368,19 +392,6 @@ async def write(self, s): self.connected = False raise - async def receive_frame(self, *args, **kwargs): - try: - header, channel, payload = await self.read(**kwargs) - if not payload: - decoded = decode_empty_frame(header) - else: - decoded = decode_frame(payload) - # TODO: Catch decode error and return amqp:decode-error - #_LOGGER.info("ICH%d <- %r", channel, decoded) - return channel, decoded - except (socket.timeout, asyncio.IncompleteReadError, asyncio.TimeoutError): - return None, None - async def receive_frame_with_lock(self, *args, **kwargs): try: async with self.socket_lock: @@ -393,17 +404,6 @@ async def receive_frame_with_lock(self, *args, **kwargs): except socket.timeout: return None, None - async def send_frame(self, channel, frame, **kwargs): - header, performative = encode_frame(frame, **kwargs) - if performative is None: - data = header - else: - encoded_channel = struct.pack('>H', channel) - data = header + encoded_channel + performative - - await self.write(data) - #_LOGGER.info("OCH%d -> %r", channel, frame) - async def negotiate(self): if not self.sslopts: return @@ -412,3 +412,76 @@ async def negotiate(self): if returned_header[1] == TLS_HEADER_FRAME: raise ValueError("Mismatching TLS header protocol. Excpected: {}, received: {}".format( TLS_HEADER_FRAME, returned_header[1])) + + +class WebSocketTransportAsync(AsyncTransportMixin): + def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs + ): + self._read_buffer = BytesIO() + self.loop = get_running_loop() + self.socket_lock = asyncio.Lock() + self.sslopts = ssl if isinstance(ssl, dict) else {} + self._connect_timeout = connect_timeout + self.host = host + self.ws = None + self._http_proxy = kwargs.get('http_proxy', None) + + async def connect(self): + http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None + if self._http_proxy: + http_proxy_host = self._http_proxy['proxy_hostname'] + http_proxy_port = self._http_proxy['proxy_port'] + username = self._http_proxy.get('username', None) + password = self._http_proxy.get('password', None) + if username or password: + http_proxy_auth = (username, password) + try: + from websocket import create_connection + self.ws = create_connection( + url="wss://{}".format(self.host), + subprotocols=[AMQP_WS_SUBPROTOCOL], + timeout=self._connect_timeout, + skip_utf8_validation=True, + sslopt=self.sslopts, + http_proxy_host=http_proxy_host, + http_proxy_port=http_proxy_port, + http_proxy_auth=http_proxy_auth + ) + except ImportError: + raise ValueError("Please install websocket-client library to use websocket transport.") + + async def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments + """Read exactly n bytes from the peer.""" + + length = 0 + view = buffer or memoryview(bytearray(n)) + nbytes = self._read_buffer.readinto(view) + length += nbytes + n -= nbytes + while n: + data = await self.loop.run_in_executor( + None, self.ws.recv + ) + + if len(data) <= n: + view[length: length + len(data)] = data + n -= len(data) + else: + view[length: length + n] = data[0:n] + self._read_buffer = BytesIO(data[n:]) + n = 0 + return view + + def close(self): + """Do any preliminary work in shutting down the connection.""" + self.ws.close() + + async def write(self, s): + """Completely write a string to the peer. + ABNF, OPCODE_BINARY = 0x2 + See http://tools.ietf.org/html/rfc5234 + http://tools.ietf.org/html/rfc6455#section-5.2 + """ + await self.loop.run_in_executor( + None, self.ws.send_binary, s + ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index 09d3303a2698..2b6c06070347 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -37,6 +37,7 @@ SenderSettleMode, ReceiverSettleMode, LinkDeliverySettleReason, + TransportType, SEND_DISPOSITION_ACCEPT, SEND_DISPOSITION_REJECT, AUTH_TYPE_CBS, @@ -155,6 +156,12 @@ def __init__(self, hostname, auth=None, **kwargs): self._receive_settle_mode = kwargs.pop('receive_settle_mode', ReceiverSettleMode.Second) self._desired_capabilities = kwargs.pop('desired_capabilities', None) + # transport + if kwargs.get('transport_type') is TransportType.Amqp and kwargs.get('http_proxy') is not None: + raise ValueError("Http proxy settings can't be passed if transport_type is explicitly set to Amqp") + self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) + self._http_proxy = kwargs.pop('http_proxy', None) + def __enter__(self): """Run Client in a context manager.""" self.open() @@ -240,7 +247,9 @@ def open(self): channel_max=self._channel_max, idle_timeout=self._idle_timeout, properties=self._properties, - network_trace=self._network_trace + network_trace=self._network_trace, + transport_type=self._transport_type, + http_proxy=self._http_proxy ) self._connection.open() if not self._session: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py index 0e60bbca7a56..66e4ff1ae327 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py @@ -14,6 +14,11 @@ #: The port number is reserved for future transport mappings to these protocols. PORT = 5672 +# default port for AMQP over Websocket +WEBSOCKET_PORT = 443 + +# subprotocol for AMQP over Websocket +AMQP_WS_SUBPROTOCOL = 'AMQPWSB10' #: The IANA assigned port number for secure AMQP (amqps).The standard AMQP port number that has been assigned #: by IANA for secure TCP using TLS. Implementations listening on this port should NOT expect a protocol @@ -302,3 +307,13 @@ class MessageDeliveryState(object): MessageDeliveryState.Timeout, MessageDeliveryState.Cancelled ) + +class TransportType(Enum): + """Transport type + The underlying transport protocol type: + Amqp: AMQP over the default TCP transport protocol, it uses port 5671. + AmqpOverWebsocket: Amqp over the Web Sockets transport protocol, it uses + port 443. + """ + Amqp = 1 + AmqpOverWebsocket = 2 diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py index 99dd25d43730..51848304bfae 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py @@ -7,9 +7,9 @@ import struct from enum import Enum -from ._transport import SSLTransport, AMQPS_PORT +from ._transport import SSLTransport, WebSocketTransport, AMQPS_PORT from .types import AMQPTypes, TYPE, VALUE -from .constants import FIELD, SASLCode, SASL_HEADER_FRAME +from .constants import FIELD, SASLCode, SASL_HEADER_FRAME, TransportType, WEBSOCKET_PORT from .performatives import ( SASLOutcome, SASLResponse, @@ -68,8 +68,33 @@ class SASLExternalCredential(object): def start(self): return b'' +class SASLTransportMixin(): + def _negotiate(self): + self.write(SASL_HEADER_FRAME) + _, returned_header = self.receive_frame() + if returned_header[1] != SASL_HEADER_FRAME: + raise ValueError("Mismatching AMQP header protocol. Expected: {}, received: {}".format( + SASL_HEADER_FRAME, returned_header[1])) + + _, supported_mechansisms = self.receive_frame(verify_frame_type=1) + if self.credential.mechanism not in supported_mechansisms[1][0]: # sasl_server_mechanisms + raise ValueError("Unsupported SASL credential type: {}".format(self.credential.mechanism)) + sasl_init = SASLInit( + mechanism=self.credential.mechanism, + initial_response=self.credential.start(), + hostname=self.host) + self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) + + _, next_frame = self.receive_frame(verify_frame_type=1) + frame_type, fields = next_frame + if frame_type != 0x00000044: # SASLOutcome + raise NotImplementedError("Unsupported SASL challenge") + if fields[0] == SASLCode.Ok: # code + return + else: + raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) -class SASLTransport(SSLTransport): +class SASLTransport(SSLTransport, SASLTransportMixin): def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): self.credential = credential @@ -78,26 +103,23 @@ def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl= def negotiate(self): with self.block(): - self.write(SASL_HEADER_FRAME) - _, returned_header = self.receive_frame() - if returned_header[1] != SASL_HEADER_FRAME: - raise ValueError("Mismatching AMQP header protocol. Expected: {}, received: {}".format( - SASL_HEADER_FRAME, returned_header[1])) - - _, supported_mechansisms = self.receive_frame(verify_frame_type=1) - if self.credential.mechanism not in supported_mechansisms[1][0]: # sasl_server_mechanisms - raise ValueError("Unsupported SASL credential type: {}".format(self.credential.mechanism)) - sasl_init = SASLInit( - mechanism=self.credential.mechanism, - initial_response=self.credential.start(), - hostname=self.host) - self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) - - _, next_frame = self.receive_frame(verify_frame_type=1) - frame_type, fields = next_frame - if frame_type != 0x00000044: # SASLOutcome - raise NotImplementedError("Unsupported SASL challenge") - if fields[0] == SASLCode.Ok: # code - return - else: - raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) + self._negotiate() + +class SASLWithWebSocket(WebSocketTransport, SASLTransportMixin): + + def __init__(self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs): + self.credential = credential + ssl = ssl or True + http_proxy = kwargs.pop('http_proxy', None) + self._transport = WebSocketTransport( + host, + port=port, + connect_timeout=connect_timeout, + ssl=ssl, + http_proxy=http_proxy, + **kwargs + ) + super().__init__(host, port, connect_timeout, ssl, **kwargs) + + def negotiate(self): + self._negotiate() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py index 440fcc69d1d5..03c76c0832af 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py @@ -3,4 +3,4 @@ # Licensed under the MIT License. # ------------------------------------ -VERSION = "5.8.0b4" +VERSION = "5.8.0a4" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index d5be74195636..73d5effc2697 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -143,15 +143,20 @@ def _create_handler(self, auth: "JWTTokenAuthAsync") -> None: ) ) desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None - + hostname = urlparse(source.address).hostname + transport_type = self._client._config.transport_type, # pylint:disable=protected-access + if transport_type.name is 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' self._handler = ReceiveClientAsync( - urlparse(source.address).hostname, + hostname, source, auth=auth, idle_timeout=self._idle_timeout, network_trace=self._client._config.network_tracing, # pylint:disable=protected-access link_credit=self._prefetch, link_properties=self._link_properties, + transport_type=transport_type, + http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access retry_policy=self._retry_policy, client_name=self._name, receive_settle_mode=pyamqp_constants.ReceiverSettleMode.First, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index 18467d15a416..4e8a628e5d19 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -99,14 +99,20 @@ def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> N self._link_properties = {TIMEOUT_SYMBOL: pyamqp_utils.amqp_long_value(int(self._timeout * 1000))} def _create_handler(self, auth: "JWTTokenAsync") -> None: + hostname = self._client._address.hostname, # pylint: disable=protected-access + transport_type = self._client._config.transport_type, # pylint:disable=protected-access + if transport_type.name is 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' self._handler = SendClientAsync( - self._client._address.hostname, # pylint: disable=protected-access + hostname, self._target, auth=auth, idle_timeout=self._idle_timeout, network_trace=self._client._config.network_tracing, # pylint: disable=protected-access retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, + transport_type=transport_type, + http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access client_name=self._name, link_properties=self._link_properties, properties=create_properties(self._client._config.user_agent), # pylint: disable=protected-access diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/async/test_websocket_async.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/async/test_websocket_async.py new file mode 100644 index 000000000000..443862be0c97 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/async/test_websocket_async.py @@ -0,0 +1,35 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import pytest +import asyncio +import logging +from uamqp.aio import ReceiveClientAsync, SASTokenAuthAsync +from uamqp.constants import TransportType + +@pytest.mark.asyncio +async def test_event_hubs_client_web_socket(eventhub_config): + uri = "sb://{}/{}".format(eventhub_config['hostname'], eventhub_config['event_hub']) + sas_auth = SASTokenAuthAsync( + uri=uri, + audience=uri, + username=eventhub_config['key_name'], + password=eventhub_config['access_key'] + ) + + source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( + eventhub_config['hostname'], + eventhub_config['event_hub'], + eventhub_config['consumer_group'], + eventhub_config['partition']) + + receive_client = ReceiveClientAsync(eventhub_config['hostname'] + '/$servicebus/websocket/', source, auth=sas_auth, debug=False, timeout=5000, prefetch=50, transport_type=TransportType.AmqpOverWebsocket) + await receive_client.open_async() + while not await receive_client.client_ready_async(): + await asyncio.sleep(0.05) + messages = await receive_client.receive_message_batch_async(max_batch_size=1) + logging.info(len(messages)) + logging.info(messages[0]) + await receive_client.close_async() diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py new file mode 100644 index 000000000000..7dd9e5bfbe9c --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py @@ -0,0 +1,27 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import pytest + +from azure.eventhub._pyamqp import authentication, ReceiveClient +from azure.eventhub._pyamqp.constants import TransportType + +def test_event_hubs_client_web_socket(live_eventhub): + uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) + sas_auth = authentication.SASTokenAuth( + uri=uri, + audience=uri, + username=live_eventhub['key_name'], + password=live_eventhub['access_key'] + ) + + source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( + live_eventhub['hostname'], + live_eventhub['event_hub'], + live_eventhub['consumer_group'], + live_eventhub['partition']) + + with ReceiveClient(live_eventhub['hostname'] + '/$servicebus/websocket/', source, auth=sas_auth, debug=False, timeout=5000, prefetch=50, transport_type=TransportType.AmqpOverWebsocket) as receive_client: + receive_client.receive_message_batch(max_batch_size=10)