Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EventHubs & AMQP Python] Send Port #19745

Merged
merged 5 commits into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
except ImportError:
from urllib.parse import urlparse, quote_plus


from .pyamqp.client import AMQPClient as PyAMQPClient
from .pyamqp.authentication import _generate_sas_token as Py_generate_sas_token
from .pyamqp.message import Message as PyMessage, Properties as PyMessageProperties
from uamqp import authentication
from .pyamqp import constants, error as errors, utils
from .pyamqp.authentication import JWTTokenAuth as PyJWTTokenAuth
from .pyamqp.client import AMQPClient
from .pyamqp.message import Message


import six
from azure.core.credentials import AccessToken, AzureSasCredential, AzureNamedKeyCredential
from azure.core.utils import parse_connection_string as core_parse_connection_string
Expand Down Expand Up @@ -175,7 +179,9 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> _AccessToken
if not scopes:
raise ValueError("No token scope provided.")
return _generate_sas_token(scopes[0], self.policy, self.key)

return Py_generate_sas_token(scopes[0], self.policy, self.key)


class EventhubAzureNamedKeyTokenCredential(object):
"""The named key credential used for authentication.
Expand Down Expand Up @@ -264,7 +270,7 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg
self._credential = credential #type: ignore
self._keep_alive = kwargs.get("keep_alive", 30)
self._auto_reconnect = kwargs.get("auto_reconnect", True)
self._mgmt_target = "amqps://{}/{}".format(
self._mgmt_target = "{}/{}".format(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: why is the format of the target changing?

Copy link
Contributor Author

@yunhaoling yunhaoling Jul 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was testing the mgmt link and found that the "amqps://" is prepended by the python amqp impl.

this is a good point regarding to how we construct the uri (protocol "amqps://" in eh or in amqp).
I revert the change first.

will address that when doing the mgmt link work

self._address.hostname, self.eventhub_name
)
self._auth_uri = "sb://{}{}".format(self._address.hostname, self._address.path)
Expand Down Expand Up @@ -344,7 +350,7 @@ def _management_request(self, mgmt_msg, op_type):
last_exception = None
while retried_times <= self._config.max_retries:
mgmt_auth = self._create_auth()
mgmt_client = AMQPClient(
mgmt_client = PyAMQPClient(
self._address.hostname, auth=mgmt_auth, debug=self._config.network_tracing
)
try:
Expand Down Expand Up @@ -446,7 +452,7 @@ def _get_partition_ids(self):

def _get_partition_properties(self, partition_id):
# type:(str) -> Dict[str, Any]
mgmt_msg = Message(
mgmt_msg = PyMessage(
application_properties={
"name": self.eventhub_name,
"partition": partition_id,
Expand Down Expand Up @@ -507,15 +513,15 @@ def _open(self):
auth = self._client._create_auth()
self._create_handler(auth)
self._handler.open(
connection=self._client._conn_manager.get_connection(
self._client._address.hostname, auth
) # pylint: disable=protected-access
# connection=self._client._conn_manager.get_connection(
# self._client._address.hostname, auth
# ) # pylint: disable=protected-access
)
while not self._handler.client_ready():
time.sleep(0.05)
self._max_message_size_on_link = (
self._handler.message_handler._link.peer_max_message_size
or MAX_MESSAGE_LENGTH_BYTES
self._handler._link.remote_max_message_size
or constants.MAX_FRAME_SIZE_BYTES
) # pylint: disable=protected-access
self.running = True

Expand Down
24 changes: 13 additions & 11 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
)

import six

from .pyamqp import constants, _encode as encode
from .pyamqp.message import BatchMessage, Message


from ._utils import (
set_message_partition_key,
trace_message,
Expand Down Expand Up @@ -57,6 +57,8 @@
AmqpMessageProperties,
)

from .pyamqp import utils as pyutils
yunhaoling marked this conversation as resolved.
Show resolved Hide resolved

if TYPE_CHECKING:
import datetime

Expand Down Expand Up @@ -108,8 +110,8 @@ def __init__(self, body=None):

# Internal usage only for transforming AmqpAnnotatedMessage to outgoing EventData
self._raw_amqp_message = AmqpAnnotatedMessage( # type: ignore
data_body=body, annotations={}, application_properties={}
)
data_body=[body], annotations={}, application_properties={}
)
self.message = (self._raw_amqp_message._message) # pylint:disable=protected-access
self._raw_amqp_message.header = AmqpMessageHeader()
self._raw_amqp_message.properties = AmqpMessageProperties()
Expand Down Expand Up @@ -483,13 +485,14 @@ def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None
"partition_key to only be string type, they might fail to parse the non-string value."
)

self.max_size_in_bytes = max_size_in_bytes #TODO: FIND REPLACEMENT - or constants.MAX_MESSAGE_LENGTH_BYTES
self.message = BatchMessage(data=[], multi_messages=False, properties=None)
self.max_size_in_bytes = max_size_in_bytes or constants.MAX_FRAME_SIZE_BYTES
self.message = BatchMessage(data=[])
yunhaoling marked this conversation as resolved.
Show resolved Hide resolved
self._partition_id = partition_id
self._partition_key = partition_key

set_message_partition_key(self.message, self._partition_key)
self._size = len(encode.encode_payload(b"", self.message))
# TODO: test whether we need to set partition key of a batch message, or setting each single message if enough
# this is performance related
#set_message_partition_key(self.message, self._partition_key)
self._size = pyutils.get_message_encoded_size(self.message)
self._count = 0

def __repr__(self):
Expand Down Expand Up @@ -562,8 +565,7 @@ def add(self, event_data):
)

trace_message(outgoing_event_data)
event_data_size = outgoing_event_data.message.get_message_encoded_size()

event_data_size = pyutils.get_message_encoded_size(outgoing_event_data.message)
# For a BatchMessage, if the encoded_message_size of event_data is < 256, then the overhead cost to encode that
# message into the BatchMessage would be 5 bytes, if >= 256, it would be 8 bytes.
size_after_add = (
Expand All @@ -579,7 +581,7 @@ def add(self, event_data):
)
)

self.message._body_gen.append(outgoing_event_data) # pylint: disable=protected-access
pyutils.add_batch(self.message, outgoing_event_data.message)
self._size = size_after_add
self._count += 1

Expand Down
33 changes: 15 additions & 18 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
transform_outbound_single_message,
)
from ._constants import TIMEOUT_SYMBOL
from .pyamqp import SendClient as PySendClient

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -120,17 +121,12 @@ def __init__(self, client, target, **kwargs):

def _create_handler(self, auth):
# type: (JWTTokenAuth) -> None
self._handler = SendClient(
self._handler = PySendClient(
self._client._address.hostname,
self._target,
auth=auth,
debug=self._client._config.network_tracing, # pylint:disable=protected-access
msg_timeout=self._timeout * 1000,
idle_timeout=self._idle_timeout,
error_policy=self._retry_policy,
keep_alive_interval=self._keep_alive,
client_name=self._name,
link_properties=self._link_properties,
properties=create_properties(self._client._config.user_agent), # pylint: disable=protected-access
idle_timeout=10,
network_trace=self._client._config.network_tracing
)

def _open_with_retry(self):
Expand All @@ -156,14 +152,14 @@ def _send_event_data(self, timeout_time=None, last_exception=None):
if self._unsent_events:
self._open()
self._set_msg_timeout(timeout_time, last_exception)
self._handler.queue_message(*self._unsent_events) # type: ignore
self._handler.wait() # type: ignore
self._unsent_events = self._handler.pending_messages # type: ignore
if self._outcome != constants.MessageSendResult.Ok:
if self._outcome == constants.MessageSendResult.Timeout:
self._condition = OperationTimeoutError("Send operation timed out")
if self._condition:
raise self._condition
self._handler.send_message(self._unsent_events[0])
self._unsent_events = None
# self._unsent_events = self._handler.pending_messages # type: ignore
# if self._outcome != constants.MessageSendResult.Ok:
# if self._outcome == constants.MessageSendResult.Timeout:
# self._condition = OperationTimeoutError("Send operation timed out")
# if self._condition:
# raise self._condition

def _send_event_data_with_retry(self, timeout=None):
# type: (Optional[float]) -> None
Expand Down Expand Up @@ -205,7 +201,8 @@ def _wrap_eventdata(
raise ValueError(
"The partition_key does not match the one of the EventDataBatch"
)
for event in event_data.message._body_gen: # pylint: disable=protected-access

for event in event_data.message.data: # pylint: disable=protected-access
trace_message(event, span)
wrapper_event_data = event_data # type:ignore
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _get_max_mesage_size(self):
self._max_message_size_on_link = (
self._producers[ # type: ignore
ALL_PARTITIONS
]._handler.message_handler._link.peer_max_message_size
]._handler._link.remote_max_message_size
or constants.MAX_MESSAGE_LENGTH_BYTES
)

Expand Down
68 changes: 40 additions & 28 deletions sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import uamqp

from ._constants import AMQP_MESSAGE_BODY_TYPE_MAP, AmqpMessageBodyType
from ..pyamqp.message import Message as PyMessage, Header as PyHeader, Properties as PyProperties


class DictMixin(object):
Expand Down Expand Up @@ -136,7 +137,7 @@ def __init__(self, **kwargs):
self._body = kwargs.get("value_body")
self._body_type = uamqp.MessageBodyType.Value

self._message = uamqp.message.Message(body=self._body, body_type=self._body_type)
#self._message = uamqp.message.Message(body=self._body, body_type=self._body_type)
header_dict = cast(Mapping, kwargs.get("header"))
self._header = AmqpMessageHeader(**header_dict) if "header" in kwargs else None
self._footer = kwargs.get("footer")
Expand Down Expand Up @@ -214,17 +215,18 @@ def _from_amqp_message(self, message):

def _to_outgoing_amqp_message(self):
message_header = None
if self.header:
message_header = uamqp.message.MessageHeader()
message_header.delivery_count = self.header.delivery_count
message_header.time_to_live = self.header.time_to_live
message_header.first_acquirer = self.header.first_acquirer
message_header.durable = self.header.durable
message_header.priority = self.header.priority
if self.header and any(self.header.values()):
message_header = PyHeader(
delivery_count=self.header.delivery_count,
ttl=self.header.time_to_live,
first_acquirer=self.header.first_acquirer,
durable=self.header.durable,
priority=self.header.priority
)

message_properties = None
if self.properties:
message_properties = uamqp.message.MessageProperties(
if self.properties and any(self.properties.values()):
message_properties = PyProperties(
message_id=self.properties.message_id,
user_id=self.properties.user_id,
to=self.properties.to,
Expand All @@ -238,33 +240,43 @@ def _to_outgoing_amqp_message(self):
if self.properties.absolute_expiry_time else None,
group_id=self.properties.group_id,
group_sequence=self.properties.group_sequence,
reply_to_group_id=self.properties.reply_to_group_id,
encoding=self._encoding
reply_to_group_id=self.properties.reply_to_group_id
)

amqp_body = self._message._body # pylint: disable=protected-access
if isinstance(amqp_body, uamqp.message.DataBody):
amqp_body_type = uamqp.MessageBodyType.Data
amqp_body = list(amqp_body.data)
elif isinstance(amqp_body, uamqp.message.SequenceBody):
amqp_body_type = uamqp.MessageBodyType.Sequence
amqp_body = list(amqp_body.data)
else:
# amqp_body is type of uamqp.message.ValueBody
amqp_body_type = uamqp.MessageBodyType.Value
amqp_body = amqp_body.data

return uamqp.message.Message(
body=amqp_body,
body_type=amqp_body_type,
# TODO: let's only support data body for prototyping
return PyMessage(
data=self._body,
header=message_header,
properties=message_properties,
application_properties=self.application_properties,
annotations=self.annotations,
message_annotations=self.annotations,
delivery_annotations=self.delivery_annotations,
footer=self.footer
)

# amqp_body = self._message._body # pylint: disable=protected-access
# if isinstance(amqp_body, uamqp.message.DataBody):
# amqp_body_type = uamqp.MessageBodyType.Data
# amqp_body = list(amqp_body.data)
# elif isinstance(amqp_body, uamqp.message.SequenceBody):
# amqp_body_type = uamqp.MessageBodyType.Sequence
# amqp_body = list(amqp_body.data)
# else:
# # amqp_body is type of uamqp.message.ValueBody
# amqp_body_type = uamqp.MessageBodyType.Value
# amqp_body = amqp_body.data
#
# return uamqp.message.Message(
# body=amqp_body,
# body_type=amqp_body_type,
# header=message_header,
# properties=message_properties,
# application_properties=self.application_properties,
# annotations=self.annotations,
# delivery_annotations=self.delivery_annotations,
# footer=self.footer
# )

@property
def body(self):
# type: () -> Any
Expand Down
Loading