Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EventHub] merge uamqp and pyamqp #25193

Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

__version__ = VERSION

from ._constants import TransportType
from ._producer_client import EventHubProducerClient
from ._consumer_client import EventHubConsumerClient
from ._client_base import EventHubSharedKeyCredential
Expand All @@ -17,7 +18,6 @@
parse_connection_string,
EventHubConnectionStringProperties
)
from ._constants import TransportType

__all__ = [
"EventData",
Expand Down
211 changes: 123 additions & 88 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py

Large diffs are not rendered by default.

166 changes: 95 additions & 71 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import six

from ._utils import (
set_message_partition_key,
trace_message,
utc_from_timestamp,
transform_outbound_single_message,
Expand Down Expand Up @@ -53,22 +52,28 @@
AmqpMessageProperties,
)

from ._pyamqp import constants, utils as pyutils
from ._pyamqp.message import BatchMessage, Message

if TYPE_CHECKING:
import datetime
from ._pyamqp.message import Message

PrimitiveTypes = Optional[Union[
int,
float,
bytes,
bool,
str,
Dict,
List,
uuid.UUID,
]]
try:
from uamqp import uamqp_Message
except ImportError:
uamqp_Message = None
import datetime
from ._transport._base import AmqpTransport

PrimitiveTypes = Optional[
Union[
int,
float,
bytes,
bool,
str,
Dict,
List,
uuid.UUID,
]
]

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -122,7 +127,7 @@ def __init__(
self._raw_amqp_message = AmqpAnnotatedMessage( # type: ignore
data_body=body, annotations={}, application_properties={}
)
self.message = (self._raw_amqp_message._message) # pylint:disable=protected-access
self.message = None # amqp message to be set right before sending
self._raw_amqp_message.header = AmqpMessageHeader()
self._raw_amqp_message.properties = AmqpMessageProperties()
self.message_id = None
Expand All @@ -137,53 +142,55 @@ def __repr__(self):
body_str = self.body_as_str()
except:
body_str = "<read-error>"
event_repr = "body='{}'".format(body_str)
event_repr = f"body='{body_str}'"
try:
event_repr += ", properties={}".format(self.properties)
event_repr += f", properties={self.properties}"
except:
event_repr += ", properties=<read-error>"
try:
event_repr += ", offset={}".format(self.offset)
event_repr += f", offset={self.offset}"
except:
event_repr += ", offset=<read-error>"
try:
event_repr += ", sequence_number={}".format(self.sequence_number)
event_repr += f", sequence_number={self.sequence_number}"
except:
event_repr += ", sequence_number=<read-error>"
try:
event_repr += ", partition_key={!r}".format(self.partition_key)
event_repr += f", partition_key={self.partition_key!r}"
except:
event_repr += ", partition_key=<read-error>"
try:
event_repr += ", enqueued_time={!r}".format(self.enqueued_time)
event_repr += f", enqueued_time={self.enqueued_time!r}"
except:
event_repr += ", enqueued_time=<read-error>"
return "EventData({})".format(event_repr)
return f"EventData({event_repr})"

def __str__(self):
# type: () -> str
try:
body_str = self.body_as_str()
except: # pylint: disable=bare-except
body_str = "<read-error>"
event_str = "{{ body: '{}'".format(body_str)
event_str = f"{{ body: '{body_str}'"
try:
event_str += ", properties: {}".format(self.properties)
event_str += f", properties: {self.properties}"
if self.offset:
event_str += ", offset: {}".format(self.offset)
event_str += f", offset: {self.offset}"
if self.sequence_number:
event_str += ", sequence_number: {}".format(self.sequence_number)
event_str += f", sequence_number: {self.sequence_number}"
if self.partition_key:
event_str += ", partition_key={!r}".format(self.partition_key)
event_str += f", partition_key={self.partition_key!r}"
if self.enqueued_time:
event_str += ", enqueued_time={!r}".format(self.enqueued_time)
event_str += f", enqueued_time={self.enqueued_time!r}"
except: # pylint: disable=bare-except
pass
event_str += " }"
return event_str

@classmethod
def _from_message(cls, message, raw_amqp_message=None):
def _from_message(
cls, message: Union["uamqp_Message", "Message"], raw_amqp_message=None
):
# type: (Message, Optional[AmqpAnnotatedMessage]) -> EventData
# pylint:disable=protected-access
"""Internal use only.
Expand All @@ -197,14 +204,13 @@ def _from_message(cls, message, raw_amqp_message=None):
event_data = cls(body="")
event_data.message = message
# pylint: disable=protected-access
event_data._raw_amqp_message = raw_amqp_message if raw_amqp_message else AmqpAnnotatedMessage(message=message)
event_data._raw_amqp_message = (
raw_amqp_message
if raw_amqp_message
else AmqpAnnotatedMessage(message=message)
)
return event_data

def _encode_message(self):
# type: () -> bytes
# pylint: disable=protected-access
return self._raw_amqp_message._message.encode_message()

def _decode_non_data_body_as_str(self, encoding="UTF-8"):
# type: (str) -> str
# pylint: disable=protected-access
Expand All @@ -217,11 +223,6 @@ def _decode_non_data_body_as_str(self, encoding="UTF-8"):
seq_list = [d for seq_section in body for d in seq_section]
return str(decode_with_recurse(seq_list, encoding))

def _to_outgoing_message(self):
# type: () -> EventData
self.message = (self._raw_amqp_message._to_outgoing_amqp_message()) # pylint:disable=protected-access
return self

@property
def raw_amqp_message(self):
# type: () -> AmqpAnnotatedMessage
Expand Down Expand Up @@ -268,10 +269,12 @@ def partition_key(self):

:rtype: bytes
"""
try:
return self._raw_amqp_message.annotations[PROP_PARTITION_KEY]
except KeyError:
return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None)
# TODO: Ask Anna. I think just trying this is reasonable? Haven't seen a case where symbol is used to get.
# try:
# return self._raw_amqp_message.annotations[types.AMQPSymbol(PROP_PARTITION_KEY)]
# except KeyError:
# return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None)
return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None)

@property
def properties(self):
Expand Down Expand Up @@ -378,9 +381,7 @@ def body_as_str(self, encoding="UTF-8"):
try:
return cast(bytes, data).decode(encoding)
except Exception as e:
raise TypeError(
"Message data is not compatible with string type: {}".format(e)
)
raise TypeError(f"Message data is not compatible with string type: {e}")

def body_as_json(self, encoding="UTF-8"):
# type: (str) -> Dict[str, Any]
Expand All @@ -394,7 +395,7 @@ def body_as_json(self, encoding="UTF-8"):
try:
return json.loads(data_str)
except Exception as e:
raise TypeError("Event data is not compatible with JSON type: {}".format(e))
raise TypeError(f"Event data is not compatible with JSON type: {e}")

@property
def content_type(self):
Expand Down Expand Up @@ -489,8 +490,14 @@ class EventDataBatch(object):
Event Hub decided by the service.
"""

def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None):
# type: (Optional[int], Optional[str], Optional[Union[str, bytes]]) -> None
def __init__(
self,
max_size_in_bytes: Optional[int] = None,
partition_id: Optional[str] = None,
partition_key: Optional[Union[str, bytes]] = None,
**kwargs,
) -> None:
self._amqp_transport = kwargs.pop("amqp_transport")

if partition_key and not isinstance(
partition_key, (six.text_type, six.binary_type)
Expand All @@ -502,33 +509,48 @@ def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None
"partition_key to only be string type, they might fail to parse the non-string value."
)

self.max_size_in_bytes = max_size_in_bytes or constants.MAX_FRAME_SIZE_BYTES
self.message = BatchMessage(data=[])
self.max_size_in_bytes = (
max_size_in_bytes or self._amqp_transport.MAX_FRAME_SIZE_BYTES
)
self.message = self._amqp_transport.BATCH_MESSAGE(data=[])
self._partition_id = partition_id
self._partition_key = partition_key
self.message = set_message_partition_key(self.message, self._partition_key)
self._size = pyutils.get_message_encoded_size(self.message)

self.message = self._amqp_transport.set_message_partition_key(
self.message, self._partition_key
)
self._size = self._amqp_transport.get_batch_message_encoded_size(self.message)
self._count = 0
self._internal_events: List[
Union[EventData, AmqpAnnotatedMessage]
] = [] # TODO: only used by uamqp

def __repr__(self):
# type: () -> str
batch_repr = "max_size_in_bytes={}, partition_id={}, partition_key={!r}, event_count={}".format(
self.max_size_in_bytes, self._partition_id, self._partition_key, self._count
batch_repr = (
f"max_size_in_bytes={self.max_size_in_bytes}, partition_id={self._partition_id}, "
f"partition_key={self._partition_key!r}, event_count={self._count}"
)
return "EventDataBatch({})".format(batch_repr)
return f"EventDataBatch({batch_repr})"

def __len__(self):
return self._count

@classmethod
def _from_batch(cls, batch_data, partition_key=None):
# type: (Iterable[EventData], Optional[AnyStr]) -> EventDataBatch
outgoing_batch_data = [transform_outbound_single_message(m, EventData) for m in batch_data]
batch_data_instance = cls(partition_key=partition_key)

def _from_batch(cls, batch_data, amqp_transport, partition_key=None):
# type: (Iterable[EventData], AmqpTransport, Optional[AnyStr]) -> EventDataBatch
outgoing_batch_data = [
transform_outbound_single_message(
m, EventData, amqp_transport.to_outgoing_amqp_message
)
for m in batch_data
]
batch_data_instance = cls(
partition_key=partition_key, amqp_transport=amqp_transport
)

for event_data in outgoing_batch_data:
batch_data_instance.add(event_data)

return batch_data_instance

def _load_events(self, events):
Expand Down Expand Up @@ -565,7 +587,9 @@ def add(self, event_data):
:raise: :class:`ValueError`, when exceeding the size limit.
"""

outgoing_event_data = transform_outbound_single_message(event_data, EventData)
outgoing_event_data = transform_outbound_single_message(
event_data, EventData, self._amqp_transport.to_outgoing_amqp_message
)

if self._partition_key:
if (
Expand All @@ -576,12 +600,14 @@ def add(self, event_data):
"The partition key of event_data does not match the partition key of this batch."
)
if not outgoing_event_data.partition_key:
outgoing_event_data.message = set_message_partition_key(
self._amqp_transport.set_message_partition_key(
outgoing_event_data.message, self._partition_key
)

trace_message(outgoing_event_data)
event_data_size = pyutils.get_message_encoded_size(outgoing_event_data.message)
event_data_size = self._amqp_transport.get_message_encoded_size(
outgoing_event_data.message
)
# For a BatchMessage, if the encoded_message_size of event_data is < 256, then the overhead cost to encode that
# message into the BatchMessage would be 5 bytes, if >= 256, it would be 8 bytes.
size_after_add = (
Expand All @@ -592,11 +618,9 @@ def add(self, event_data):

if size_after_add > self.max_size_in_bytes:
raise ValueError(
"EventDataBatch has reached its size limit: {}".format(
self.max_size_in_bytes
)
f"EventDataBatch has reached its size limit: {self.max_size_in_bytes}"
)

pyutils.add_batch(self.message, outgoing_event_data.message)
self._amqp_transport.add_batch(self, outgoing_event_data, event_data)
self._size = size_after_add
self._count += 1
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from urllib.parse import urlparse

from azure.core.pipeline.policies import RetryMode

from ._constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT


Expand Down Expand Up @@ -34,10 +33,14 @@ def __init__(self, **kwargs):
self.connection_verify = kwargs.get("connection_verify") # type: Optional[str]
self.connection_port = DEFAULT_AMQPS_PORT
self.custom_endpoint_hostname = None
self.hostname = kwargs.pop("hostname")
uamqp_transport = kwargs.pop("uamqp_transport")

if self.http_proxy or self.transport_type == TransportType.AmqpOverWebsocket:
self.transport_type = TransportType.AmqpOverWebsocket
self.connection_port = DEFAULT_AMQP_WSS_PORT
if not uamqp_transport:
self.hostname += "/$servicebus/websocket"

# custom end point
if self.custom_endpoint_address:
Expand All @@ -48,5 +51,7 @@ def __init__(self, **kwargs):
endpoint = urlparse(self.custom_endpoint_address)
self.transport_type = TransportType.AmqpOverWebsocket
self.custom_endpoint_hostname = endpoint.hostname
if not uamqp_transport:
self.custom_endpoint_address += "/$servicebus/websocket"
# in case proxy and custom endpoint are both provided, we default port to 443 if it's not provided
self.connection_port = endpoint.port or DEFAULT_AMQP_WSS_PORT
Loading