Skip to content

Commit

Permalink
Revert "AMQP websocket implementation (Azure#23722)" (Azure#24344)
Browse files Browse the repository at this point in the history
  • Loading branch information
rakshith91 authored and swathipil committed Aug 23, 2022
1 parent 69a0b6a commit f63f352
Show file tree
Hide file tree
Showing 13 changed files with 120 additions and 440 deletions.
2 changes: 0 additions & 2 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ def _create_handler(self, auth: uamqp_JWTTokenAuth) -> None:
source=source,
auth=auth,
network_trace=self._client._config.network_tracing, # pylint:disable=protected-access
transport_type=transport_type,
http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access
link_credit=self._prefetch,
link_properties=self._link_properties,
timeout=self._timeout,
Expand Down
23 changes: 5 additions & 18 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ssl import SSLError

from ._transport import Transport
from .sasl import SASLTransport, SASLWithWebSocket
from .sasl import SASLTransport
from .session import Session
from .performatives import OpenFrame, CloseFrame
from .constants import (
Expand All @@ -22,8 +22,7 @@
MAX_FRAME_SIZE_BYTES,
HEADER_FRAME,
ConnectionState,
EMPTY_FRAME,
TransportType
EMPTY_FRAME
)

from .error import (
Expand Down Expand Up @@ -78,19 +77,12 @@ class Connection(object):
Default value is `0.1`.
:keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames
will be logged at the logging.INFO level.
:keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket.
Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy.
:keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following
keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings,
the transport_type would be AmqpOverWebSocket.
Additionally the following keys may also be present: `'username', 'password'`.
"""

def __init__(self, endpoint, **kwargs):
# type(str, Any) -> None
parsed_url = urlparse(endpoint)
self._hostname = parsed_url.hostname
endpoint = self._hostname
if parsed_url.port:
self._port = parsed_url.port
elif parsed_url.scheme == 'amqps':
Expand All @@ -100,21 +92,16 @@ def __init__(self, endpoint, **kwargs):
self.state = None # type: Optional[ConnectionState]

transport = kwargs.get('transport')
self._transport_type = kwargs.pop('transport_type', TransportType.Amqp)
if transport:
self._transport = transport
elif 'sasl_credential' in kwargs:
sasl_transport = SASLTransport
if self._transport_type.name is 'AmqpOverWebsocket' or kwargs.get("http_proxy"):
sasl_transport = SASLWithWebSocket
endpoint = parsed_url.hostname + parsed_url.path
self._transport = sasl_transport(
host=endpoint,
self._transport = SASLTransport(
host=parsed_url.netloc,
credential=kwargs['sasl_credential'],
**kwargs
)
else:
self._transport = Transport(parsed_url.netloc, self._transport_type, **kwargs)
self._transport = Transport(parsed_url.netloc, **kwargs)

self._container_id = kwargs.pop('container_id', None) or str(uuid.uuid4()) # type: str
self._max_frame_size = kwargs.pop('max_frame_size', MAX_FRAME_SIZE_BYTES) # type: int
Expand Down
78 changes: 4 additions & 74 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack
from ._encode import encode_frame
from ._decode import decode_frame, decode_empty_frame
from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, TransportType, AMQP_WS_SUBPROTOCOL
from .constants import TLS_HEADER_FRAME


try:
Expand Down Expand Up @@ -456,6 +456,7 @@ def send_frame(self, channel, frame, **kwargs):
else:
encoded_channel = struct.pack('>H', channel)
data = header + encoded_channel + performative

self.write(data)

def negotiate(self, encode, decode):
Expand Down Expand Up @@ -646,82 +647,11 @@ def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)):
return result


def Transport(host, transport_type, connect_timeout=None, ssl=False, **kwargs):
def Transport(host, connect_timeout=None, ssl=False, **kwargs):
"""Create transport.
Given a few parameters from the Connection constructor,
select and create a subclass of _AbstractTransport.
"""
if transport_type == TransportType.AmqpOverWebsocket:
transport = WebSocketTransport
else:
transport = SSLTransport if ssl else TCPTransport
transport = SSLTransport if ssl else TCPTransport
return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

class WebSocketTransport(_AbstractTransport):
def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs
):
self.sslopts = ssl if isinstance(ssl, dict) else {}
self._connect_timeout = connect_timeout
self._host = host
super().__init__(
host, port, connect_timeout, **kwargs
)
self.ws = None
self._http_proxy = kwargs.get('http_proxy', None)

def connect(self):
http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None
if self._http_proxy:
http_proxy_host = self._http_proxy['proxy_hostname']
http_proxy_port = self._http_proxy['proxy_port']
username = self._http_proxy.get('username', None)
password = self._http_proxy.get('password', None)
if username or password:
http_proxy_auth = (username, password)
try:
from websocket import create_connection
self.ws = create_connection(
url="wss://{}".format(self._host),
subprotocols=[AMQP_WS_SUBPROTOCOL],
timeout=self._connect_timeout,
skip_utf8_validation=True,
sslopt=self.sslopts,
http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port,
http_proxy_auth=http_proxy_auth
)
except ImportError:
raise ValueError("Please install websocket-client library to use websocket transport.")

def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments
"""Read exactly n bytes from the peer."""

length = 0
view = buffer or memoryview(bytearray(n))
nbytes = self._read_buffer.readinto(view)
length += nbytes
n -= nbytes
while n:
data = self.ws.recv()

if len(data) <= n:
view[length: length + len(data)] = data
n -= len(data)
else:
view[length: length + n] = data[0:n]
self._read_buffer = BytesIO(data[n:])
n = 0
return view

def _shutdown_transport(self):
"""Do any preliminary work in shutting down the connection."""
self.ws.close()

def _write(self, s):
"""Completely write a string to the peer.
ABNF, OPCODE_BINARY = 0x2
See http://tools.ietf.org/html/rfc5234
http://tools.ietf.org/html/rfc6455#section-5.2
"""
self.ws.send_binary(s)
Original file line number Diff line number Diff line change
Expand Up @@ -176,23 +176,6 @@ async def _do_retryable_operation_async(self, operation, *args, **kwargs):
absolute_timeout -= (end_time - start_time)
raise retry_settings['history'][-1]

async def _keep_alive_worker_async(self):
interval = 10 if self._keep_alive is True else self._keep_alive
start_time = time.time()
try:
while self._connection and not self._shutdown:
current_time = time.time()
elapsed_time = (current_time - start_time)
if elapsed_time >= interval:
_logger.info("Keeping %r connection alive. %r",
self.__class__.__name__,
self._connection._container_id)
await self._connection._get_remote_timeout(current_time)
start_time = current_time
await asyncio.sleep(1)
except Exception as e: # pylint: disable=broad-except
_logger.info("Connection keep-alive for %r failed: %r.", self.__class__.__name__, e)

async def open_async(self):
"""Asynchronously open the client. The client can create a new Connection
or an existing Connection can be passed in. This existing Connection
Expand All @@ -217,8 +200,6 @@ async def open_async(self):
max_frame_size=self._max_frame_size,
channel_max=self._channel_max,
idle_timeout=self._idle_timeout,
transport_type=self._transport_type,
http_proxy=self._http_proxy,
properties=self._properties,
network_trace=self._network_trace
)
Expand All @@ -236,8 +217,6 @@ async def open_async(self):
auth_timeout=self._auth_timeout
)
await self._cbs_authenticator.open()
if self._keep_alive:
self._keep_alive_thread = asyncio.ensure_future(self._keep_alive_worker_async())
self._shutdown = False

async def close_async(self):
Expand All @@ -249,9 +228,6 @@ async def close_async(self):
self._shutdown = True
if not self._session:
return # already closed.
if self._keep_alive_thread:
await self._keep_alive_thread
self._keep_alive_thread = None
await self._close_link_async(close=True)
if self._cbs_authenticator:
await self._cbs_authenticator.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import asyncio

from ._transport_async import AsyncTransport
from ._sasl_async import SASLTransport, SASLWithWebSocket
from ._sasl_async import SASLTransport
from ._session_async import Session
from ..performatives import OpenFrame, CloseFrame
from .._connection import get_local_timeout
Expand All @@ -27,8 +27,7 @@
MAX_CHANNELS,
HEADER_FRAME,
ConnectionState,
EMPTY_FRAME,
TransportType
EMPTY_FRAME
)

from ..error import (
Expand Down Expand Up @@ -59,36 +58,25 @@ class Connection(object):
:param list(str) offered_capabilities: The extension capabilities the sender supports.
:param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports
:param dict properties: Connection properties.
:keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket.
Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy.
:keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following
keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings,
the transport_type would be AmqpOverWebSocket.
Additionally the following keys may also be present: `'username', 'password'`.
"""

def __init__(self, endpoint, **kwargs):
parsed_url = urlparse(endpoint)
self.hostname = parsed_url.hostname
endpoint = self._hostname
self._transport_type = kwargs.pop('transport_type', TransportType.Amqp)
if parsed_url.port:
self.port = parsed_url.port
elif parsed_url.scheme == 'amqps':
self.port = SECURE_PORT
else:
self.port = PORT
self.state = None

transport = kwargs.get('transport')
if transport:
self.transport = transport
elif 'sasl_credential' in kwargs:
sasl_transport = SASLTransport
if self._transport_type.name is 'AmqpOverWebsocket' or kwargs.get("http_proxy"):
sasl_transport = SASLWithWebSocket
endpoint = parsed_url.hostname + parsed_url.path
self._transport = sasl_transport(
host=endpoint,
self.transport = SASLTransport(
host=parsed_url.netloc,
credential=kwargs['sasl_credential'],
**kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import struct
from enum import Enum

from ._transport_async import AsyncTransport, WebSocketTransportAsync
from ._transport_async import AsyncTransport
from ..types import AMQPTypes, TYPE, VALUE
from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME, WEBSOCKET_PORT, TransportType
from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME
from .._transport import AMQPS_PORT
from ..performatives import (
SASLOutcome,
Expand Down Expand Up @@ -72,7 +72,14 @@ class SASLExternalCredential(object):
def start(self):
return b''

class SASLTransportMixinAsync():

class SASLTransport(AsyncTransport):

def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs):
self.credential = credential
ssl = ssl or True
super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

async def negotiate(self):
await self.write(SASL_HEADER_FRAME)
_, returned_header = await self.receive_frame()
Expand All @@ -97,26 +104,3 @@ async def negotiate(self):
return
else:
raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields))

class SASLTransport(AsyncTransport, SASLTransportMixinAsync):
def __init__(self, host, credential, connect_timeout=None, ssl=None, **kwargs):
self.credential = credential
ssl = ssl or True
super(SASLTransport, self).__init__(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

class SASLWithWebSocket(WebSocketTransportAsync, SASLTransportMixinAsync):
def __init__(
self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs
): # pylint: disable=super-init-not-called
self.credential = credential
ssl = ssl or True
http_proxy = kwargs.pop('http_proxy', None)
self._transport = WebSocketTransportAsync(
host,
port=port,
connect_timeout=connect_timeout,
ssl=ssl,
http_proxy=http_proxy,
**kwargs
)
super().__init__(host, port, connect_timeout, ssl, **kwargs)
Loading

0 comments on commit f63f352

Please sign in to comment.