Skip to content

Commit

Permalink
AMQP websocket implementation (Azure#24345)
Browse files Browse the repository at this point in the history
* Initial implementation

* http proxy support

* change impl

* more changes

* working sol

* async impl

* Update sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py

* more changes

* sasl mixin

* Update sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/message.py

* refactor

* Update sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py

* Update sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py

* oops

* comments

* comment

* Apply suggestions from code review

Co-authored-by: swathipil <76007337+swathipil@users.noreply.github.com>

* comments

* changes

* async test

* rasie

* lint

* changelog

* version

* comments

* move path to EH

* Fix typo

Co-authored-by: swathipil <76007337+swathipil@users.noreply.github.com>
  • Loading branch information
2 people authored and kashifkhan committed May 10, 2022
1 parent 7305718 commit e3689b4
Show file tree
Hide file tree
Showing 18 changed files with 469 additions and 130 deletions.
4 changes: 3 additions & 1 deletion sdk/eventhub/azure-eventhub/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Release History

## 5.8.0b4 (Unreleased)
## 5.8.0a4 (Unreleased)

### Features Added

- Added suppport for connection using websocket and http proxy.

### Breaking Changes

### Bugs Fixed
Expand Down
2 changes: 0 additions & 2 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,6 @@ def _create_auth(self):
functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE),
token_type=token_type,
timeout=self._config.auth_timeout,
http_proxy=self._config.http_proxy,
transport_type=self._config.transport_type,
custom_endpoint_hostname=self._config.custom_endpoint_hostname,
port=self._config.connection_port,
verify=self._config.connection_verify,
Expand Down
8 changes: 7 additions & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def __init__(self, client, source, **kwargs):

def _create_handler(self, auth):
# type: (JWTTokenAuth) -> None
transport_type = self._client._config.transport_type # pylint:disable=protected-access
hostname = urlparse(source.address).hostname
if transport_type.name is 'AmqpOverWebsocket':
hostname += '/$servicebus/websocket/'
source = Source(address=self._source, filters={})
if self._offset is not None:
filter_key = ApacheFilters.selector_filter
Expand All @@ -151,11 +155,13 @@ def _create_handler(self, auth):
desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None

self._handler = ReceiveClient(
urlparse(source.address).hostname,
hostname,
source,
auth=auth,
idle_timeout=self._idle_timeout,
network_trace=self._client._config.network_tracing, # pylint:disable=protected-access
transport_type=transport_type,
http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access
link_credit=self._prefetch,
link_properties=self._link_properties,
retry_policy=self._retry_policy,
Expand Down
10 changes: 8 additions & 2 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,18 @@ def __init__(self, client, target, **kwargs):

def _create_handler(self, auth):
# type: (JWTTokenAuth) -> None
transport_type=self._client._config.transport_type # pylint:disable=protected-access
hostname = self._client._address.hostname # pylint: disable=protected-access
if transport_type.name is 'AmqpOverWebsocket':
hostname += '/$servicebus/websocket/'
self._handler = SendClient(
self._client._address.hostname, # pylint: disable=protected-access
hostname,
self._target,
auth=auth,
idle_timeout=self._idle_timeout,
network_trace=self._client._config.network_tracing, # pylint: disable=protected-access
network_trace=self._client._config.network_tracing, # pylint:disable=protected-access
transport_type=transport_type,
http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access
retry_policy=self._retry_policy,
keep_alive_interval=self._keep_alive,
client_name=self._name,
Expand Down
23 changes: 18 additions & 5 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ssl import SSLError

from ._transport import Transport
from .sasl import SASLTransport
from .sasl import SASLTransport, SASLWithWebSocket
from .session import Session
from .performatives import OpenFrame, CloseFrame
from .constants import (
Expand All @@ -22,7 +22,8 @@
MAX_FRAME_SIZE_BYTES,
HEADER_FRAME,
ConnectionState,
EMPTY_FRAME
EMPTY_FRAME,
TransportType
)

from .error import (
Expand Down Expand Up @@ -77,12 +78,19 @@ class Connection(object):
Default value is `0.1`.
:keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames
will be logged at the logging.INFO level.
:keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket.
Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy.
:keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following
keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings,
the transport_type would be AmqpOverWebSocket.
Additionally the following keys may also be present: `'username', 'password'`.
"""

def __init__(self, endpoint, **kwargs):
# type(str, Any) -> None
parsed_url = urlparse(endpoint)
self._hostname = parsed_url.hostname
endpoint = self._hostname
if parsed_url.port:
self._port = parsed_url.port
elif parsed_url.scheme == 'amqps':
Expand All @@ -92,16 +100,21 @@ def __init__(self, endpoint, **kwargs):
self.state = None # type: Optional[ConnectionState]

transport = kwargs.get('transport')
self._transport_type = kwargs.pop('transport_type', TransportType.Amqp)
if transport:
self._transport = transport
elif 'sasl_credential' in kwargs:
self._transport = SASLTransport(
host=parsed_url.netloc,
sasl_transport = SASLTransport
if self._transport_type.name is 'AmqpOverWebsocket' or kwargs.get("http_proxy"):
sasl_transport = SASLWithWebSocket
endpoint = parsed_url.hostname + parsed_url.path
self._transport = sasl_transport(
host=endpoint,
credential=kwargs['sasl_credential'],
**kwargs
)
else:
self._transport = Transport(parsed_url.netloc, **kwargs)
self._transport = Transport(parsed_url.netloc, self._transport_type, **kwargs)

self._container_id = kwargs.pop('container_id', None) or str(uuid.uuid4()) # type: str
self._max_frame_size = kwargs.pop('max_frame_size', MAX_FRAME_SIZE_BYTES) # type: int
Expand Down
78 changes: 74 additions & 4 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack
from ._encode import encode_frame
from ._decode import decode_frame, decode_empty_frame
from .constants import TLS_HEADER_FRAME
from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, TransportType, AMQP_WS_SUBPROTOCOL


try:
Expand Down Expand Up @@ -456,7 +456,6 @@ def send_frame(self, channel, frame, **kwargs):
else:
encoded_channel = struct.pack('>H', channel)
data = header + encoded_channel + performative

self.write(data)

def negotiate(self, encode, decode):
Expand Down Expand Up @@ -647,11 +646,82 @@ def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)):
return result


def Transport(host, connect_timeout=None, ssl=False, **kwargs):
def Transport(host, transport_type, connect_timeout=None, ssl=False, **kwargs):
"""Create transport.
Given a few parameters from the Connection constructor,
select and create a subclass of _AbstractTransport.
"""
transport = SSLTransport if ssl else TCPTransport
if transport_type == TransportType.AmqpOverWebsocket:
transport = WebSocketTransport
else:
transport = SSLTransport if ssl else TCPTransport
return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

class WebSocketTransport(_AbstractTransport):
def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs
):
self.sslopts = ssl if isinstance(ssl, dict) else {}
self._connect_timeout = connect_timeout
self._host = host
super().__init__(
host, port, connect_timeout, **kwargs
)
self.ws = None
self._http_proxy = kwargs.get('http_proxy', None)

def connect(self):
http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None
if self._http_proxy:
http_proxy_host = self._http_proxy['proxy_hostname']
http_proxy_port = self._http_proxy['proxy_port']
username = self._http_proxy.get('username', None)
password = self._http_proxy.get('password', None)
if username or password:
http_proxy_auth = (username, password)
try:
from websocket import create_connection
self.ws = create_connection(
url="wss://{}".format(self._host),
subprotocols=[AMQP_WS_SUBPROTOCOL],
timeout=self._connect_timeout,
skip_utf8_validation=True,
sslopt=self.sslopts,
http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port,
http_proxy_auth=http_proxy_auth
)
except ImportError:
raise ValueError("Please install websocket-client library to use websocket transport.")

def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments
"""Read exactly n bytes from the peer."""

length = 0
view = buffer or memoryview(bytearray(n))
nbytes = self._read_buffer.readinto(view)
length += nbytes
n -= nbytes
while n:
data = self.ws.recv()

if len(data) <= n:
view[length: length + len(data)] = data
n -= len(data)
else:
view[length: length + n] = data[0:n]
self._read_buffer = BytesIO(data[n:])
n = 0
return view

def _shutdown_transport(self):
"""Do any preliminary work in shutting down the connection."""
self.ws.close()

def _write(self, s):
"""Completely write a string to the peer.
ABNF, OPCODE_BINARY = 0x2
See http://tools.ietf.org/html/rfc5234
http://tools.ietf.org/html/rfc6455#section-5.2
"""
self.ws.send_binary(s)
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,23 @@ async def _do_retryable_operation_async(self, operation, *args, **kwargs):
absolute_timeout -= (end_time - start_time)
raise retry_settings['history'][-1]

async def _keep_alive_worker_async(self):
interval = 10 if self._keep_alive is True else self._keep_alive
start_time = time.time()
try:
while self._connection and not self._shutdown:
current_time = time.time()
elapsed_time = (current_time - start_time)
if elapsed_time >= interval:
_logger.info("Keeping %r connection alive. %r",
self.__class__.__name__,
self._connection._container_id)
await self._connection._get_remote_timeout(current_time)
start_time = current_time
await asyncio.sleep(1)
except Exception as e: # pylint: disable=broad-except
_logger.info("Connection keep-alive for %r failed: %r.", self.__class__.__name__, e)

async def open_async(self):
"""Asynchronously open the client. The client can create a new Connection
or an existing Connection can be passed in. This existing Connection
Expand All @@ -200,6 +217,8 @@ async def open_async(self):
max_frame_size=self._max_frame_size,
channel_max=self._channel_max,
idle_timeout=self._idle_timeout,
transport_type=self._transport_type,
http_proxy=self._http_proxy,
properties=self._properties,
network_trace=self._network_trace
)
Expand All @@ -217,6 +236,8 @@ async def open_async(self):
auth_timeout=self._auth_timeout
)
await self._cbs_authenticator.open()
if self._keep_alive:
self._keep_alive_thread = asyncio.ensure_future(self._keep_alive_worker_async())
self._shutdown = False

async def close_async(self):
Expand All @@ -228,6 +249,9 @@ async def close_async(self):
self._shutdown = True
if not self._session:
return # already closed.
if self._keep_alive_thread:
await self._keep_alive_thread
self._keep_alive_thread = None
await self._close_link_async(close=True)
if self._cbs_authenticator:
await self._cbs_authenticator.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import asyncio

from ._transport_async import AsyncTransport
from ._sasl_async import SASLTransport
from ._sasl_async import SASLTransport, SASLWithWebSocket
from ._session_async import Session
from ..performatives import OpenFrame, CloseFrame
from .._connection import get_local_timeout
Expand All @@ -27,7 +27,8 @@
MAX_CHANNELS,
HEADER_FRAME,
ConnectionState,
EMPTY_FRAME
EMPTY_FRAME,
TransportType
)

from ..error import (
Expand Down Expand Up @@ -58,25 +59,36 @@ class Connection(object):
:param list(str) offered_capabilities: The extension capabilities the sender supports.
:param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports
:param dict properties: Connection properties.
:keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket.
Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy.
:keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following
keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings,
the transport_type would be AmqpOverWebSocket.
Additionally the following keys may also be present: `'username', 'password'`.
"""

def __init__(self, endpoint, **kwargs):
parsed_url = urlparse(endpoint)
self.hostname = parsed_url.hostname
endpoint = self._hostname
self._transport_type = kwargs.pop('transport_type', TransportType.Amqp)
if parsed_url.port:
self.port = parsed_url.port
elif parsed_url.scheme == 'amqps':
self.port = SECURE_PORT
else:
self.port = PORT
self.state = None

transport = kwargs.get('transport')
if transport:
self.transport = transport
elif 'sasl_credential' in kwargs:
self.transport = SASLTransport(
host=parsed_url.netloc,
sasl_transport = SASLTransport
if self._transport_type.name is 'AmqpOverWebsocket' or kwargs.get("http_proxy"):
sasl_transport = SASLWithWebSocket
endpoint = parsed_url.hostname + parsed_url.path
self._transport = sasl_transport(
host=endpoint,
credential=kwargs['sasl_credential'],
**kwargs
)
Expand Down
Loading

0 comments on commit e3689b4

Please sign in to comment.