Skip to content

Commit

Permalink
Adding back Rakshith's websocket changes (#24410)
Browse files Browse the repository at this point in the history
* Adding back Rakshith's sync websocket changes

* fix async send and receive

* fix transport bugs

* add websocket to dev reqs + async fix hostname

* thank you kashif

* fix tests + turn on websocket tests

* update consumer test timing
  • Loading branch information
swathipil authored May 14, 2022
1 parent 737a79e commit b6f8b72
Show file tree
Hide file tree
Showing 22 changed files with 443 additions and 147 deletions.
6 changes: 1 addition & 5 deletions sdk/eventhub/azure-eventhub/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@

### Features Added

### Breaking Changes

### Bugs Fixed

### Other Changes
- Added support for connection using websocket and http proxy.

## 5.8.0a3 (2022-03-08)

Expand Down
11 changes: 8 additions & 3 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,6 @@ def _create_auth(self):
functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE),
token_type=token_type,
timeout=self._config.auth_timeout,
http_proxy=self._config.http_proxy,
transport_type=self._config.transport_type,
custom_endpoint_hostname=self._config.custom_endpoint_hostname,
port=self._config.connection_port,
verify=self._config.connection_verify,
Expand Down Expand Up @@ -379,8 +377,15 @@ def _management_request(self, mgmt_msg, op_type):
last_exception = None
while retried_times <= self._config.max_retries:
mgmt_auth = self._create_auth()
hostname = self._address.hostname
if self._config.transport_type.name == 'AmqpOverWebsocket':
hostname += '/$servicebus/websocket/'
mgmt_client = AMQPClient(
self._address.hostname, auth=mgmt_auth, debug=self._config.network_tracing
hostname,
auth=mgmt_auth,
debug=self._config.network_tracing,
transport_type=self._config.transport_type,
http_proxy=self._config.http_proxy
)
try:
mgmt_client.open()
Expand Down
11 changes: 9 additions & 2 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,13 @@ def _create_handler(self, auth):
)
desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None

transport_type = self._client._config.transport_type # pylint:disable=protected-access
hostname = urlparse(source.address).hostname
if transport_type.name == 'AmqpOverWebsocket':
hostname += '/$servicebus/websocket/'

self._handler = ReceiveClient(
urlparse(source.address).hostname,
hostname,
source,
auth=auth,
idle_timeout=self._idle_timeout,
Expand All @@ -164,7 +169,9 @@ def _create_handler(self, auth):
properties=create_properties(self._client._config.user_agent), # pylint:disable=protected-access
desired_capabilities=desired_capabilities,
streaming_receive=True,
message_received_callback=self._message_received
message_received_callback=self._message_received,
transport_type=transport_type,
http_proxy=self._client._config.http_proxy # pylint:disable=protected-access
)

def _open_with_retry(self):
Expand Down
8 changes: 7 additions & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,12 @@ def __init__(self, client, target, **kwargs):

def _create_handler(self, auth):
# type: (JWTTokenAuth) -> None
transport_type = self._client._config.transport_type # pylint:disable=protected-access
hostname = self._client._address.hostname # pylint: disable=protected-access
if transport_type.name == 'AmqpOverWebsocket':
hostname += '/$servicebus/websocket/'
self._handler = SendClient(
self._client._address.hostname, # pylint: disable=protected-access
hostname, # pylint: disable=protected-access
self._target,
auth=auth,
idle_timeout=self._idle_timeout,
Expand All @@ -136,6 +140,8 @@ def _create_handler(self, auth):
client_name=self._name,
link_properties=self._link_properties,
properties=create_properties(self._client._config.user_agent), # pylint: disable=protected-access
transport_type=transport_type,
http_proxy=self._client._config.http_proxy # pylint: disable=protected-access
)

def _open_with_retry(self):
Expand Down
23 changes: 18 additions & 5 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ssl import SSLError

from ._transport import Transport
from .sasl import SASLTransport
from .sasl import SASLTransport, SASLWithWebSocket
from .session import Session
from .performatives import OpenFrame, CloseFrame
from .constants import (
Expand All @@ -22,7 +22,8 @@
MAX_FRAME_SIZE_BYTES,
HEADER_FRAME,
ConnectionState,
EMPTY_FRAME
EMPTY_FRAME,
TransportType
)

from .error import (
Expand Down Expand Up @@ -77,12 +78,19 @@ class Connection(object):
Default value is `0.1`.
:keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames
will be logged at the logging.INFO level.
:keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket.
Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy.
:keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following
keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings,
the transport_type would be AmqpOverWebSocket.
Additionally the following keys may also be present: `'username', 'password'`.
"""

def __init__(self, endpoint, **kwargs):
# type(str, Any) -> None
parsed_url = urlparse(endpoint)
self._hostname = parsed_url.hostname
endpoint = self._hostname
if parsed_url.port:
self._port = parsed_url.port
elif parsed_url.scheme == 'amqps':
Expand All @@ -92,16 +100,21 @@ def __init__(self, endpoint, **kwargs):
self.state = None # type: Optional[ConnectionState]

transport = kwargs.get('transport')
self._transport_type = kwargs.pop('transport_type', TransportType.Amqp)
if transport:
self._transport = transport
elif 'sasl_credential' in kwargs:
self._transport = SASLTransport(
host=parsed_url.netloc,
sasl_transport = SASLTransport
if self._transport_type.name == 'AmqpOverWebsocket' or kwargs.get("http_proxy"):
sasl_transport = SASLWithWebSocket
endpoint = parsed_url.hostname + parsed_url.path
self._transport = sasl_transport(
host=endpoint,
credential=kwargs['sasl_credential'],
**kwargs
)
else:
self._transport = Transport(parsed_url.netloc, **kwargs)
self._transport = Transport(parsed_url.netloc, transport_type=self._transport_type, **kwargs)

self._container_id = kwargs.pop('container_id', None) or str(uuid.uuid4()) # type: str
self._max_frame_size = kwargs.pop('max_frame_size', MAX_FRAME_SIZE_BYTES) # type: int
Expand Down
80 changes: 75 additions & 5 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack
from ._encode import encode_frame
from ._decode import decode_frame, decode_empty_frame
from .constants import TLS_HEADER_FRAME
from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, TransportType, AMQP_WS_SUBPROTOCOL


try:
Expand Down Expand Up @@ -439,7 +439,7 @@ def write(self, s):

def receive_frame(self, *args, **kwargs):
try:
header, channel, payload = self.read(**kwargs)
header, channel, payload = self.read(**kwargs)
if not payload:
decoded = decode_empty_frame(header)
else:
Expand Down Expand Up @@ -646,12 +646,82 @@ def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)):
result, self._read_buffer = rbuf[:n], rbuf[n:]
return result


def Transport(host, connect_timeout=None, ssl=False, **kwargs):
def Transport(host, transport_type, connect_timeout=None, ssl=False, **kwargs):
"""Create transport.
Given a few parameters from the Connection constructor,
select and create a subclass of _AbstractTransport.
"""
transport = SSLTransport if ssl else TCPTransport
if transport_type == TransportType.AmqpOverWebsocket:
transport = WebSocketTransport
else:
transport = SSLTransport if ssl else TCPTransport
return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

class WebSocketTransport(_AbstractTransport):
def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs):
self.sslopts = ssl if isinstance(ssl, dict) else {}
self._connect_timeout = connect_timeout
self._host = host
super().__init__(
host, port, connect_timeout, **kwargs
)
self.ws = None
self._http_proxy = kwargs.get('http_proxy', None)

def connect(self):
http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None
if self._http_proxy:
http_proxy_host = self._http_proxy['proxy_hostname']
http_proxy_port = self._http_proxy['proxy_port']
username = self._http_proxy.get('username', None)
password = self._http_proxy.get('password', None)
if username or password:
http_proxy_auth = (username, password)
try:
from websocket import create_connection
self.ws = create_connection(
url="wss://{}".format(self._host),
subprotocols=[AMQP_WS_SUBPROTOCOL],
timeout=self._connect_timeout,
skip_utf8_validation=True,
sslopt=self.sslopts,
http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port,
http_proxy_auth=http_proxy_auth
)
except ImportError:
raise ValueError("Please install websocket-client library to use websocket transport.")

def _read(self, n, initial=False, buffer=None, **kwargs): # pylint: disable=unused-arguments
"""Read exactly n bytes from the peer."""

length = 0
view = buffer or memoryview(bytearray(n))
nbytes = self._read_buffer.readinto(view)
length += nbytes
n -= nbytes
while n:
data = self.ws.recv()

if len(data) <= n:
view[length: length + len(data)] = data
n -= len(data)
else:
view[length: length + n] = data[0:n]
self._read_buffer = BytesIO(data[n:])
n = 0

return view

def _shutdown_transport(self):
"""Do any preliminary work in shutting down the connection."""
self.ws.close()

def _write(self, s):
"""Completely write a string to the peer.
ABNF, OPCODE_BINARY = 0x2
See http://tools.ietf.org/html/rfc5234
http://tools.ietf.org/html/rfc6455#section-5.2
"""
self.ws.send_binary(s)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ._receiver_async import ReceiverLink
from ._sender_async import SenderLink
from ._session_async import Session
from ._sasl_async import SASLTransport
from ._cbs_async import CBSAuthenticator
from ..client import AMQPClient as AMQPClientSync
from ..client import ReceiveClient as ReceiveClientSync
Expand Down Expand Up @@ -201,7 +200,9 @@ async def open_async(self):
channel_max=self._channel_max,
idle_timeout=self._idle_timeout,
properties=self._properties,
network_trace=self._network_trace
network_trace=self._network_trace,
transport_type=self._transport_type,
http_proxy=self._http_proxy
)
await self._connection.open()
if not self._session:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import asyncio

from ._transport_async import AsyncTransport
from ._sasl_async import SASLTransport
from ._sasl_async import SASLTransport, SASLWithWebSocket
from ._session_async import Session
from ..performatives import OpenFrame, CloseFrame
from .._connection import get_local_timeout
Expand All @@ -27,7 +27,8 @@
MAX_CHANNELS,
HEADER_FRAME,
ConnectionState,
EMPTY_FRAME
EMPTY_FRAME,
TransportType
)

from ..error import (
Expand Down Expand Up @@ -58,11 +59,19 @@ class Connection(object):
:param list(str) offered_capabilities: The extension capabilities the sender supports.
:param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports
:param dict properties: Connection properties.
:keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket.
Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy.
:keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following
keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings,
the transport_type would be AmqpOverWebSocket.
Additionally the following keys may also be present: `'username', 'password'`.
"""

def __init__(self, endpoint, **kwargs):
parsed_url = urlparse(endpoint)
self.hostname = parsed_url.hostname
endpoint = self.hostname
self._transport_type = kwargs.pop('transport_type', TransportType.Amqp)
if parsed_url.port:
self.port = parsed_url.port
elif parsed_url.scheme == 'amqps':
Expand All @@ -75,8 +84,12 @@ def __init__(self, endpoint, **kwargs):
if transport:
self.transport = transport
elif 'sasl_credential' in kwargs:
self.transport = SASLTransport(
host=parsed_url.netloc,
sasl_transport = SASLTransport
if self._transport_type.name == "AmqpOverWebsocket" or kwargs.get("http_proxy"):
sasl_transport = SASLWithWebSocket
endpoint = parsed_url.hostname + parsed_url.path
self.transport = sasl_transport(
host=endpoint,
credential=kwargs['sasl_credential'],
**kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import struct
from enum import Enum

from ._transport_async import AsyncTransport
from ._transport_async import AsyncTransport, WebSocketTransportAsync
from ..types import AMQPTypes, TYPE, VALUE
from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME
from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME, WEBSOCKET_PORT, TransportType
from .._transport import AMQPS_PORT
from ..performatives import (
SASLOutcome,
Expand Down Expand Up @@ -73,14 +73,8 @@ def start(self):
return b''


class SASLTransport(AsyncTransport):

def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs):
self.credential = credential
ssl = ssl or True
super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

async def negotiate(self):
class SASLTransportMixinAsync():
async def _negotiate(self):
await self.write(SASL_HEADER_FRAME)
_, returned_header = await self.receive_frame()
if returned_header[1] != SASL_HEADER_FRAME:
Expand All @@ -104,3 +98,35 @@ async def negotiate(self):
return
else:
raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields))


class SASLTransport(AsyncTransport, SASLTransportMixinAsync):

def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs):
self.credential = credential
ssl = ssl or True
super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

async def negotiate(self):
await self._negotiate()


class SASLWithWebSocket(WebSocketTransportAsync, SASLTransportMixinAsync):
def __init__(
self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs
):
self.credential = credential
ssl = ssl or True
http_proxy = kwargs.pop('http_proxy', None)
self._transport = WebSocketTransportAsync(
host,
port=port,
connect_timeout=connect_timeout,
ssl=ssl,
http_proxy=http_proxy,
**kwargs
)
super().__init__(host, port, connect_timeout, ssl, **kwargs)

async def negotiate(self):
await self._negotiate()
Loading

0 comments on commit b6f8b72

Please sign in to comment.