diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py index c88edfa66292..5272d4fecd5f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py @@ -7,6 +7,7 @@ __version__ = VERSION +from ._constants import TransportType from ._producer_client import EventHubProducerClient from ._consumer_client import EventHubConsumerClient from ._client_base import EventHubSharedKeyCredential @@ -17,7 +18,6 @@ parse_connection_string, EventHubConnectionStringProperties ) -from ._constants import TransportType __all__ = [ "EventData", diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 4d0f017827a2..5696fe3e45bf 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -23,26 +23,48 @@ from azure.core.pipeline.policies import RetryMode -from ._pyamqp.client import AMQPClient -from ._pyamqp.message import Message -from ._pyamqp import constants, error as errors, utils as pyamqp_utils -from ._pyamqp.authentication import JWTTokenAuth -from .exceptions import _handle_exception, ClientClosedError +try: + from ._transport._uamqp_transport import UamqpTransport +except ImportError: + UamqpTransport = None +from ._transport._pyamqp_transport import PyamqpTransport +from .exceptions import ClientClosedError from ._configuration import Configuration from ._utils import utc_from_timestamp, parse_sas_credential from ._connection_manager import get_connection_manager from ._constants import ( CONTAINER_PREFIX, JWT_TOKEN_SCOPE, - MGMT_OPERATION, - MGMT_PARTITION_OPERATION, + READ_OPERATION, MGMT_STATUS_CODE, MGMT_STATUS_DESC, - READ_OPERATION + MGMT_OPERATION, + MGMT_PARTITION_OPERATION, ) +from ._pyamqp import utils as pyamqp_utils, error as errors if TYPE_CHECKING: from azure.core.credentials import TokenCredential + try: + from uamqp import Message as uamqp_Message + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + from ._transport._uamqp_transport import ( + EventHubSharedKeyCredential as uamqp_EventHubSharedKeyCredential, + ) + except ImportError: + uamqp_Message = None + uamqp_JWTTokenAuth = None + uamqp_EventHubSharedKeyCredential = None + from ._pyamqp.message import Message + from ._pyamqp.authentication import JWTTokenAuth + + CredentialTypes = Union[ + AzureSasCredential, + AzureNamedKeyCredential, + uamqp_EventHubSharedKeyCredential, + "EventHubSharedKeyCredential", + TokenCredential, + ] _LOGGER = logging.getLogger(__name__) _Address = collections.namedtuple("_Address", "hostname path") @@ -163,7 +185,7 @@ def _get_backoff_time(retry_mode, backoff_factor, backoff_max, retried_times): if retry_mode == RetryMode.Fixed: backoff_value = backoff_factor else: - backoff_value = backoff_factor * (2 ** retried_times) + backoff_value = backoff_factor * (2**retried_times) return min(backoff_max, backoff_value) @@ -262,8 +284,22 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument class ClientBase(object): # pylint:disable=too-many-instance-attributes - def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwargs): - # type: (str, str, Union[AzureSasCredential, TokenCredential, AzureNamedKeyCredential], Any) -> None + def __init__( + self, + fully_qualified_namespace: str, + eventhub_name: str, + credential: "CredentialTypes", + **kwargs: Any, + ) -> None: + self._uamqp_transport = kwargs.pop("uamqp_transport", False) + if not self._uamqp_transport: + self._amqp_transport = PyamqpTransport() + else: + try: + self._amqp_transport = UamqpTransport() + except TypeError: + raise ImportError("uamqp package is not installed") + self.eventhub_name = eventhub_name if not eventhub_name: raise ValueError("The eventhub name can not be None or empty.") @@ -273,16 +309,22 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg if isinstance(credential, AzureSasCredential): self._credential = EventhubAzureSasTokenCredential(credential) elif isinstance(credential, AzureNamedKeyCredential): - self._credential = EventhubAzureNamedKeyTokenCredential(credential) # type: ignore + self._credential = EventhubAzureNamedKeyTokenCredential(credential) + # TODO: see if pyamqp generated token works for uamqp + # if self._uamqp_transport: + # self._credential = UamqpTransport.create_named_key_token_credential(credential) # type: ignore + # else: + # raise NotImplementedError('pyamqp named key token credential') else: self._credential = credential # type: ignore self._keep_alive = kwargs.get("keep_alive", 30) self._auto_reconnect = kwargs.get("auto_reconnect", True) - self._mgmt_target = "amqps://{}/{}".format( - self._address.hostname, self.eventhub_name + self._auth_uri = f"sb://{self._address.hostname}{self._address.path}" + self._config = Configuration( + uamqp_transport=self._uamqp_transport, + hostname=self._address.hostname, + **kwargs, ) - self._auth_uri = "sb://{}{}".format(self._address.hostname, self._address.path) - self._config = Configuration(**kwargs) self._debug = self._config.network_tracing self._conn_manager = get_connection_manager(**kwargs) self._idle_timeout = kwargs.get("idle_timeout", None) @@ -298,14 +340,15 @@ def _from_connection_string(conn_str, **kwargs): if token and token_expiry: kwargs["credential"] = EventHubSASTokenCredential(token, token_expiry) elif policy and key: + # TODO: see if pyamqp generated token works for uamqp. pyamqp by default here, else uamqp kwargs["credential"] = EventHubSharedKeyCredential(policy, key) + # kwargs["credential"] = UamqpTransport.create_shared_key_credential(policy, key) return kwargs - def _create_auth(self): - # type: () -> JWTTokenAuth + def _create_auth(self) -> Union["uamqp_JWTTokenAuth", "JWTTokenAuth"]: """ - Create an ~uamqp.authentication.SASTokenAuth instance to authenticate - the session. + Create an ~uamqp.authentication.SASTokenAuth or pyamqp.JWTTokenAuth instance + to authenticate the session. """ try: # ignore mypy's warning because token_type is Optional @@ -313,20 +356,19 @@ def _create_auth(self): except AttributeError: token_type = b"jwt" if token_type == b"servicebus.windows.net:sastoken": - return JWTTokenAuth( - self._auth_uri, + return self._amqp_transport.create_token_auth( self._auth_uri, - functools.partial(self._credential.get_token, self._auth_uri) + functools.partial(self._credential.get_token, self._auth_uri), + token_type=token_type, + config=self._config, + update_token=True, # TODO: discarded by pyamqp transport ) - return JWTTokenAuth( - self._auth_uri, + return self._amqp_transport.create_token_auth( self._auth_uri, functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, - timeout=self._config.auth_timeout, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, + config=self._config, + update_token=False, ) def _close_connection(self): @@ -361,48 +403,40 @@ def _backoff( ) raise last_exception - def _management_request(self, mgmt_msg, op_type): - # type: (Message, bytes) -> Any + def _management_request( + self, mgmt_msg: Union["uamqp_Message", "Message"], op_type: bytes + ) -> Any: + # pylint:disable=assignment-from-none retried_times = 0 last_exception = None while retried_times <= self._config.max_retries: mgmt_auth = self._create_auth() - hostname = self._address.hostname - custom_endpoint_address = self._config.custom_endpoint_address - if self._config.transport_type.name == 'AmqpOverWebsocket': - hostname += '/$servicebus/websocket/' - if custom_endpoint_address: - custom_endpoint_address += '/$servicebus/websocket/' - mgmt_client = AMQPClient( - hostname, - auth=mgmt_auth, - network_trace=self._config.network_tracing, - transport_type=self._config.transport_type, - http_proxy=self._config.http_proxy, - custom_endpoint_address=custom_endpoint_address, - connection_verify=self._config.connection_verify + mgmt_client = self._amqp_transport.create_mgmt_client( + self._address, mgmt_auth=mgmt_auth, config=self._config ) try: mgmt_client.open() while not mgmt_client.client_ready(): time.sleep(0.05) - access_token = mgmt_auth.get_token() - + + access_token = self._amqp_transport.get_updated_token(mgmt_auth) if not access_token: _LOGGER.debug("Management client received an empty access token object") - elif not access_token.token: _LOGGER.debug("Management client received an empty token") - else: _LOGGER.debug(f"Management client token expires on: {datetime.fromtimestamp(access_token.expires_on)}") - - mgmt_msg.application_properties["security_token"] = access_token.token - response = mgmt_client.mgmt_request( + # TODO: double check whether access_token or access_token.token + mgmt_msg.application_properties[ + "security_token" + ] = access_token + + response = self._amqp_transport.mgmt_client_request( + mgmt_client, mgmt_msg, - operation=READ_OPERATION.decode(), - operation_type=op_type.decode(), + operation=READ_OPERATION, + operation_type=op_type, status_code_field=MGMT_STATUS_CODE, description_fields=MGMT_STATUS_DESC, ) @@ -415,30 +449,28 @@ def _management_request(self, mgmt_msg, op_type): if status_code < 400: return response if status_code in [401]: - raise errors.AuthenticationException( - errors.ErrorCondition.UnauthorizedAccess, - description="Management authentication failed. Status code: {}, Description: {!r}".format( - status_code, - description - ) - ) - if status_code in [404]: - raise errors.AMQPConnectionError( - errors.ErrorCondition.NotFound, - description="Management connection failed. Status code: {}, Description: {!r}".format( - status_code, - description - ) + raise self._amqp_transport.get_error( + self._amqp_transport.AUTH_EXCEPTION, + f"Management authentication failed. Status code: {status_code}, Description: {description!r}", + condition=errors.ErrorCondition.UnauthorizedAccess, ) - raise errors.AMQPConnectionError( - errors.ErrorCondition.UnknownError, - description="Management operation failed. Status code: {}, Description: {!r}".format( - status_code, - description + if status_code in [ + 404 + ]: # TODO: make sure the error surfaced is the same across pyamqp and uamqp + return self._amqp_transport.get_error( + self._amqp_transport.CONNECTION_ERROR, + f"Management connection failed. Status code: {status_code}, Description: {description!r}", + condition=errors.ErrorCondition.NotFound, ) + return self._amqp_transport.get_error( + self._amqp_transport.AMQP_CONNECTION_ERROR, + f"Management request error. Status code: {status_code}, Description: {description!r}", + condition=errors.ErrorCondition.UnknownError, ) except Exception as exception: # pylint: disable=broad-except - last_exception = _handle_exception(exception, self) + last_exception = self._amqp_transport._handle_exception( + exception, self + ) # pylint: disable=protected-access self._backoff( retried_times=retried_times, last_exception=last_exception ) @@ -457,12 +489,13 @@ def _add_span_request_attributes(self, span): span.add_attribute("message_bus.destination", self._address.path) span.add_attribute("peer.address", self._address.hostname) - def _get_eventhub_properties(self): - # type:() -> Dict[str, Any] - mgmt_msg = Message(application_properties={"name": self.eventhub_name}) + def _get_eventhub_properties(self) -> Dict[str, Any]: + mgmt_msg = self._amqp_transport.MESSAGE( + application_properties={"name": self.eventhub_name} + ) response = self._management_request(mgmt_msg, op_type=MGMT_OPERATION) output = {} - eh_info = response.value # type: Dict[bytes, Any] + eh_info: Dict[bytes, Any] = response.value if eh_info: output["eventhub_name"] = eh_info[b"name"].decode("utf-8") output["created_at"] = utc_from_timestamp( @@ -471,7 +504,7 @@ def _get_eventhub_properties(self): output["partition_ids"] = [ p.decode("utf-8") for p in eh_info[b"partition_ids"] ] - return output + return output def _get_partition_ids(self): # type:() -> List[str] @@ -479,7 +512,7 @@ def _get_partition_ids(self): def _get_partition_properties(self, partition_id): # type:(str) -> Dict[str, Any] - mgmt_msg = Message( + mgmt_msg = self._amqp_transport.MESSAGE( application_properties={ "name": self.eventhub_name, "partition": partition_id, @@ -524,9 +557,7 @@ def _create_handler(self, auth): def _check_closed(self): if self.closed: raise ClientClosedError( - "{} has been closed. Please create a new one to handle event data.".format( - self._name - ) + f"{self._name} has been closed. Please create a new one to handle event data." ) def _open(self): @@ -541,9 +572,9 @@ def _open(self): while not self._handler.client_ready(): time.sleep(0.05) self._max_message_size_on_link = ( - self._handler._link.remote_max_message_size - or constants.MAX_FRAME_SIZE_BYTES - ) # pylint: disable=protected-access + self._amqp_transport.get_remote_max_message_size(self._handler) + or self._amqp_transport.MAX_FRAME_SIZE_BYTES + ) self.running = True def _close_handler(self): @@ -556,12 +587,17 @@ def _close_connection(self): self._client._conn_manager.reset_connection_if_broken() # pylint: disable=protected-access def _handle_exception(self, exception): - if not self.running and isinstance(exception, TimeoutError): - exception = errors.AuthenticationException( - errors.ErrorCondition.InternalError, - description="Authorization timeout." - ) - return _handle_exception(exception, self) + if not self.running and isinstance( + exception, self._amqp_transport.TIMEOUT_EXCEPTION + ): + exception = self._amqp_transport.get_error( + self._amqp_transport.AUTH_EXCEPTION, + "Authorization timeout.", + condition=errors.ErrorCondition.InternalError, + ) + return self._amqp_transport._handle_exception( # pylint: disable=protected-access + exception, self + ) def _do_retryable_operation(self, operation, timeout=None, **kwargs): # pylint:disable=protected-access @@ -579,7 +615,7 @@ def _do_retryable_operation(self, operation, timeout=None, **kwargs): return operation( timeout_time=timeout_time, last_exception=last_exception, - **kwargs + **kwargs, ) return operation() except Exception as exception: # pylint:disable=broad-except diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 3ce98d6efda4..22bb8ad60470 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -2,9 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import json +import datetime import logging import uuid from typing import ( @@ -21,7 +22,6 @@ import six from ._utils import ( - set_message_partition_key, trace_message, utc_from_timestamp, transform_outbound_single_message, @@ -52,23 +52,30 @@ AmqpMessageHeader, AmqpMessageProperties, ) - -from ._pyamqp import constants, utils as pyutils -from ._pyamqp.message import BatchMessage, Message +from ._pyamqp.message import Message +from ._pyamqp._message_backcompat import LegacyMessage, LegacyBatchMessage +from ._transport._pyamqp_transport import PyamqpTransport if TYPE_CHECKING: + try: + from uamqp import uamqp_Message + except ImportError: + uamqp_Message = None import datetime - -PrimitiveTypes = Optional[Union[ - int, - float, - bytes, - bool, - str, - Dict, - List, - uuid.UUID, -]] + from ._transport._base import AmqpTransport + +PrimitiveTypes = Optional[ + Union[ + int, + float, + bytes, + bool, + str, + Dict, + List, + uuid.UUID, + ] +] _LOGGER = logging.getLogger(__name__) @@ -117,74 +124,76 @@ def __init__( self._sys_properties = None # type: Optional[Dict[bytes, Any]] if body is None: raise ValueError("EventData cannot be None.") + self._uamqp_message = None # Internal usage only for transforming AmqpAnnotatedMessage to outgoing EventData self._raw_amqp_message = AmqpAnnotatedMessage( # type: ignore data_body=body, annotations={}, application_properties={} ) - self.message = (self._raw_amqp_message._message) # pylint:disable=protected-access + self._message = None # amqp message to be set right before sending self._raw_amqp_message.header = AmqpMessageHeader() self._raw_amqp_message.properties = AmqpMessageProperties() self.message_id = None self.content_type = None self.correlation_id = None - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: # pylint: disable=bare-except try: # TODO: below call won't work b/c pyamqp.message.message doesn't have body_type body_str = self.body_as_str() except: body_str = "" - event_repr = "body='{}'".format(body_str) + event_repr = f"body='{body_str}'" try: - event_repr += ", properties={}".format(self.properties) + event_repr += f", properties={self.properties}" except: event_repr += ", properties=" try: - event_repr += ", offset={}".format(self.offset) + event_repr += f", offset={self.offset}" except: event_repr += ", offset=" try: - event_repr += ", sequence_number={}".format(self.sequence_number) + event_repr += f", sequence_number={self.sequence_number}" except: event_repr += ", sequence_number=" try: - event_repr += ", partition_key={!r}".format(self.partition_key) + event_repr += f", partition_key={self.partition_key!r}" except: event_repr += ", partition_key=" try: - event_repr += ", enqueued_time={!r}".format(self.enqueued_time) + event_repr += f", enqueued_time={self.enqueued_time!r}" except: event_repr += ", enqueued_time=" - return "EventData({})".format(event_repr) + return f"EventData({event_repr})" - def __str__(self): - # type: () -> str + def __str__(self) -> str: try: body_str = self.body_as_str() except: # pylint: disable=bare-except body_str = "" - event_str = "{{ body: '{}'".format(body_str) + event_str = f"{{ body: '{body_str}'" try: - event_str += ", properties: {}".format(self.properties) + event_str += f", properties: {self.properties}" if self.offset: - event_str += ", offset: {}".format(self.offset) + event_str += f", offset: {self.offset}" if self.sequence_number: - event_str += ", sequence_number: {}".format(self.sequence_number) + event_str += f", sequence_number: {self.sequence_number}" if self.partition_key: - event_str += ", partition_key={!r}".format(self.partition_key) + event_str += f", partition_key={self.partition_key!r}" if self.enqueued_time: - event_str += ", enqueued_time={!r}".format(self.enqueued_time) + event_str += f", enqueued_time={self.enqueued_time!r}" except: # pylint: disable=bare-except pass event_str += " }" return event_str @classmethod - def _from_message(cls, message, raw_amqp_message=None): - # type: (Message, Optional[AmqpAnnotatedMessage]) -> EventData + def _from_message( + cls, + message: Union["uamqp_Message", Message], + raw_amqp_message: Optional[AmqpAnnotatedMessage] = None, + ) -> EventData: # pylint:disable=protected-access """Internal use only. @@ -195,18 +204,16 @@ def _from_message(cls, message, raw_amqp_message=None): :rtype: ~azure.eventhub.EventData """ event_data = cls(body="") - event_data.message = message + event_data._message = message # pylint: disable=protected-access - event_data._raw_amqp_message = raw_amqp_message if raw_amqp_message else AmqpAnnotatedMessage(message=message) + event_data._raw_amqp_message = ( + raw_amqp_message + if raw_amqp_message + else AmqpAnnotatedMessage(message=message) + ) return event_data - def _encode_message(self): - # type: () -> bytes - # pylint: disable=protected-access - return self._raw_amqp_message._message.encode_message() - - def _decode_non_data_body_as_str(self, encoding="UTF-8"): - # type: (str) -> str + def _decode_non_data_body_as_str(self, encoding: str = "UTF-8") -> str: # pylint: disable=protected-access body = self.raw_amqp_message.body if self.body_type == AmqpMessageBodyType.VALUE: @@ -217,20 +224,28 @@ def _decode_non_data_body_as_str(self, encoding="UTF-8"): seq_list = [d for seq_section in body for d in seq_section] return str(decode_with_recurse(seq_list, encoding)) - def _to_outgoing_message(self): - # type: () -> EventData - self.message = (self._raw_amqp_message._to_outgoing_amqp_message()) # pylint:disable=protected-access - return self + @property + def message(self) -> LegacyMessage: + if not self._uamqp_message: + self._uamqp_message = LegacyMessage( + self._raw_amqp_message, + to_outgoing_amqp_message=PyamqpTransport().to_outgoing_amqp_message, + ) + return self._uamqp_message + + # TODO: make message property mutable + @message.setter + def message(self, value: Union["uamqp_Message", Message]) -> None: + self._message = value + self._raw_amqp_message = AmqpAnnotatedMessage(message=value) @property - def raw_amqp_message(self): - # type: () -> AmqpAnnotatedMessage + def raw_amqp_message(self) -> AmqpAnnotatedMessage: """Advanced usage only. The internal AMQP message payload that is sent or received.""" return self._raw_amqp_message @property - def sequence_number(self): - # type: () -> Optional[int] + def sequence_number(self) -> Optional[int]: """The sequence number of the event. :rtype: int @@ -238,8 +253,7 @@ def sequence_number(self): return self._raw_amqp_message.annotations.get(PROP_SEQ_NUMBER, None) @property - def offset(self): - # type: () -> Optional[str] + def offset(self) -> Optional[str]: """The offset of the event. :rtype: str @@ -250,8 +264,7 @@ def offset(self): return None @property - def enqueued_time(self): - # type: () -> Optional[datetime.datetime] + def enqueued_time(self) -> Optional[datetime.datetime]: """The enqueued timestamp of the event. :rtype: datetime.datetime @@ -262,20 +275,20 @@ def enqueued_time(self): return None @property - def partition_key(self): - # type: () -> Optional[bytes] + def partition_key(self) -> Optional[bytes]: """The partition key of the event. :rtype: bytes """ - try: - return self._raw_amqp_message.annotations[PROP_PARTITION_KEY] - except KeyError: - return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) + # TODO: Ask Anna. I think just trying this is reasonable? Haven't seen a case where symbol is used to get. + # try: + # return self._raw_amqp_message.annotations[types.AMQPSymbol(PROP_PARTITION_KEY)] + # except KeyError: + # return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) + return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) @property - def properties(self): - # type: () -> Dict[Union[str, bytes], Any] + def properties(self) -> Dict[Union[str, bytes], Any]: """Application-defined properties on the event. :rtype: dict @@ -283,7 +296,7 @@ def properties(self): return self._raw_amqp_message.application_properties @properties.setter - def properties(self, value): + def properties(self, value: Dict[Union[str, bytes], Any]): # type: (Dict[Union[str, bytes], Any]) -> None """Application-defined properties on the event. @@ -293,8 +306,7 @@ def properties(self, value): self._raw_amqp_message.application_properties = properties @property - def system_properties(self): - # type: () -> Dict[bytes, Any] + def system_properties(self) -> Dict[bytes, Any]: """Metadata set by the Event Hubs Service associated with the event. An EventData could have some or all of the following meta data depending on the source @@ -332,8 +344,7 @@ def system_properties(self): return self._sys_properties @property - def body(self): - # type: () -> PrimitiveTypes + def body(self) -> PrimitiveTypes: """The body of the Message. The format may vary depending on the body type: For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, the body could be bytes or Iterable[bytes]. @@ -350,16 +361,14 @@ def body(self): raise ValueError("Event content empty.") @property - def body_type(self): - # type: () -> AmqpMessageBodyType + def body_type(self) -> AmqpMessageBodyType: """The body type of the underlying AMQP message. :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType """ return self._raw_amqp_message.body_type - def body_as_str(self, encoding="UTF-8"): - # type: (str) -> str + def body_as_str(self, encoding: str = "UTF-8") -> str: """The content of the event as a string, if the data is of a compatible type. :param encoding: The encoding to use for decoding event data. @@ -378,12 +387,9 @@ def body_as_str(self, encoding="UTF-8"): try: return cast(bytes, data).decode(encoding) except Exception as e: - raise TypeError( - "Message data is not compatible with string type: {}".format(e) - ) + raise TypeError(f"Message data is not compatible with string type: {e}") - def body_as_json(self, encoding="UTF-8"): - # type: (str) -> Dict[str, Any] + def body_as_json(self, encoding: str = "UTF-8") -> Dict[str, Any]: """The content of the event loaded as a JSON object, if the data is compatible. :param encoding: The encoding to use for decoding event data. @@ -394,11 +400,10 @@ def body_as_json(self, encoding="UTF-8"): try: return json.loads(data_str) except Exception as e: - raise TypeError("Event data is not compatible with JSON type: {}".format(e)) + raise TypeError(f"Event data is not compatible with JSON type: {e}") @property - def content_type(self): - # type: () -> Optional[str] + def content_type(self) -> Optional[str]: """The content type descriptor. Optionally describes the payload of the message, with a descriptor following the format of RFC2045, Section 5, for example "application/json". @@ -412,15 +417,13 @@ def content_type(self): return self._raw_amqp_message.properties.content_type @content_type.setter - def content_type(self, value): - # type: (str) -> None + def content_type(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.content_type = value @property - def correlation_id(self): - # type: () -> Optional[str] + def correlation_id(self) -> Optional[str]: """The correlation identifier. Allows an application to specify a context for the message for the purposes of correlation, for example reflecting the MessageId of a message that is being replied to. @@ -434,15 +437,13 @@ def correlation_id(self): return self._raw_amqp_message.properties.correlation_id @correlation_id.setter - def correlation_id(self, value): - # type: (str) -> None + def correlation_id(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.correlation_id = value @property - def message_id(self): - # type: () -> Optional[str] + def message_id(self) -> Optional[str]: """The id to identify the message. The message identifier is an application-defined value that uniquely identifies the message and its payload. The identifier is a free-form string and can reflect a GUID or an identifier derived from the @@ -458,7 +459,7 @@ def message_id(self): return self._raw_amqp_message.properties.message_id @message_id.setter - def message_id(self, value): + def message_id(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.message_id = value @@ -489,8 +490,14 @@ class EventDataBatch(object): Event Hub decided by the service. """ - def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None): - # type: (Optional[int], Optional[str], Optional[Union[str, bytes]]) -> None + def __init__( + self, + max_size_in_bytes: Optional[int] = None, + partition_id: Optional[str] = None, + partition_key: Optional[Union[str, bytes]] = None, + **kwargs, + ) -> None: + self._amqp_transport = kwargs.pop("amqp_transport") if partition_key and not isinstance( partition_key, (six.text_type, six.binary_type) @@ -502,33 +509,52 @@ def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None "partition_key to only be string type, they might fail to parse the non-string value." ) - self.max_size_in_bytes = max_size_in_bytes or constants.MAX_FRAME_SIZE_BYTES - self.message = BatchMessage(data=[]) + self.max_size_in_bytes = ( + max_size_in_bytes or self._amqp_transport.MAX_FRAME_SIZE_BYTES + ) + self._message = self._amqp_transport.BATCH_MESSAGE(data=[]) self._partition_id = partition_id self._partition_key = partition_key - self.message = set_message_partition_key(self.message, self._partition_key) - self._size = pyutils.get_message_encoded_size(self.message) - self._count = 0 - def __repr__(self): - # type: () -> str - batch_repr = "max_size_in_bytes={}, partition_id={}, partition_key={!r}, event_count={}".format( - self.max_size_in_bytes, self._partition_id, self._partition_key, self._count + self._message = self._amqp_transport.set_message_partition_key( + self._message, self._partition_key ) - return "EventDataBatch({})".format(batch_repr) + self._size = self._amqp_transport.get_batch_message_encoded_size(self._message) + self._count = 0 + self._internal_events: List[ + Union[EventData, AmqpAnnotatedMessage] + ] = [] # TODO: only used by uamqp + self._uamqp_message = None + + def __repr__(self) -> str: + batch_repr = ( + f"max_size_in_bytes={self.max_size_in_bytes}, partition_id={self._partition_id}, " + f"partition_key={self._partition_key!r}, event_count={self._count}" + ) + return f"EventDataBatch({batch_repr})" - def __len__(self): + def __len__(self) -> int: return self._count @classmethod - def _from_batch(cls, batch_data, partition_key=None): - # type: (Iterable[EventData], Optional[AnyStr]) -> EventDataBatch - outgoing_batch_data = [transform_outbound_single_message(m, EventData) for m in batch_data] - batch_data_instance = cls(partition_key=partition_key) - + def _from_batch( + cls, + batch_data: Iterable[EventData], + amqp_transport: AmqpTransport, + partition_key: Optional[AnyStr] = None, + ) -> EventDataBatch: + outgoing_batch_data = [ + transform_outbound_single_message( + m, EventData, amqp_transport.to_outgoing_amqp_message + ) + for m in batch_data + ] + batch_data_instance = cls( + partition_key=partition_key, amqp_transport=amqp_transport + ) + for event_data in outgoing_batch_data: batch_data_instance.add(event_data) - return batch_data_instance def _load_events(self, events): @@ -543,16 +569,21 @@ def _load_events(self, events): ) @property - def size_in_bytes(self): - # type: () -> int + def size_in_bytes(self) -> int: """The combined size of the events in the batch, in bytes. :rtype: int """ return self._size - def add(self, event_data): - # type: (Union[EventData, AmqpAnnotatedMessage]) -> None + @property + def message(self) -> LegacyBatchMessage: + if not self._uamqp_message: + message = AmqpAnnotatedMessage(message=Message(*self._message)) + self._uamqp_message = LegacyBatchMessage(message) + return self._uamqp_message + + def add(self, event_data: Union[EventData, AmqpAnnotatedMessage]) -> None: """Try to add an EventData to the batch. The total size of an added event is the sum of its body, properties, etc. @@ -565,7 +596,9 @@ def add(self, event_data): :raise: :class:`ValueError`, when exceeding the size limit. """ - outgoing_event_data = transform_outbound_single_message(event_data, EventData) + outgoing_event_data = transform_outbound_single_message( + event_data, EventData, self._amqp_transport.to_outgoing_amqp_message + ) if self._partition_key: if ( @@ -576,12 +609,15 @@ def add(self, event_data): "The partition key of event_data does not match the partition key of this batch." ) if not outgoing_event_data.partition_key: - outgoing_event_data.message = set_message_partition_key( - outgoing_event_data.message, self._partition_key + self._amqp_transport.set_message_partition_key( + outgoing_event_data._message, # pylint: disable=protected-access + self._partition_key, ) trace_message(outgoing_event_data) - event_data_size = pyutils.get_message_encoded_size(outgoing_event_data.message) + event_data_size = self._amqp_transport.get_message_encoded_size( + outgoing_event_data._message # pylint: disable=protected-access + ) # For a BatchMessage, if the encoded_message_size of event_data is < 256, then the overhead cost to encode that # message into the BatchMessage would be 5 bytes, if >= 256, it would be 8 bytes. size_after_add = ( @@ -592,11 +628,9 @@ def add(self, event_data): if size_after_add > self.max_size_in_bytes: raise ValueError( - "EventDataBatch has reached its size limit: {}".format( - self.max_size_in_bytes - ) + f"EventDataBatch has reached its size limit: {self.max_size_in_bytes}" ) - pyutils.add_batch(self.message, outgoing_event_data.message) + self._amqp_transport.add_batch(self, outgoing_event_data, event_data) self._size = size_after_add self._count += 1 diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py index e9aaeb17e1a9..b8306d284398 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py @@ -6,7 +6,6 @@ from urllib.parse import urlparse from azure.core.pipeline.policies import RetryMode - from ._constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT @@ -34,10 +33,14 @@ def __init__(self, **kwargs): self.connection_verify = kwargs.get("connection_verify") # type: Optional[str] self.connection_port = DEFAULT_AMQPS_PORT self.custom_endpoint_hostname = None + self.hostname = kwargs.pop("hostname") + uamqp_transport = kwargs.pop("uamqp_transport") if self.http_proxy or self.transport_type == TransportType.AmqpOverWebsocket: self.transport_type = TransportType.AmqpOverWebsocket self.connection_port = DEFAULT_AMQP_WSS_PORT + if not uamqp_transport: + self.hostname += "/$servicebus/websocket" # custom end point if self.custom_endpoint_address: @@ -48,5 +51,7 @@ def __init__(self, **kwargs): endpoint = urlparse(self.custom_endpoint_address) self.transport_type = TransportType.AmqpOverWebsocket self.custom_endpoint_hostname = endpoint.hostname + if not uamqp_transport: + self.custom_endpoint_address += "/$servicebus/websocket" # in case proxy and custom endpoint are both provided, we default port to 443 if it's not provided self.connection_port = endpoint.port or DEFAULT_AMQP_WSS_PORT diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py index 66dace638f3e..cdcfbdddd3f4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py @@ -3,24 +3,32 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from threading import Lock from enum import Enum -from ._pyamqp._connection import Connection, _CLOSING_STATES from ._constants import TransportType -if TYPE_CHECKING: - from uamqp.authentication import JWTTokenAuth +from ._pyamqp._connection import Connection, _CLOSING_STATES +if TYPE_CHECKING: try: from typing_extensions import Protocol except ImportError: Protocol = object # type: ignore + try: + from uamqp import Connection as uamqp_Connection + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + except ImportError: + uamqp_Connection = None + uamqp_JWTTokenAuth = None + from ._pyamqp.authentication import JWTTokenAuth + class ConnectionManager(Protocol): - def get_connection(self, host, auth): - # type: (str, 'JWTTokenAuth') -> Connection + def get_connection( + self, host: str, auth: Union["uamqp_JWTTokenAuth", "JWTTokenAuth"] + ) -> Union["uamqp_Connection", "Connection"]: pass def close_connection(self): @@ -35,10 +43,11 @@ class _ConnectionMode(Enum): SeparateConnection = 2 +# TODO: see if we want to use this, and if so, make compatible with uamqp class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes def __init__(self, **kwargs): self._lock = Lock() - self._conn = None # type: Connection + self._conn: Union["uamqp_Connection", "Connection"] = None self._container_id = kwargs.get("container_id") self._debug = kwargs.get("debug") @@ -79,6 +88,7 @@ def close_connection(self): self._conn.close() self._conn = None + # TODO: fix and add uamqp stuff def reset_connection_if_broken(self): # type: () -> None with self._lock: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index 8790a22d2a69..62f9c48348f4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -3,23 +3,13 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- from __future__ import unicode_literals +from multiprocessing import Event import time import uuid import logging from collections import deque -from typing import TYPE_CHECKING, Callable, Dict, Optional, Any, Deque -from urllib.parse import urlparse - -from ._pyamqp import ( - ReceiveClient, - types, - utils as pyamqp_utils, - error, - constants as pyamqp_constants -) -from ._pyamqp.endpoints import Source, ApacheFilters -from ._pyamqp.message import Message +from typing import TYPE_CHECKING, Callable, Dict, Optional, Any, Deque, Union from ._common import EventData from ._client_base import ConsumerProducerMixin @@ -28,11 +18,20 @@ EPOCH_SYMBOL, TIMEOUT_SYMBOL, RECEIVER_RUNTIME_METRIC_SYMBOL, - NO_RETRY_ERRORS, - CUSTOM_CONDITION_BACKOFF, ) if TYPE_CHECKING: + from typing import Deque + try: + from uamqp import ReceiveClient as uamqp_ReceiveClient, Message as uamqp_Message + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + except ImportError: + uamqp_ReceiveClient = None + uamqp_Message = None + uamqp_JWTTokenAuth = None + + from ._pyamqp import ReceiveClient + from ._pyamqp.message import Message from ._pyamqp.authentication import JWTTokenAuth from ._consumer_client import EventHubConsumerClient @@ -76,8 +75,7 @@ class EventHubConsumer( It is set to `False` by default. """ - def __init__(self, client, source, **kwargs): - # type: (EventHubConsumerClient, str, Any) -> None + def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs: Any) -> None: event_position = kwargs.get("event_position", None) prefetch = kwargs.get("prefetch", 300) owner_level = kwargs.get("owner_level", None) @@ -93,9 +91,10 @@ def __init__(self, client, source, **kwargs): self.stop = False # used by event processor self.handler_ready = False - self._on_event_received = kwargs[ + self._amqp_transport = kwargs.pop("amqp_transport") + self._on_event_received: Callable[[EventData], None] = kwargs[ "on_event_received" - ] # type: Callable[[EventData], None] + ] self._client = client self._source = source self._offset = event_position @@ -104,87 +103,62 @@ def __init__(self, client, source, **kwargs): self._owner_level = owner_level self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect - self._retry_policy = error.RetryPolicy( - retry_total=self._client._config.max_retries, # pylint:disable=protected-access - retry_backoff_factor=self._client._config.backoff_factor, # pylint:disable=protected-access - retry_backoff_max=self._client._config.backoff_max, # pylint:disable=protected-access - retry_mode=self._client._config.retry_mode, # pylint:disable=protected-access - no_retry_condition=NO_RETRY_ERRORS, - custom_condition_backoff=CUSTOM_CONDITION_BACKOFF, - ) + self._retry_policy = self._amqp_transport.create_retry_policy(self._client._config) self._reconnect_backoff = 1 - self._link_properties = {} # type: Dict[types.AMQPType, types.AMQPType] + link_properties: Dict[bytes, int] = {} self._error = None self._timeout = 0 - self._idle_timeout = idle_timeout if idle_timeout else None + self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None partition = self._source.split("/")[-1] self._partition = partition - self._name = "EHConsumer-{}-partition{}".format(uuid.uuid4(), partition) + self._name = f"EHConsumer-{uuid.uuid4()}-partition{partition}" if owner_level is not None: - self._link_properties[EPOCH_SYMBOL] = pyamqp_utils.amqp_long_value(int(owner_level)) + link_properties[EPOCH_SYMBOL] = int(owner_level) link_property_timeout_ms = ( self._client._config.receive_timeout or self._timeout # pylint:disable=protected-access - ) * 1000 - self._link_properties[TIMEOUT_SYMBOL] = pyamqp_utils.amqp_long_value(int(link_property_timeout_ms)) - self._handler = None # type: Optional[ReceiveClient] + ) * self._amqp_transport.IDLE_TIMEOUT_FACTOR + link_properties[TIMEOUT_SYMBOL] = int(link_property_timeout_ms) + self._link_properties = self._amqp_transport.create_link_properties(link_properties) + self._handler: Optional[Union["ReceiveClient", "uamqp_ReceiveClient"]] = None self._track_last_enqueued_event_properties = ( track_last_enqueued_event_properties ) - self._message_buffer = deque() # type: Deque[Message] - self._last_received_event = None # type: Optional[EventData] - self._receive_start_time = None # type: Optional[float] + self._message_buffer: Deque[Union["Message", "uamqp_Message"]] = deque() + self._last_received_event: Optional[EventData] = None + self._receive_start_time: Optional[float]= None def _create_handler(self, auth): # type: (JWTTokenAuth) -> None - source = Source(address=self._source, filters={}) - if self._offset is not None: - filter_key = ApacheFilters.selector_filter - source.filters[filter_key] = ( - filter_key, - pyamqp_utils.amqp_string_value( - event_position_selector( - self._offset, - self._offset_inclusive - ) - ) - ) + source = self._amqp_transport.create_source( + self._source, + self._offset, + event_position_selector(self._offset, self._offset_inclusive) + ) desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None - custom_endpoint_address = self._client._config.custom_endpoint_address # pylint:disable=protected-access - transport_type = self._client._config.transport_type # pylint:disable=protected-access - hostname = urlparse(source.address).hostname - if transport_type.name == 'AmqpOverWebsocket': - hostname += '/$servicebus/websocket/' - if custom_endpoint_address: - custom_endpoint_address += '/$servicebus/websocket/' - - self._handler = ReceiveClient( - hostname, - source, + self._handler = self._amqp_transport.create_receive_client( + config=self._client._config, # pylint:disable=protected-access + source=source, auth=auth, - idle_timeout=self._idle_timeout, network_trace=self._client._config.network_tracing, # pylint:disable=protected-access - transport_type=transport_type, - http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access link_credit=self._prefetch, link_properties=self._link_properties, + idle_timeout=self._idle_timeout, retry_policy=self._retry_policy, + keep_alive_interval=self._keep_alive, client_name=self._name, - receive_settle_mode=pyamqp_constants.ReceiverSettleMode.First, - properties=create_properties(self._client._config.user_agent), # pylint:disable=protected-access + properties=create_properties( + self._client._config.user_agent, amqp_transport=self._amqp_transport # pylint:disable=protected-access + ), desired_capabilities=desired_capabilities, streaming_receive=True, message_received_callback=self._message_received, - custom_endpoint_address=custom_endpoint_address, - connection_verify=self._client._config.connection_verify, ) - def _open_with_retry(self): - # type: () -> None + def _open_with_retry(self) -> None: self._do_retryable_operation(self._open, operation_need_param=False) - def _message_received(self, message): - # type: (Message) -> None + def _message_received(self, message: Union["Message", "uamqp_Message"]) -> None: # pylint:disable=protected-access self._message_buffer.append(message) @@ -195,10 +169,8 @@ def _next_message_in_buffer(self): self._last_received_event = event_data return event_data - def _open(self): - # type: () -> bool + def _open(self) -> bool: """Open the EventHubConsumer/EventHubProducer using the supplied connection. - """ # pylint: disable=protected-access if not self.running: @@ -206,7 +178,7 @@ def _open(self): self._handler.close() auth = self._client._create_auth() self._create_handler(auth) - self._handler.open() + self._handler.open() # TODO: uamqp handler is not using the passed in connection anyway while not self._handler.client_ready(): time.sleep(0.05) self.handler_ready = True @@ -228,12 +200,14 @@ def receive(self, batch=False, max_batch_size=300, max_wait_time=None): while retried_times <= max_retries: try: if self._open(): + # TODO: for pyamqp, this will pass in batch. But, in the ReceiveClient._client_run, + # can remove (batch=self._link_credit)? self._handler.do_work(batch=self._prefetch) # type: ignore break except Exception as exception: # pylint: disable=broad-except if ( - isinstance(exception, error.AMQPLinkError) - and exception.condition == error.ErrorCondition.LinkStolen # pylint: disable=no-member + isinstance(exception, self._amqp_transport.AMQP_LINK_ERROR) + and exception.condition == self._amqp_transport.LINK_STOLEN_CONDITION # pylint: disable=no-member ): raise self._handle_exception(exception) if not self.running: # exit by close diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py index fa00d9aa5dd5..f8c4c3361a2f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py @@ -136,6 +136,7 @@ def __init__( **kwargs # type: Any ): # type: (...) -> None + self._checkpoint_store = kwargs.pop("checkpoint_store", None) self._load_balancing_interval = kwargs.pop("load_balancing_interval", None) if self._load_balancing_interval is None: @@ -200,6 +201,7 @@ def _create_consumer( prefetch=prefetch, idle_timeout=self._idle_timeout, track_last_enqueued_event_properties=track_last_enqueued_event_properties, + amqp_transport=self._amqp_transport, ) return handler diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 0ebc1f62a548..878995065d93 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -6,7 +6,6 @@ import uuid import logging -import time import threading from typing import ( Iterable, @@ -18,45 +17,49 @@ TYPE_CHECKING, ) # pylint: disable=unused-import -from azure.core.tracing import AbstractSpan - -from .exceptions import OperationTimeoutError from ._common import EventData, EventDataBatch from ._client_base import ConsumerProducerMixin from ._utils import ( create_properties, - set_message_partition_key, trace_message, send_context_manager, transform_outbound_single_message, ) from ._constants import ( TIMEOUT_SYMBOL, - NO_RETRY_ERRORS, - CUSTOM_CONDITION_BACKOFF -) -from ._pyamqp import ( - error, - utils as pyamqp_utils, - SendClient ) -_LOGGER = logging.getLogger(__name__) - if TYPE_CHECKING: - from uamqp.authentication import JWTTokenAuth # pylint: disable=ungrouped-imports + from azure.core.tracing import AbstractSpan + + try: + from uamqp import constants as uamqp_constants, SendClient as uamqp_SendClient + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + except ImportError: + uamqp_constants = None + uamqp_SendClient = None + uamqp_JWTTokenAuth = None + from ._pyamqp import SendClient + from ._pyamqp.authentication import JWTTokenAuth + from ._transport._base import AmqpTransport from ._producer_client import EventHubProducerClient +_LOGGER = logging.getLogger(__name__) + -def _set_partition_key(event_datas, partition_key): - # type: (Iterable[EventData], AnyStr) -> Iterable[EventData] +def _set_partition_key( + event_datas: Iterable[EventData], + partition_key: AnyStr, + amqp_transport: "AmqpTransport", +) -> Iterable[EventData]: for ed in iter(event_datas): - set_message_partition_key(ed.message, partition_key) + amqp_transport.set_message_partition_key(ed._message, partition_key) # pylint: disable=protected-access yield ed -def _set_trace_message(event_datas, parent_span=None): - # type: (Iterable[EventData], Optional[AbstractSpan]) -> Iterable[EventData] +def _set_trace_message( + event_datas: Iterable[EventData], parent_span: Optional["AbstractSpan"] = None +) -> Iterable[EventData]: for ed in iter(event_datas): trace_message(ed, parent_span) yield ed @@ -87,8 +90,11 @@ class EventHubProducer( Default value is `True`. """ - def __init__(self, client, target, **kwargs): - # type: (EventHubProducerClient, str, Any) -> None + def __init__( + self, client: "EventHubProducerClient", target: str, **kwargs: Any + ) -> None: + + self._amqp_transport = kwargs.pop("amqp_transport") partition = kwargs.get("partition", None) send_timeout = kwargs.get("send_timeout", 60) keep_alive = kwargs.get("keep_alive", None) @@ -103,79 +109,95 @@ def __init__(self, client, target, **kwargs): self._target = target self._partition = partition self._timeout = send_timeout - self._idle_timeout = idle_timeout if idle_timeout else None + self._idle_timeout = ( + (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) + if idle_timeout + else None + ) self._error = None self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect - self._retry_policy = error.RetryPolicy( - retry_total=self._client._config.max_retries, # pylint: disable=protected-access - no_retry_condition=NO_RETRY_ERRORS, - custom_condition_backoff=CUSTOM_CONDITION_BACKOFF + self._retry_policy = self._amqp_transport.create_retry_policy( + config=self._client._config ) self._reconnect_backoff = 1 - self._name = "EHProducer-{}".format(uuid.uuid4()) - self._unsent_events = [] # type: List[Any] + self._name = f"EHProducer-{uuid.uuid4()}" + self._unsent_events: List[Any] = [] if partition: self._target += "/Partitions/" + partition - self._name += "-partition{}".format(partition) - self._handler = None # type: Optional[SendClient] - self._condition = None # type: Optional[Exception] + self._name += f"-partition{partition}" + self._handler: Optional[Union["SendClient", "uamqp_SendClient"]] = None + self._outcome: Optional["uamqp_constants.MessageSendResult"] = None + self._condition: Optional[Exception] = None self._lock = threading.Lock() - self._link_properties = {TIMEOUT_SYMBOL: pyamqp_utils.amqp_long_value(int(self._timeout * 1000))} + self._link_properties = self._amqp_transport.create_link_properties( + {TIMEOUT_SYMBOL: int(self._timeout * 1000)} + ) - def _create_handler(self, auth): - # type: (JWTTokenAuth) -> None - transport_type = self._client._config.transport_type # pylint:disable=protected-access - custom_endpoint_address = self._client._config.custom_endpoint_address # pylint: disable=protected-access - hostname = self._client._address.hostname # pylint: disable=protected-access - if transport_type.name == 'AmqpOverWebsocket': - hostname += '/$servicebus/websocket/' - if custom_endpoint_address: - custom_endpoint_address += '/$servicebus/websocket/' - self._handler = SendClient( - hostname, - self._target, + def _create_handler( + self, auth: Union["uamqp_JWTTokenAuth", "JWTTokenAuth"] + ) -> None: + self._handler = self._amqp_transport.create_send_client( + config=self._client._config, # pylint:disable=protected-access + target=self._target, auth=auth, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access idle_timeout=self._idle_timeout, - network_trace=self._client._config.network_tracing, # pylint:disable=protected-access - transport_type=transport_type, - http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, link_properties=self._link_properties, - properties=create_properties(self._client._config.user_agent), # pylint: disable=protected-access - custom_endpoint_address=custom_endpoint_address, - connection_verify=self._client._config.connection_verify + properties=create_properties( + self._client._config.user_agent, # pylint: disable=protected-access + amqp_transport=self._amqp_transport, + ), + msg_timeout=self._timeout * 1000, ) - def _open_with_retry(self): - # type: () -> None + def _open_with_retry(self) -> None: return self._do_retryable_operation(self._open, operation_need_param=False) - def _send_event_data(self, timeout_time=None): - # type: (Optional[float]) -> None + def _on_outcome( + self, + outcome: "uamqp_constants.MessageSendResult", + condition: Optional[Exception], + ) -> None: + """ + Called when the outcome is received for a delivery. + :param outcome: The outcome of the message delivery - success or failure. + :type outcome: ~uamqp.constants.MessageSendResult + :param condition: Detail information of the outcome. + """ + self._outcome = outcome + self._condition = condition + + def _send_event_data( + self, + timeout_time: Optional[float] = None, + last_exception: Optional[Exception] = None, + ) -> None: if self._unsent_events: - self._open() - timeout = timeout_time - time.time() if timeout_time else 0 - self._handler.send_message(self._unsent_events[0], timeout=timeout) - self._unsent_events = None + self._amqp_transport.send_messages( + self, timeout_time, last_exception, _LOGGER + ) - def _send_event_data_with_retry(self, timeout=None): - # type: (Optional[float]) -> None + def _send_event_data_with_retry(self, timeout: Optional[float] = None) -> None: return self._do_retryable_operation(self._send_event_data, timeout=timeout) - @staticmethod def _wrap_eventdata( - event_data, # type: Union[EventData, EventDataBatch, Iterable[EventData]] - span, # type: Optional[AbstractSpan] - partition_key, # type: Optional[AnyStr] - ): - # type: (...) -> Union[EventData, EventDataBatch] + self, + event_data: Union[EventData, EventDataBatch, Iterable[EventData]], + span: Optional["AbstractSpan"], + partition_key: Optional[AnyStr], + ) -> Union[EventData, EventDataBatch]: if isinstance(event_data, EventData): - outgoing_event_data = transform_outbound_single_message(event_data, EventData) + outgoing_event_data = transform_outbound_single_message( + event_data, EventData, self._amqp_transport.to_outgoing_amqp_message + ) if partition_key: - set_message_partition_key(outgoing_event_data.message, partition_key) + self._amqp_transport.set_message_partition_key( + outgoing_event_data._message, partition_key # pylint: disable=protected-access + ) wrapper_event_data = outgoing_event_data trace_message(wrapper_event_data, span) else: @@ -183,29 +205,35 @@ def _wrap_eventdata( event_data, EventDataBatch ): # The partition_key in the param will be omitted. if ( - partition_key and partition_key != event_data._partition_key # pylint: disable=protected-access + partition_key + and partition_key + != event_data._partition_key # pylint: disable=protected-access ): raise ValueError( "The partition_key does not match the one of the EventDataBatch" ) - - for event in event_data.message.data: # pylint: disable=protected-access + for ( + event + ) in event_data._message.data: # pylint: disable=protected-access trace_message(event, span) wrapper_event_data = event_data # type:ignore else: if partition_key: - event_data = _set_partition_key(event_data, partition_key) + event_data = _set_partition_key( + event_data, partition_key, self._amqp_transport + ) event_data = _set_trace_message(event_data, span) - wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # type: ignore # pylint: disable=protected-access + wrapper_event_data = EventDataBatch._from_batch( # type: ignore # pylint: disable=protected-access + event_data, self._amqp_transport, partition_key=partition_key + ) return wrapper_event_data def send( self, - event_data, # type: Union[EventData, EventDataBatch, Iterable[EventData]] - partition_key=None, # type: Optional[AnyStr] - timeout=None, # type: Optional[float] - ): - # type:(...) -> None + event_data: Union[EventData, EventDataBatch, Iterable[EventData]], + partition_key: Optional[AnyStr] = None, + timeout: Optional[float] = None, + ) -> None: """ Sends an event data and blocks until acknowledgement is received or operation times out. @@ -233,22 +261,17 @@ def send( with self._lock: with send_context_manager() as child: self._check_closed() - wrapper_event_data = self._wrap_eventdata(event_data, child, partition_key) + wrapper_event_data = self._wrap_eventdata( + event_data, child, partition_key + ) + self._unsent_events = [wrapper_event_data._message] # pylint: disable=protected-access if child: self._client._add_span_request_attributes( # pylint: disable=protected-access child ) + self._send_event_data_with_retry(timeout=timeout) - try: - self._open() - self._handler.send_message(wrapper_event_data.message, timeout=timeout) - except TimeoutError as exception: - raise OperationTimeoutError(message=str(exception), details=exception) - except Exception as exception: # pylint:disable=broad-except - raise self._handle_exception(exception) - - def close(self): - # type:() -> None + def close(self) -> None: """ Close down the handler. If the handler has already closed, this will be a no op. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index b3aea5088707..79c4ce5e0e8c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -122,7 +122,7 @@ def _get_partitions(self): for p_id in cast(List[str], self._partition_ids): self._producers[p_id] = None - def _get_max_mesage_size(self): + def _get_max_message_size(self): # type: () -> None # pylint: disable=protected-access,line-too-long with self._lock: @@ -131,9 +131,9 @@ def _get_max_mesage_size(self): EventHubProducer, self._producers[ALL_PARTITIONS] )._open_with_retry() self._max_message_size_on_link = ( - self._producers[ # type: ignore - ALL_PARTITIONS - ]._handler._link.remote_max_message_size + self._amqp_transport.get_remote_max_message_size( + self._producers[ALL_PARTITIONS]._handler # type: ignore + ) or MAX_MESSAGE_LENGTH_BYTES ) @@ -175,6 +175,7 @@ def _create_producer(self, partition_id=None, send_timeout=None): partition=partition_id, send_timeout=send_timeout, idle_timeout=self._idle_timeout, + amqp_transport=self._amqp_transport, ) return handler @@ -350,7 +351,7 @@ def create_batch(self, **kwargs): """ if not self._max_message_size_on_link: - self._get_max_mesage_size() + self._get_max_message_size() max_size_in_bytes = kwargs.get("max_size_in_bytes", None) partition_id = kwargs.get("partition_id", None) @@ -367,6 +368,7 @@ def create_batch(self, **kwargs): max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link), partition_id=partition_id, partition_key=partition_key, + amqp_transport=self._amqp_transport, ) return event_data_batch diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_message_backcompat.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_message_backcompat.py new file mode 100644 index 000000000000..fd00604d282b --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_message_backcompat.py @@ -0,0 +1,240 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +# pylint: disable=too-many-lines +from typing import Callable +from enum import Enum + +from ._encode import encode_payload +from .utils import get_message_encoded_size +from .error import AMQPError +from .message import Header, Properties +#from uamqp import constants, errors + + +class MessageState(Enum): + WaitingToBeSent = 0 + WaitingForSendAck = 1 + SendComplete = 2 + SendFailed = 3 + ReceivedUnsettled = 4 + ReceivedSettled = 5 + + def __eq__(self, __o: object) -> bool: + try: + return self.value == __o.value + except AttributeError: + return super().__eq__(__o) + + +class MessageAlreadySettled(Exception): + pass + + +DONE_STATES = (MessageState.SendComplete, MessageState.SendFailed) +RECEIVE_STATES = (MessageState.ReceivedSettled, MessageState.ReceivedUnsettled) +PENDING_STATES = (MessageState.WaitingForSendAck, MessageState.WaitingToBeSent) + + +class LegacyMessage(object): + def __init__(self, message, **kwargs): + self._message = message + self.state = MessageState.SendComplete + self.idle_time = 0 + self.retries = 0 + self._settler = kwargs.get('settler') + self._encoding = kwargs.get('encoding') + self.delivery_no = kwargs.get('delivery_no') + self.delivery_tag = kwargs.get('delivery_tag') or None + self.on_send_complete = None + self.properties = LegacyMessageProperties(self._message.properties) if self._message.properties else None + self.application_properties = self._message.application_properties + self.annotations = self._message.annotations + self.header = LegacyMessageHeader(self._message.header) if self._message.header else None + self.footer = self._message.footer + self.delivery_annotations = self._message.delivery_annotations + if self._settler: + self.state = MessageState.ReceivedUnsettled + elif self.delivery_no: + self.state = MessageState.ReceivedSettled + self._to_outgoing_amqp_message: Callable = kwargs.get('to_outgoing_amqp_message') + + def __str__(self): + return str(self._message) + + def _can_settle_message(self): + if self.state not in RECEIVE_STATES: + raise TypeError("Only received messages can be settled.") + if self.settled: + return False + return True + + @property + def settled(self): + if self.state == MessageState.ReceivedUnsettled: + return False + return True + + def get_message_encoded_size(self): + return get_message_encoded_size(self._to_outgoing_amqp_message(self._message)) + + def encode_message(self): + output = bytearray() + encode_payload(output, self._to_outgoing_amqp_message(self._message)) + return bytes(output) + + def get_data(self): + return self._message.body + + def gather(self): + if self.state in RECEIVE_STATES: + raise TypeError("Only new messages can be gathered.") + if not self._message: + raise ValueError("Message data already consumed.") + if self.state in DONE_STATES: + raise MessageAlreadySettled() + return [self] + + def get_message(self): + return self._to_outgoing_amqp_message(self._message) + + def accept(self): + if self._can_settle_message(): + self._settler.settle_messages(self.delivery_no, 'accepted') + self.state = MessageState.ReceivedSettled + return True + return False + + def reject(self, condition=None, description=None, info=None): + if self._can_settle_message(): + self._settler.settle_messages( + self.delivery_no, + 'rejected', + error=AMQPError( + condition=condition, + description=description, + info=info + ) + ) + self.state = MessageState.ReceivedSettled + return True + return False + + def release(self): + if self._can_settle_message(): + self._settler.settle_messages(self.delivery_no, 'released') + self.state = MessageState.ReceivedSettled + return True + return False + + def modify(self, failed, deliverable, annotations=None): + if self._can_settle_message(): + self._settler.settle_messages( + self.delivery_no, + 'modified', + delivery_failed=failed, + undeliverable_here=deliverable, + message_annotations=annotations, + ) + self.state = MessageState.ReceivedSettled + return True + return False + + +class LegacyBatchMessage(LegacyMessage): + batch_format = 0x80013700 + max_message_length = 1024 * 1024 + size_offset = 0 + + +class LegacyMessageProperties(object): + + def __init__(self, properties): + self.message_id = self._encode_property(properties.message_id) + self.user_id = self._encode_property(properties.user_id) + self.to = self._encode_property(properties.to) + self.subject = self._encode_property(properties.subject) + self.reply_to = self._encode_property(properties.reply_to) + self.correlation_id = self._encode_property(properties.correlation_id) + self.content_type = self._encode_property(properties.content_type) + self.content_encoding = self._encode_property(properties.content_encoding) + self.absolute_expiry_time = properties.absolute_expiry_time + self.creation_time = properties.creation_time + self.group_id = self._encode_property(properties.group_id) + self.group_sequence = properties.group_sequence + self.reply_to_group_id = self._encode_property(properties.reply_to_group_id) + + def __str__(self): + return str( + { + "message_id": self.message_id, + "user_id": self.user_id, + "to": self.to, + "subject": self.subject, + "reply_to": self.reply_to, + "correlation_id": self.correlation_id, + "content_type": self.content_type, + "content_encoding": self.content_encoding, + "absolute_expiry_time": self.absolute_expiry_time, + "creation_time": self.creation_time, + "group_id": self.group_id, + "group_sequence": self.group_sequence, + "reply_to_group_id": self.reply_to_group_id, + } + ) + + def _encode_property(self, value): + try: + return value.encode("UTF-8") + except AttributeError: + return value + + def get_properties_obj(self): + return Properties( + self.message_id, + self.user_id, + self.to, + self.subject, + self.reply_to, + self.correlation_id, + self.content_type, + self.content_encoding, + self.absolute_expiry_time, + self.creation_time, + self.group_id, + self.group_sequence, + self.reply_to_group_id + ) + + +class LegacyMessageHeader(object): + + def __init__(self, header): + self.delivery_count = header.delivery_count # or 0 + self.time_to_live = header.time_to_live + self.first_acquirer = header.first_acquirer + self.durable = header.durable + self.priority = header.priority + + def __str__(self): + return str( + { + "delivery_count": self.delivery_count, + "time_to_live": self.time_to_live, + "first_acquirer": self.first_acquirer, + "durable": self.durable, + "priority": self.priority, + } + ) + + def get_header_obj(self): + return Header( + self.durable, + self.priority, + self.time_to_live, + self.first_acquirer, + self.delivery_count + ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py new file mode 100644 index 000000000000..324085403361 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -0,0 +1,213 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from abc import ABC, abstractmethod + +class AmqpTransport(ABC): + """ + Abstract class that defines a set of common methods needed by producer and consumer. + """ + # define constants + BATCH_MESSAGE = None + MAX_FRAME_SIZE_BYTES = None + IDLE_TIMEOUT_FACTOR = None + MESSAGE = None + + # define symbols + PRODUCT_SYMBOL = None + VERSION_SYMBOL = None + FRAMEWORK_SYMBOL = None + PLATFORM_SYMBOL = None + USER_AGENT_SYMBOL = None + PROP_PARTITION_KEY_AMQP_SYMBOL = None + + # errors + AMQP_LINK_ERROR = None + LINK_STOLEN_CONDITION = None + MGMT_AUTH_EXCEPTION = None + CONNECTION_ERROR = None + AMQP_CONNECTION_ERROR = None + + @abstractmethod + def to_outgoing_amqp_message(self, annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message or pyamqp.Message + """ + + @abstractmethod + def get_message_encoded_size(self, message): + """ + Gets the message encoded size given an underlying Message. + :param uamqp.Message or pyamqp.Message message: Message to get encoded size of. + :rtype: int + """ + + @abstractmethod + def get_remote_max_message_size(self, handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + + @abstractmethod + def create_retry_policy(self, config): + """ + Creates the error retry policy. + :param ~azure.eventhub._configuration.Configuration config: Configuration. + """ + + @abstractmethod + def create_link_properties(self, link_properties): + """ + Creates and returns the link properties. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict + """ + + @abstractmethod + def create_send_client(self, *, config, **kwargs): + """ + Creates and returns the send client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + + @abstractmethod + def send_messages(self, producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + + # TODO: delete after data property added to uamqp.BatchMessage + #@abstractmethod + #def get_batch_message_data(self, batch_message): + # """ + # Gets the data body of the BatchMessage. + # :param batch_message: BatchMessage to retrieve data body from. + # """ + + @abstractmethod + def set_message_partition_key(self, message, partition_key, **kwargs): + """Set the partition key as an annotation on a uamqp message. + + :param message: The message to update. + :param str partition_key: The partition key value. + :rtype: None + """ + + @abstractmethod + def add_batch(self, batch_message, outgoing_event_data, event_data): + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + + @abstractmethod + def create_source(self, source, offset, selector): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes selector: Required. + """ + + @abstractmethod + def create_receive_client(self, *, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword Source source: Required. The source. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + @abstractmethod + def open_receive_client(self, *, handler, client, auth): + """ + Opens the receive client. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + """ + + @abstractmethod + def create_token_auth(self, auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, + then pass 300 to refresh_window. Only used by uamqp. + """ + + @abstractmethod + def create_mgmt_client(self, address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + @abstractmethod + def get_updated_token(self, mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + + @abstractmethod + def mgmt_client_request(self, mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + + @abstractmethod + def get_error(self, error, message, *, condition=None): + """ + Gets error and passes in error message, and, if applicable, condition. + :param error: The error to raise. + :param str message: Error message. + :param condition: Optional error condition. Will not be used by uamqp. + """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py new file mode 100644 index 000000000000..66dce6e383f0 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py @@ -0,0 +1,487 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import logging +import time +from typing import TYPE_CHECKING, Optional, Union, Any + +from .._pyamqp import ( + error as errors, + utils, + SendClient, + constants, + AMQPClient, + ReceiveClient, +) +from .._pyamqp.message import Message, BatchMessage, Header, Properties +from .._pyamqp.authentication import JWTTokenAuth +from .._pyamqp.endpoints import Source, ApacheFilters + +from .._constants import ( + NO_RETRY_ERRORS, + CUSTOM_CONDITION_BACKOFF, +) + +from ._base import AmqpTransport +from ..amqp._constants import AmqpMessageBodyType +from .._constants import ( + NO_RETRY_ERRORS, + PROP_PARTITION_KEY, +) + +from ..exceptions import ( + ConnectError, + EventDataSendError, + OperationTimeoutError, + EventHubError, + AuthenticationError, + ConnectionLostError, + EventDataSendError, +) + +_LOGGER = logging.getLogger(__name__) + + +class PyamqpTransport(AmqpTransport): + """ + Class which defines uamqp-based methods used by the producer and consumer. + """ + + # define constants + BATCH_MESSAGE = BatchMessage + IDLE_TIMEOUT_FACTOR = 1 + MESSAGE = Message + MAX_FRAME_SIZE_BYTES = constants.MAX_FRAME_SIZE_BYTES + + # define symbols + PRODUCT_SYMBOL = "product" + VERSION_SYMBOL = "version" + FRAMEWORK_SYMBOL = "framework" + PLATFORM_SYMBOL = "platform" + USER_AGENT_SYMBOL = "user-agent" + PROP_PARTITION_KEY_AMQP_SYMBOL = PROP_PARTITION_KEY + + # define errors and conditions + AMQP_LINK_ERROR = errors.AMQPLinkError + LINK_STOLEN_CONDITION = errors.ErrorCondition.LinkStolen + AUTH_EXCEPTION = errors.AuthenticationException + CONNECTION_ERROR = errors.AMQPConnectionError + AMQP_CONNECTION_ERROR = errors.AMQPConnectionError + TIMEOUT_EXCEPTION = TimeoutError + + def to_outgoing_amqp_message(self, annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: pyamqp.Message + """ + message_header = None + if annotated_message.header and any(annotated_message.header.values()): + message_header = Header( + delivery_count=annotated_message.header.delivery_count, + ttl=annotated_message.header.time_to_live, + first_acquirer=annotated_message.header.first_acquirer, + durable=annotated_message.header.durable, + priority=annotated_message.header.priority, + ) + + message_properties = None + if annotated_message.properties and any(annotated_message.properties.values()): + message_properties = Properties( + message_id=annotated_message.properties.message_id, + user_id=annotated_message.properties.user_id, + to=annotated_message.properties.to, + subject=annotated_message.properties.subject, + reply_to=annotated_message.properties.reply_to, + correlation_id=annotated_message.properties.correlation_id, + content_type=annotated_message.properties.content_type, + content_encoding=annotated_message.properties.content_encoding, + creation_time=int(annotated_message.properties.creation_time) + if annotated_message.properties.creation_time + else None, + absolute_expiry_time=int( + annotated_message.properties.absolute_expiry_time + ) + if annotated_message.properties.absolute_expiry_time + else None, + group_id=annotated_message.properties.group_id, + group_sequence=annotated_message.properties.group_sequence, + reply_to_group_id=annotated_message.properties.reply_to_group_id, + ) + + message_dict = { + "header": message_header, + "properties": message_properties, + "application_properties": annotated_message.application_properties, + "message_annotations": annotated_message.annotations, + "delivery_annotations": annotated_message.delivery_annotations, + "data": annotated_message._data_body, # pylint: disable=protected-access + "sequence": annotated_message._sequence_body, # pylint: disable=protected-access + "value": annotated_message._value_body, # pylint: disable=protected-access + "footer": annotated_message.footer + } + + return Message(**message_dict) + + def get_batch_message_encoded_size(self, message): + """ + Gets the batch message encoded size given an underlying Message. + :param pyamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + return utils.get_message_encoded_size(message) + + def get_message_encoded_size(self, message): + """ + Gets the message encoded size given an underlying Message. + :param pyamqp.Message: Message to get encoded size of. + :rtype: int + """ + return utils.get_message_encoded_size(message) + + def get_remote_max_message_size(self, handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + return handler._link.remote_max_message_size # pylint: disable=protected-access + + def create_retry_policy(self, config): + """ + Creates the error retry policy. + :param ~azure.eventhub._configuration.Configuration config: Configuration. + """ + return errors.RetryPolicy( + retry_total=config.max_retries, # pylint:disable=protected-access + retry_backoff_factor=config.backoff_factor, # pylint:disable=protected-access + retry_backoff_max=config.backoff_max, # pylint:disable=protected-access + retry_mode=config.retry_mode, # pylint:disable=protected-access + no_retry_condition=NO_RETRY_ERRORS, + custom_condition_backoff=CUSTOM_CONDITION_BACKOFF, + ) + + def create_link_properties(self, link_properties): + """ + Creates and returns the link properties. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict + """ + return { + symbol: utils.amqp_long_value(value) + for (symbol, value) in link_properties.items() + } + + def create_send_client(self, *, config, **kwargs): # pylint:disable=unused-argument + """ + Creates and returns the uamqp SendClient. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + + target = kwargs.pop("target") + # TODO: extra passed in to pyamqp, but not used. should be used? + msg_timeout = kwargs.pop( # pylint: disable=unused-variable + "msg_timeout" + ) # TODO: not used by pyamqp? + + return SendClient( + config.hostname, + target, + custom_endpoint_address=config.custom_endpoint_address, + connection_verify=config.connection_verify, + transport_type=config.transport_type, + http_proxy=config.http_proxy, + **kwargs, + ) + + def send_messages(self, producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + producer._open() + timeout = timeout_time - time.time() if timeout_time else 0 + producer._handler.send_message(producer._unsent_events[0], timeout=timeout) + producer._unsent_events = None + # TODO: figure out if we want to use below, and see if it affects error story + #try: + # producer._open() + # producer._handler.send_message( + # producer._unsent_events[0], timeout=timeout_time + # ) + #except self.TIMEOUT_EXCEPTION as exc: + # raise OperationTimeoutError(message=str(exc), details=exc) + #except Exception as exc: + # raise producer._handle_exception(exc) + + def set_message_partition_key( + self, message, partition_key, **kwargs + ): + # type: (Message, Optional[Union[bytes, str]], Any) -> Message + """Set the partition key as an annotation on a uamqp message. + :param Message message: The message to update. + :param str partition_key: The partition key value. + :rtype: Message + """ + encoding = kwargs.pop("encoding", 'utf-8') + if partition_key: + annotations = message.message_annotations + if annotations is None: + annotations = {} + try: + partition_key = partition_key.decode(encoding) + except AttributeError: + pass + annotations[ + PROP_PARTITION_KEY + ] = partition_key # pylint:disable=protected-access + header = Header(durable=True) + return message._replace(message_annotations=annotations, header=header) + return message + + def add_batch(self, batch_message, outgoing_event_data, event_data): # pylint: disable=unused-argument + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + utils.add_batch(batch_message._message, outgoing_event_data._message) # pylint: disable=protected-access + + def create_source(self, source, offset, selector): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes selector: Required. + """ + source = Source(address=source, filters={}) + if offset is not None: + filter_key = ApacheFilters.selector_filter + source.filters[filter_key] = ( + filter_key, + utils.amqp_string_value(selector) + ) + return source + + def create_receive_client(self, *, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + source = kwargs.pop("source") + return ReceiveClient( + config.hostname, + source, + receive_settle_mode=constants.ReceiverSettleMode.First, # TODO: make more descriptive in pyamqp? + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_address=config.custom_endpoint_address, + connection_verify=config.connection_verify, + **kwargs, + ) + + def open_receive_client(self, *, handler, client, auth): + """ + Opens the receive client and returns ready status. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + :param auth: Auth. + :rtype: bool + """ + # pylint:disable=protected-access + handler.open( + connection=client._conn_manager.get_connection( + client._address.hostname, auth + ) + ) + + def create_token_auth(self, auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, then pass 300 to refresh_window. + """ + # TODO: figure out why we're passing all these args to pyamqp JWTTokenAuth, which aren't being used + update_token = kwargs.pop("update_token") # pylint: disable=unused-variable + if update_token: + # update_token not actually needed by pyamqp + # just using to detect wh + return JWTTokenAuth( + auth_uri, + auth_uri, + get_token + ) + return JWTTokenAuth( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + ) + #if update_token: + # token_auth.update_token() # TODO: why don't we need to update in pyamqp? + + def create_mgmt_client(self, address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + return AMQPClient( + config.hostname, + auth=mgmt_auth, + network_trace=config.network_tracing, + transport_type=config.transport_type, + http_proxy=config.http_proxy, + custom_endpoint_address=config.custom_endpoint_address, + connection_verify=config.connection_verify + ) + + def get_updated_token(self, mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + return mgmt_auth.get_token() + + def mgmt_client_request(self, mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + operation_type = kwargs.pop("operation_type") + operation = kwargs.pop("operation") + return mgmt_client.mgmt_request( + mgmt_msg, operation=operation.decode(), operation_type=operation_type.decode(), **kwargs + ) + + def get_error(self, error, message, *, condition=None): + """ + Gets error and passes in error message, and, if applicable, condition. + :param error: The error to raise. + :param str message: Error message. + :param condition: Optional error condition. Will not be used by uamqp. + """ + return error(condition, message) + + def _create_eventhub_exception(self, exception): + if isinstance(exception, errors.AuthenticationException): + error = AuthenticationError(str(exception), exception) + elif isinstance(exception, errors.AMQPLinkError): + error = ConnectError(str(exception), exception) + # TODO: do we need MessageHanlderError in amqp any more + # if connection/session/link error are enough? + # elif isinstance(exception, errors.MessageHandlerError): + # error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.AMQPConnectionError): + error = ConnectError(str(exception), exception) + elif isinstance(exception, TimeoutError): + error = ConnectionLostError(str(exception), exception) + else: + error = EventHubError(str(exception), exception) + return error + + + def _handle_exception( + self, exception, closable + ): # pylint:disable=too-many-branches, too-many-statements + try: # closable is a producer/consumer object + name = closable._name # pylint: disable=protected-access + except AttributeError: # closable is an client object + name = closable._container_id # pylint: disable=protected-access + if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise + _LOGGER.info("%r stops due to keyboard interrupt", name) + closable._close_connection() # pylint:disable=protected-access + raise exception + elif isinstance(exception, EventHubError): + closable._close_handler() # pylint:disable=protected-access + raise exception + # TODO: The following errors seem to be useless in EH + # elif isinstance( + # exception, + # ( + # errors.MessageAccepted, + # errors.MessageAlreadySettled, + # errors.MessageModified, + # errors.MessageRejected, + # errors.MessageReleased, + # errors.MessageContentTooLarge, + # ), + # ): + # _LOGGER.info("%r Event data error (%r)", name, exception) + # error = EventDataError(str(exception), exception) + # raise error + elif isinstance(exception, errors.MessageException): + _LOGGER.info("%r Event data send error (%r)", name, exception) + error = EventDataSendError(str(exception), exception) + raise error + else: + if isinstance(exception, errors.AuthenticationException): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + elif isinstance(exception, errors.AMQPLinkError): + if hasattr(closable, "_close_handler"): + closable._close_handler() # pylint:disable=protected-access + elif isinstance(exception, errors.AMQPConnectionError): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + # TODO: add MessageHandlerError in amqp? + # elif isinstance(exception, errors.MessageHandlerError): + # if hasattr(closable, "_close_handler"): + # closable._close_handler() # pylint:disable=protected-access + else: # errors.AMQPConnectionError, compat.TimeoutException + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + return self._create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py new file mode 100644 index 000000000000..8dd8120beef5 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -0,0 +1,539 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import time +import logging +from typing import TYPE_CHECKING, Optional, Union, Any + +try: + from uamqp import ( + BatchMessage, + constants, + MessageBodyType, + Message, + types, + SendClient, + ReceiveClient, + Source, + utils, + authentication, + AMQPClient, + compat, + errors, + ) + from uamqp.message import ( + MessageHeader, + MessageProperties, + ) + uamqp_installed = True +except ImportError: + uamqp_installed = False + +from ._base import AmqpTransport +from ..amqp._constants import AmqpMessageBodyType +from .._constants import ( + NO_RETRY_ERRORS, + PROP_PARTITION_KEY, +) + +from ..exceptions import ( + ConnectError, + EventDataError, + EventDataSendError, + OperationTimeoutError, + EventHubError, + AuthenticationError, + ConnectionLostError, + EventDataError, + EventDataSendError, +) + +_LOGGER = logging.getLogger(__name__) + +if uamqp_installed: + def _error_handler(error): + """ + Called internally when an event has failed to send so we + can parse the error to determine whether we should attempt + to retry sending the event again. + Returns the action to take according to error type. + + :param error: The error received in the send attempt. + :type error: Exception + :rtype: ~uamqp.errors.ErrorAction + """ + if error.condition == b"com.microsoft:server-busy": + return errors.ErrorAction(retry=True, backoff=4) + if error.condition == b"com.microsoft:timeout": + return errors.ErrorAction(retry=True, backoff=2) + if error.condition == b"com.microsoft:operation-cancelled": + return errors.ErrorAction(retry=True) + if error.condition == b"com.microsoft:container-close": + return errors.ErrorAction(retry=True, backoff=4) + if error.condition in NO_RETRY_ERRORS: + return errors.ErrorAction(retry=False) + return errors.ErrorAction(retry=True) + + + class UamqpTransport(AmqpTransport): + """ + Class which defines uamqp-based methods used by the producer and consumer. + """ + # define constants + BATCH_MESSAGE = BatchMessage + MAX_FRAME_SIZE_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES + IDLE_TIMEOUT_FACTOR = 1000 + MESSAGE = Message + + # define symbols + PRODUCT_SYMBOL = types.AMQPSymbol("product") + VERSION_SYMBOL = types.AMQPSymbol("version") + FRAMEWORK_SYMBOL = types.AMQPSymbol("framework") + PLATFORM_SYMBOL = types.AMQPSymbol("platform") + USER_AGENT_SYMBOL = types.AMQPSymbol("user-agent") + PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) + + # define errors and conditions + AMQP_LINK_ERROR = errors.LinkDetach + LINK_STOLEN_CONDITION = constants.ErrorCodes.LinkStolen + AUTH_EXCEPTION = errors.AuthenticationException + CONNECTION_ERROR = ConnectError + AMQP_CONNECTION_ERROR = errors.AMQPConnectionError + TIMEOUT_EXCEPTION = compat.TimeoutException + + def to_outgoing_amqp_message(self, annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message + """ + message_header = None + if annotated_message.header: + message_header = MessageHeader() + message_header.delivery_count = annotated_message.header.delivery_count + message_header.time_to_live = annotated_message.header.time_to_live + message_header.first_acquirer = annotated_message.header.first_acquirer + message_header.durable = annotated_message.header.durable + message_header.priority = annotated_message.header.priority + + message_properties = None + if annotated_message.properties: + message_properties = MessageProperties( + message_id=annotated_message.properties.message_id, + user_id=annotated_message.properties.user_id, + to=annotated_message.properties.to, + subject=annotated_message.properties.subject, + reply_to=annotated_message.properties.reply_to, + correlation_id=annotated_message.properties.correlation_id, + content_type=annotated_message.properties.content_type, + content_encoding=annotated_message.properties.content_encoding, + creation_time=int(annotated_message.properties.creation_time) + if annotated_message.properties.creation_time else None, + absolute_expiry_time=int(annotated_message.properties.absolute_expiry_time) + if annotated_message.properties.absolute_expiry_time else None, + group_id=annotated_message.properties.group_id, + group_sequence=annotated_message.properties.group_sequence, + reply_to_group_id=annotated_message.properties.reply_to_group_id, + encoding=annotated_message._encoding # pylint: disable=protected-access + ) + + # pylint: disable=protected-access + amqp_body_type = annotated_message.body_type + if amqp_body_type == AmqpMessageBodyType.DATA: + amqp_body_type = MessageBodyType.Data + amqp_body = list(annotated_message._data_body) + elif amqp_body_type == AmqpMessageBodyType.SEQUENCE: + amqp_body_type = MessageBodyType.Sequence + amqp_body = list(annotated_message._sequence_body) + else: + amqp_body_type = MessageBodyType.Value + amqp_body = annotated_message._value_body + + return Message( + body=amqp_body, + body_type=amqp_body_type, + header=message_header, + properties=message_properties, + application_properties=annotated_message.application_properties, + annotations=annotated_message.annotations, + delivery_annotations=annotated_message.delivery_annotations, + footer=annotated_message.footer + ) + + def get_batch_message_encoded_size(self, message): + """ + Gets the batch message encoded size given an underlying Message. + :param uamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + return message.gather()[0].get_message_encoded_size() + + def get_message_encoded_size(self, message): + """ + Gets the message encoded size given an underlying Message. + :param uamqp.Message message: Message to get encoded size of. + :rtype: int + """ + return message.get_message_encoded_size() + + def get_remote_max_message_size(self, handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + return handler.message_handler._link.peer_max_message_size # pylint:disable=protected-access + + def create_retry_policy(self, config): + """ + Creates the error retry policy. + :param ~azure.eventhub._configuration.Configuration config: Configuration. + """ + return errors.ErrorPolicy(max_retries=config.max_retries, on_error=_error_handler) + + def create_link_properties(self, link_properties): + """ + Creates and returns the link properties. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict + """ + return {types.AMQPSymbol(symbol): types.AMQPLong(value) for (symbol, value) in link_properties.items()} + + def create_send_client(self, *, config, **kwargs): # pylint:disable=unused-argument + """ + Creates and returns the uamqp SendClient. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + + return SendClient( + target, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + **kwargs + ) + + def _set_msg_timeout(self, producer, timeout_time, last_exception, logger): + if not timeout_time: + return + remaining_time = timeout_time - time.time() + if remaining_time <= 0.0: + if last_exception: + error = last_exception + else: + error = OperationTimeoutError("Send operation timed out") + logger.info("%r send operation timed out. (%r)", producer._name, error) # pylint: disable=protected-access + raise error + producer._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access + + def send_messages(self, producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + producer._open() + producer._unsent_events[0].on_send_complete = producer._on_outcome + self._set_msg_timeout(producer, timeout_time, last_exception, logger) + producer._handler.queue_message(*producer._unsent_events) # type: ignore + producer._handler.wait() # type: ignore + producer._unsent_events = producer._handler.pending_messages # type: ignore + if producer._outcome != constants.MessageSendResult.Ok: + if producer._outcome == constants.MessageSendResult.Timeout: + producer._condition = OperationTimeoutError("Send operation timed out") + if producer._condition: + raise producer._condition + + # TODO: can delete this method, if data prop is added to uamqp.BatchMessage + #def get_batch_message_data(self, batch_message): + # """ + # Gets the data body of the BatchMessage. + # :param batch_message: BatchMessage to retrieve data body from. + # """ + # return batch_message._body_gen # pylint:disable=protected-access + + def set_message_partition_key(self, message, partition_key, **kwargs): # pylint:disable=unused-argument + # type: (Message, Optional[Union[bytes, str]], Any) -> Message + """Set the partition key as an annotation on a uamqp message. + + :param ~uamqp.Message message: The message to update. + :param str partition_key: The partition key value. + :rtype: Message + """ + if partition_key: + annotations = message.annotations + if annotations is None: + annotations = {} + annotations[ + UamqpTransport.PROP_PARTITION_KEY_AMQP_SYMBOL # TODO: see if setting non-amqp symbol is valid + ] = partition_key + header = MessageHeader() + header.durable = True + message.annotations = annotations + message.header = header + return message + + def add_batch(self, batch_message, outgoing_event_data, event_data): + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + batch_message._internal_events.append(event_data) + batch_message._message._body_gen.append( # pylint: disable=protected-access + outgoing_event_data._message + ) + + def create_source(self, source, offset, selector): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes selector: Required. + """ + source = Source(source) + if offset is not None: + source.set_filter(selector) + return source + + def create_receive_client(self, *, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + source = kwargs.pop("source") + symbol_array = kwargs.pop("desired_capabilities") + desired_capabilities = None + if symbol_array: + symbol_array = [types.AMQPSymbol(symbol) for symbol in symbol_array] + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + link_credit = kwargs.pop("link_credit") + streaming_receive = kwargs.pop("streaming_receive") + message_received_callback = kwargs.pop("message_received_callback") + + client = ReceiveClient( + source, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + desired_capabilities=desired_capabilities, + prefetch=link_credit, + receive_settle_mode=constants.ReceiverSettleMode.ReceiveAndDelete, + auto_complete=False, + **kwargs + ) + # pylint:disable=protected-access + client._streaming_receive = streaming_receive + client._message_received_callback = (message_received_callback) + return client + + def open_receive_client(self, *, handler, client, auth): + """ + Opens the receive client and returns ready status. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + :param auth: Auth. + :rtype: bool + """ + # pylint:disable=protected-access + handler.open(connection=client._conn_manager.get_connection( + client._address.hostname, auth + )) + + def create_token_auth(self, auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Required. Whether to update token. If not updating token, + then pass 300 to refresh_window. + """ + update_token = kwargs.pop("update_token") + refresh_window = 300 + if update_token: + refresh_window = 0 + + token_auth = authentication.JWTTokenAuth( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + refresh_window=refresh_window + ) + if update_token: + token_auth.update_token() # TODO: why don't we need to update in pyamqp? + return token_auth + + def create_mgmt_client(self, address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + mgmt_target = f"amqps://{address.hostname}{address.path}" + return AMQPClient( + mgmt_target, + auth=mgmt_auth, + debug=config.network_tracing + ) + + def get_updated_token(self, mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + return mgmt_auth.token + + def mgmt_client_request(self, mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + operation_type = kwargs.pop("operation_type") + operation = kwargs.pop("operation") + return mgmt_client.mgmt_request( + mgmt_msg, + operation, + op_type=operation_type, + **kwargs + ) + + def get_error(self, error, message, *, condition=None): # pylint: disable=unused-argument + """ + Gets error and passes in error message, and, if applicable, condition. + :param error: The error to raise. + :param str message: Error message. + :param condition: Optional error condition. Will not be used by uamqp. + """ + return error(message) + + def _create_eventhub_exception(self, exception): + if isinstance(exception, errors.AuthenticationException): + error = AuthenticationError(str(exception), exception) + elif isinstance(exception, errors.VendorLinkDetach): + error = ConnectError(str(exception), exception) + elif isinstance(exception, errors.LinkDetach): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.ConnectionClose): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.MessageHandlerError): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.AMQPConnectionError): + error_type = ( + AuthenticationError + if str(exception).startswith("Unable to open authentication session") + else ConnectError + ) + error = error_type(str(exception), exception) + elif isinstance(exception, compat.TimeoutException): + error = ConnectionLostError(str(exception), exception) + else: + error = EventHubError(str(exception), exception) + return error + + + def _handle_exception( + self, exception, closable + ): # pylint:disable=too-many-branches, too-many-statements + try: # closable is a producer/consumer object + name = closable._name # pylint: disable=protected-access + except AttributeError: # closable is an client object + name = closable._container_id # pylint: disable=protected-access + if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise + _LOGGER.info("%r stops due to keyboard interrupt", name) + closable._close_connection() # pylint:disable=protected-access + raise exception + elif isinstance(exception, EventHubError): + closable._close_handler() # pylint:disable=protected-access + raise exception + elif isinstance( + exception, + ( + errors.MessageAccepted, + errors.MessageAlreadySettled, + errors.MessageModified, + errors.MessageRejected, + errors.MessageReleased, + errors.MessageContentTooLarge, + ), + ): + _LOGGER.info("%r Event data error (%r)", name, exception) + error = EventDataError(str(exception), exception) + raise error + elif isinstance(exception, errors.MessageException): + _LOGGER.info("%r Event data send error (%r)", name, exception) + error = EventDataSendError(str(exception), exception) + raise error + else: + if isinstance(exception, errors.AuthenticationException): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + elif isinstance(exception, errors.LinkDetach): + if hasattr(closable, "_close_handler"): + closable._close_handler() # pylint:disable=protected-access + elif isinstance(exception, errors.ConnectionClose): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + elif isinstance(exception, errors.MessageHandlerError): + if hasattr(closable, "_close_handler"): + closable._close_handler() # pylint:disable=protected-access + else: # errors.AMQPConnectionError, compat.TimeoutException + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + return self._create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index e9c000e9d5e3..024e74042b29 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -10,19 +10,16 @@ import datetime import calendar import logging -from typing import TYPE_CHECKING, Type, Optional, Dict, Union, Any, Iterable, Tuple, Mapping +from typing import TYPE_CHECKING, Type, Optional, Dict, Union, Any, Iterable, Tuple, Mapping, Callable import six -from ._pyamqp.message import Header - from azure.core.settings import settings from azure.core.tracing import SpanKind, Link from .amqp import AmqpAnnotatedMessage from ._version import VERSION from ._constants import ( - PROP_PARTITION_KEY, MAX_USER_AGENT_LENGTH, USER_AGENT_PREFIX, PROP_LAST_ENQUEUED_SEQUENCE_NUMBER, @@ -30,11 +27,26 @@ PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC, PROP_LAST_ENQUEUED_OFFSET, PROP_TIMESTAMP, + PROP_PARTITION_KEY ) +try: + from uamqp import types + from uamqp.message import MessageHeader + PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) +except (ImportError, ModuleNotFoundError): + types = None + MessageHeader = None + PROP_PARTITION_KEY_AMQP_SYMBOL = None + + if TYPE_CHECKING: # pylint: disable=ungrouped-imports - from ._pyamqp.message import Message + from ._transport._base import AmqpTransport + try: + from uamqp import uamqp_types + except ImportError: + uamqp_types = None from azure.core.tracing import AbstractSpan from azure.core.credentials import AzureSasCredential from ._common import EventData @@ -76,8 +88,9 @@ def utc_from_timestamp(timestamp): return datetime.datetime.fromtimestamp(timestamp, tz=TZ_UTC) -def create_properties(user_agent=None): - # type: (Optional[str]) -> Dict[types.AMQPSymbol, str] +def create_properties( + user_agent: Optional[str] = None, *, amqp_transport: "AmqpTransport" +) -> Dict[Union["uamqp_types.AMQPSymbol", str], str]: """ Format the properties with which to instantiate the connection. This acts like a user agent over HTTP. @@ -85,66 +98,55 @@ def create_properties(user_agent=None): :rtype: dict """ properties = {} - properties["product"] = USER_AGENT_PREFIX - properties["version"] = VERSION - framework = "Python/{}.{}.{}".format( - sys.version_info[0], sys.version_info[1], sys.version_info[2] - ) - properties["framework"] = framework + properties[amqp_transport.PRODUCT_SYMBOL] = USER_AGENT_PREFIX + properties[amqp_transport.VERSION_SYMBOL] = VERSION + framework = f"Python/{sys.version_info[0]}.{sys.version_info[1]}.{sys.version_info[2]}" + properties[amqp_transport.FRAMEWORK_SYMBOL] = framework platform_str = platform.platform() - properties["platform"] = platform_str + properties[amqp_transport.PLATFORM_SYMBOL] = platform_str - final_user_agent = "{}/{} {} ({})".format( - USER_AGENT_PREFIX, VERSION, framework, platform_str - ) + final_user_agent = f"{USER_AGENT_PREFIX}/{VERSION} {framework} ({platform_str})" if user_agent: - final_user_agent = "{} {}".format(user_agent, final_user_agent) + final_user_agent = f"{user_agent} {final_user_agent}" if len(final_user_agent) > MAX_USER_AGENT_LENGTH: raise ValueError( - "The user-agent string cannot be more than {} in length." - "Current user_agent string is: {} with length: {}".format( - MAX_USER_AGENT_LENGTH, final_user_agent, len(final_user_agent) - ) + f"The user-agent string cannot be more than {MAX_USER_AGENT_LENGTH} in length." + f"Current user_agent string is: {final_user_agent} with length: {len(final_user_agent)}" ) - properties["user-agent"] = final_user_agent + properties[amqp_transport.USER_AGENT_SYMBOL] = final_user_agent return properties -def set_message_partition_key(message, partition_key, **kwargs): - # type: (Message, Optional[Union[bytes, str]]) -> Message - """Set the partition key as an annotation on a uamqp message. +@contextmanager +def send_context_manager(): + span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] + if span_impl_type is not None: + with span_impl_type(name="Azure.EventHubs.send", kind=SpanKind.CLIENT) as child: + yield child + else: + yield None + +# TODO: delete after async unit tests have been refactored +def set_message_partition_key(message, partition_key): + # type: (Message, Optional[Union[bytes, str]]) -> None + """Set the partition key as an annotation on a uamqp message. :param ~uamqp.Message message: The message to update. :param str partition_key: The partition key value. :rtype: None """ - encoding = kwargs.pop("encoding", 'utf-8') if partition_key: - annotations = message.message_annotations + annotations = message.annotations if annotations is None: annotations = dict() - try: - partition_key = partition_key.decode(encoding) - except AttributeError: - pass annotations[ - PROP_PARTITION_KEY + PROP_PARTITION_KEY_AMQP_SYMBOL ] = partition_key # pylint:disable=protected-access - header = Header(durable=True) - return message._replace(message_annotations=annotations, header=header) - return message - - -@contextmanager -def send_context_manager(): - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - - if span_impl_type is not None: - with span_impl_type(name="Azure.EventHubs.send", kind=SpanKind.CLIENT) as child: - yield child - else: - yield None + header = MessageHeader() + header.durable = True + message.annotations = annotations + message.header = header def trace_message(event, parent_span=None): @@ -209,15 +211,13 @@ def event_position_selector(value, inclusive=False): value.microsecond / 1000 ) return ( - "amqp.annotation.x-opt-enqueued-time {} '{}'".format( - operator, int(timestamp) - ) + f"amqp.annotation.x-opt-enqueued-time {operator} '{int(timestamp)}'" ).encode("utf-8") elif isinstance(value, six.integer_types): return ( - "amqp.annotation.x-opt-sequence-number {} '{}'".format(operator, value) + f"amqp.annotation.x-opt-sequence-number {operator} '{value}'" ).encode("utf-8") - return ("amqp.annotation.x-opt-offset {} '{}'".format(operator, value)).encode( + return (f"amqp.annotation.x-opt-offset {operator} '{value}'").encode( "utf-8" ) @@ -232,23 +232,23 @@ def get_last_enqueued_event_properties(event_data): if event_data._last_enqueued_event_properties: return event_data._last_enqueued_event_properties - if event_data.message.delivery_annotations: - sequence_number = event_data.message.delivery_annotations.get( + if event_data._message.delivery_annotations: + sequence_number = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_SEQUENCE_NUMBER, None ) - enqueued_time_stamp = event_data.message.delivery_annotations.get( + enqueued_time_stamp = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_TIME_UTC, None ) if enqueued_time_stamp: enqueued_time_stamp = utc_from_timestamp(float(enqueued_time_stamp) / 1000) - retrieval_time_stamp = event_data.message.delivery_annotations.get( + retrieval_time_stamp = event_data._message.delivery_annotations.get( PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC, None ) if retrieval_time_stamp: retrieval_time_stamp = utc_from_timestamp( float(retrieval_time_stamp) / 1000 ) - offset_bytes = event_data.message.delivery_annotations.get( + offset_bytes = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_OFFSET, None ) offset = offset_bytes.decode("UTF-8") if offset_bytes else None @@ -274,8 +274,8 @@ def parse_sas_credential(credential): return (sas, expiry) -def transform_outbound_single_message(message, message_type): - # type: (Union[AmqpAnnotatedMessage, EventData], Type[EventData]) -> EventData +def transform_outbound_single_message(message, message_type, to_outgoing_amqp_message): + # type: (Union[AmqpAnnotatedMessage, EventData], Type[EventData], Callable) -> EventData """ This method serves multiple goals: 1. update the internal message to reflect any updates to settable properties on EventData @@ -287,17 +287,19 @@ def transform_outbound_single_message(message, message_type): :rtype: EventData """ try: - # EventData # pylint: disable=protected-access - return message._to_outgoing_message() # type: ignore + # EventData.message stores uamqp/pyamqp.Message during sending + # pylint: disable=protected-access + message._message = to_outgoing_amqp_message(message.raw_amqp_message) + return message # type: ignore except AttributeError: - # AmqpAnnotatedMessage # pylint: disable=protected-access + # AmqpAnnotatedMessage is converted to uamqp/pyamqp.Message during sending + amqp_message = to_outgoing_amqp_message(message) return message_type._from_message( - message=message._to_outgoing_amqp_message(), raw_amqp_message=message # type: ignore + message=amqp_message, raw_amqp_message=message # type: ignore ) - def decode_with_recurse(data, encoding="UTF-8"): # type: (Any, str) -> Any # pylint:disable=isinstance-second-argument-not-valid-type diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index ee421c3ca4f2..9868d114b79f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -121,7 +121,7 @@ async def _get_partitions(self) -> None: for p_id in cast(List[str], self._partition_ids): self._producers[p_id] = None - async def _get_max_mesage_size(self) -> None: + async def _get_max_message_size(self) -> None: # pylint: disable=protected-access,line-too-long async with self._lock: if not self._max_message_size_on_link: @@ -378,7 +378,7 @@ async def create_batch( """ if not self._max_message_size_on_link: - await self._get_max_mesage_size() + await self._get_max_message_size() if max_size_in_bytes and max_size_in_bytes > self._max_message_size_on_link: raise ValueError( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index 21b56b8ba91c..79eaa16a603e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -4,75 +4,12 @@ # license information. # ------------------------------------------------------------------------- -from typing import Optional, Any, cast, Mapping, Dict +from __future__ import annotations +from typing import Optional, Any, cast, Mapping, Dict, Union +from ._amqp_utils import normalized_data_body, normalized_sequence_body from ._constants import AmqpMessageBodyType -from .._pyamqp.message import Message, Header, Properties -from .._pyamqp import utils as pyamqp_utils - - -class DictMixin(object): - def __setitem__(self, key, item): - # type: (Any, Any) -> None - self.__dict__[key] = item - - def __getitem__(self, key): - # type: (Any) -> Any - return self.__dict__[key] - - def __repr__(self): - # type: () -> str - return str(self) - - def __len__(self): - # type: () -> int - return len(self.keys()) - - def __delitem__(self, key): - # type: (Any) -> None - self.__dict__[key] = None - - def __eq__(self, other): - # type: (Any) -> bool - """Compare objects by comparing all attributes.""" - if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ - return False - - def __ne__(self, other): - # type: (Any) -> bool - """Compare objects by comparing all attributes.""" - return not self.__eq__(other) - - def __str__(self): - # type: () -> str - return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) - - def has_key(self, k): - # type: (Any) -> bool - return k in self.__dict__ - - def update(self, *args, **kwargs): - # type: (Any, Any) -> None - return self.__dict__.update(*args, **kwargs) - - def keys(self): - # type: () -> list - return [k for k in self.__dict__ if not k.startswith("_")] - - def values(self): - # type: () -> list - return [v for k, v in self.__dict__.items() if not k.startswith("_")] - - def items(self): - # type: () -> list - return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] - - def get(self, key, default=None): - # type: (Any, Optional[Any]) -> Any - if key in self.__dict__: - return self.__dict__[key] - return default +from .._mixin import DictMixin class AmqpAnnotatedMessage(object): @@ -108,12 +45,15 @@ class AmqpAnnotatedMessage(object): def __init__(self, **kwargs): # type: (Any) -> None - self._message = kwargs.pop("message", None) self._encoding = kwargs.pop("encoding", "UTF-8") + self._data_body = None + self._sequence_body = None + self._value_body = None # internal usage only for Event Hub received message - if self._message: - self._from_amqp_message(self._message) + message = kwargs.pop("message", None) + if message: + self._from_amqp_message(message) return # manually constructed AMQPAnnotatedMessage @@ -124,22 +64,17 @@ def __init__(self, **kwargs): "or value_body being set as the body of the AmqpAnnotatedMessage." ) - self._body = None self._body_type = None if "data_body" in kwargs: - self._body = pyamqp_utils.normalized_data_body(kwargs.get("data_body")) - self._message = Message(data=self._body) + self._data_body = normalized_data_body(kwargs.get("data_body")) self._body_type = AmqpMessageBodyType.DATA elif "sequence_body" in kwargs: - self._body = pyamqp_utils.normalized_sequence_body(kwargs.get("sequence_body")) + self._sequence_body = normalized_sequence_body(kwargs.get("sequence_body")) self._body_type = AmqpMessageBodyType.SEQUENCE - self._message = Message(sequence=self._body) elif "value_body" in kwargs: - self._body = kwargs.get("value_body") + self._value_body = kwargs.get("value_body") self._body_type = AmqpMessageBodyType.VALUE - self._message = Message(value=self._body) - #self._message = uamqp.message.Message(body=self._body, body_type=self._body_type) header_dict = cast(Mapping, kwargs.get("header")) self._header = AmqpMessageHeader(**header_dict) if "header" in kwargs else None self._footer = kwargs.get("footer") @@ -149,34 +84,16 @@ def __init__(self, **kwargs): self._annotations = kwargs.get("annotations") self._delivery_annotations = kwargs.get("delivery_annotations") - def __str__(self): + def __str__(self) -> str: if self._body_type == AmqpMessageBodyType.DATA: - output_str = "" - for data_section in self.body: - try: - output_str += data_section.decode(self._encoding) - except AttributeError: - output_str += str(data_section) - return output_str + return "".join(d.decode(self._encoding) for d in self._data_body) elif self._body_type == AmqpMessageBodyType.SEQUENCE: - output_str = "" - for sequence_section in self.body: - for d in sequence_section: - try: - output_str += d.decode(self._encoding) - except AttributeError: - output_str += str(d) - return output_str - else: - if not self.body: - return "" - try: - return self.body.decode(self._encoding) - except AttributeError: - return str(self.body) - - def __repr__(self): - # type: () -> str + return str(self._sequence_body) + elif self._body_type == AmqpMessageBodyType.VALUE: + return str(self._value_body) + return "" + + def __repr__(self) -> str: # pylint: disable=bare-except message_repr = "body={}".format( str(self) @@ -209,8 +126,8 @@ def __repr__(self): return "AmqpAnnotatedMessage({})".format(message_repr)[:1024] def _from_amqp_message(self, message): - # populate the properties from an uamqp message - # TODO: message.properties should not be a list + # populate the properties from a amqp transport message + # TODO: pyamqp message.properties should not be a list self._properties = AmqpMessageProperties( message_id=message.properties.message_id, user_id=message.properties.user_id, @@ -228,7 +145,7 @@ def _from_amqp_message(self, message): ) if message.properties else None self._header = AmqpMessageHeader( delivery_count=message.header.delivery_count, - time_to_live=message.header.time_to_live, + time_to_live=message.header.ttl, first_acquirer=message.header.first_acquirer, durable=message.header.durable, priority=message.header.priority @@ -237,57 +154,18 @@ def _from_amqp_message(self, message): self._annotations = message.message_annotations if message.message_annotations else {} self._delivery_annotations = message.delivery_annotations if message.delivery_annotations else {} self._application_properties = message.application_properties if message.application_properties else {} - - def _to_outgoing_amqp_message(self): - message_header = None - if self.header and any(self.header.values()): - message_header = Header( - delivery_count=self.header.delivery_count, - ttl=self.header.time_to_live, - first_acquirer=self.header.first_acquirer, - durable=self.header.durable, - priority=self.header.priority - ) - - message_properties = None - if self.properties and any(self.properties.values()): - message_properties = Properties( - message_id=self.properties.message_id, - user_id=self.properties.user_id, - to=self.properties.to, - subject=self.properties.subject, - reply_to=self.properties.reply_to, - correlation_id=self.properties.correlation_id, - content_type=self.properties.content_type, - content_encoding=self.properties.content_encoding, - creation_time=int(self.properties.creation_time) if self.properties.creation_time else None, - absolute_expiry_time=int(self.properties.absolute_expiry_time) - if self.properties.absolute_expiry_time else None, - group_id=self.properties.group_id, - group_sequence=self.properties.group_sequence, - reply_to_group_id=self.properties.reply_to_group_id - ) - - dict = { - "header": message_header, - "properties": message_properties, - "application_properties": self.application_properties, - "message_annotations": self.annotations, - "delivery_annotations": self.delivery_annotations, - "footer": self.footer - } - - if self.body_type == AmqpMessageBodyType.DATA: - dict["data"] = self._body - elif self.body_type == AmqpMessageBodyType.SEQUENCE: - dict["sequence"] = self._body + if message.data: + self._data_body = list(message.data) + self._body_type = AmqpMessageBodyType.DATA + elif message.sequence: + self._sequence_body = list(message.sequence) + self._body_type = AmqpMessageBodyType.SEQUENCE else: - dict["value"] = self._body - - return Message(**dict) + self._value_body = message.value + self._body_type = AmqpMessageBodyType.VALUE @property - def body(self): + def body(self) -> Any: # type: () -> Any """The body of the Message. The format may vary depending on the body type: For ~azure.eventhub.AmqpMessageBodyType.DATA, the body could be bytes or Iterable[bytes] @@ -295,24 +173,23 @@ def body(self): For ~azure.eventhub.AmqpMessageBodyType.VALUE, the body could be any type. :rtype: Any """ - return self._message.data or self._message.sequence or self._message.value + if self._body_type == AmqpMessageBodyType.DATA: # pylint:disable=no-else-return + return (i for i in self._data_body) + elif self._body_type == AmqpMessageBodyType.SEQUENCE: + return (i for i in self._sequence_body) + elif self._body_type == AmqpMessageBodyType.VALUE: + return self._value_body + return None @property - def body_type(self): - # type: () -> AmqpMessageBodyType + def body_type(self) -> AmqpMessageBodyType: """The body type of the underlying AMQP message. rtype: ~azure.eventhub.amqp.AmqpMessageBodyType """ - if self._message.data: - return AmqpMessageBodyType.DATA - elif self._message.sequence: - return AmqpMessageBodyType.SEQUENCE - else: - return AmqpMessageBodyType.VALUE + return self._body_type @property - def properties(self): - # type: () -> Optional[AmqpMessageProperties] + def properties(self) -> Optional[AmqpMessageProperties]: """ Properties to add to the message. :rtype: Optional[~azure.eventhub.amqp.AmqpMessageProperties] @@ -320,13 +197,11 @@ def properties(self): return self._properties @properties.setter - def properties(self, value): - # type: (AmqpMessageProperties) -> None + def properties(self, value: AmqpMessageProperties) -> None: self._properties = value @property - def application_properties(self): - # type: () -> Optional[Dict] + def application_properties(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Service specific application properties. @@ -335,13 +210,11 @@ def application_properties(self): return self._application_properties @application_properties.setter - def application_properties(self, value): - # type: (Dict) -> None + def application_properties(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._application_properties = value @property - def annotations(self): - # type: () -> Optional[Dict] + def annotations(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Service specific message annotations. @@ -350,13 +223,11 @@ def annotations(self): return self._annotations @annotations.setter - def annotations(self, value): - # type: (Dict) -> None + def annotations(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._annotations = value @property - def delivery_annotations(self): - # type: () -> Optional[Dict] + def delivery_annotations(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Delivery-specific non-standard properties at the head of the message. Delivery annotations convey information from the sending peer to the receiving peer. @@ -366,13 +237,11 @@ def delivery_annotations(self): return self._delivery_annotations @delivery_annotations.setter - def delivery_annotations(self, value): - # type: (Dict) -> None + def delivery_annotations(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._delivery_annotations = value @property - def header(self): - # type: () -> Optional[AmqpMessageHeader] + def header(self) -> Optional[AmqpMessageHeader]: """ The message header. :rtype: Optional[~azure.eventhub.amqp.AmqpMessageHeader] @@ -380,13 +249,11 @@ def header(self): return self._header @header.setter - def header(self, value): - # type: (AmqpMessageHeader) -> None + def header(self, value: AmqpMessageHeader) -> None: self._header = value @property - def footer(self): - # type: () -> Optional[Dict] + def footer(self) -> Optional[Dict[Any, Any]]: """ The message footer. @@ -395,10 +262,8 @@ def footer(self): return self._footer @footer.setter - def footer(self, value): - # type: (Dict) -> None + def footer(self, value: Optional[Dict[Any, Any]]) -> None: self._footer = value - # self._message.footer = value class AmqpMessageHeader(DictMixin): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py new file mode 100644 index 000000000000..4bb676392f89 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py @@ -0,0 +1,27 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +def encode_str(data, encoding='utf-8'): + try: + return data.encode(encoding) + except AttributeError: + return data + +def normalized_data_body(data, **kwargs): + # A helper method to normalize input into AMQP Data Body format + encoding = kwargs.get("encoding", "utf-8") + if isinstance(data, list): + return [encode_str(item, encoding) for item in data] + else: + return [encode_str(data, encoding)] + + +def normalized_sequence_body(sequence): + # A helper method to normalize input into AMQP Sequence Body format + if isinstance(sequence, list) and all([isinstance(b, list) for b in sequence]): + return sequence + elif isinstance(sequence, list): + return [sequence] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py index d6b0258a04d0..8001d97cea6d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py @@ -2,14 +2,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import logging import six - -from ._constants import NO_RETRY_ERRORS -from ._pyamqp import error as errors - -_LOGGER = logging.getLogger(__name__) - +try: + from uamqp import errors, compat +except ImportError: + errors = None + compat = None class EventHubError(Exception): """Represents an error occurred in the client. @@ -101,73 +99,27 @@ class OperationTimeoutError(EventHubError): class OwnershipLostError(Exception): """Raised when `update_checkpoint` detects the ownership to a partition has been lost.""" - +# TODO: delete when async unittests have been refactored def _create_eventhub_exception(exception): if isinstance(exception, errors.AuthenticationException): error = AuthenticationError(str(exception), exception) - elif isinstance(exception, errors.AMQPLinkError): + elif isinstance(exception, errors.VendorLinkDetach): error = ConnectError(str(exception), exception) - # TODO: do we need MessageHanlderError in amqp any more - # if connection/session/link error are enough? - # elif isinstance(exception, errors.MessageHandlerError): - # error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.LinkDetach): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.ConnectionClose): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.MessageHandlerError): + error = ConnectionLostError(str(exception), exception) elif isinstance(exception, errors.AMQPConnectionError): - error = ConnectError(str(exception), exception) - elif isinstance(exception, TimeoutError): + error_type = ( + AuthenticationError + if str(exception).startswith("Unable to open authentication session") + else ConnectError + ) + error = error_type(str(exception), exception) + elif isinstance(exception, compat.TimeoutException): error = ConnectionLostError(str(exception), exception) else: error = EventHubError(str(exception), exception) return error - - -def _handle_exception( - exception, closable -): # pylint:disable=too-many-branches, too-many-statements - try: # closable is a producer/consumer object - name = closable._name # pylint: disable=protected-access - except AttributeError: # closable is an client object - name = closable._container_id # pylint: disable=protected-access - if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise - _LOGGER.info("%r stops due to keyboard interrupt", name) - closable._close_connection() # pylint:disable=protected-access - raise exception - elif isinstance(exception, EventHubError): - closable._close_handler() # pylint:disable=protected-access - raise exception - # TODO: The following errors seem to be useless in EH - # elif isinstance( - # exception, - # ( - # errors.MessageAccepted, - # errors.MessageAlreadySettled, - # errors.MessageModified, - # errors.MessageRejected, - # errors.MessageReleased, - # errors.MessageContentTooLarge, - # ), - # ): - # _LOGGER.info("%r Event data error (%r)", name, exception) - # error = EventDataError(str(exception), exception) - # raise error - elif isinstance(exception, errors.MessageException): - _LOGGER.info("%r Event data send error (%r)", name, exception) - error = EventDataSendError(str(exception), exception) - raise error - else: - if isinstance(exception, errors.AuthenticationException): - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - elif isinstance(exception, errors.AMQPLinkError): - if hasattr(closable, "_close_handler"): - closable._close_handler() # pylint:disable=protected-access - elif isinstance(exception, errors.AMQPConnectionError): - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - # TODO: add MessageHandlerError in amqp? - # elif isinstance(exception, errors.MessageHandlerError): - # if hasattr(closable, "_close_handler"): - # closable._close_handler() # pylint:disable=protected-access - else: # errors.AMQPConnectionError, compat.TimeoutException - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - return _create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/dev_requirements.txt b/sdk/eventhub/azure-eventhub/dev_requirements.txt index 9c91833e14d8..4035baa0ba70 100644 --- a/sdk/eventhub/azure-eventhub/dev_requirements.txt +++ b/sdk/eventhub/azure-eventhub/dev_requirements.txt @@ -5,5 +5,4 @@ azure-mgmt-eventhub==10.0.0 azure-mgmt-resource==20.0.0 aiohttp>=3.0 websocket-client --e ../../../tools/azure-devtools --e ../../servicebus/azure-servicebus \ No newline at end of file +-e ../../../tools/azure-devtools \ No newline at end of file diff --git a/sdk/eventhub/azure-eventhub/setup.py b/sdk/eventhub/azure-eventhub/setup.py index 8730981bc8ea..ee939c0934f8 100644 --- a/sdk/eventhub/azure-eventhub/setup.py +++ b/sdk/eventhub/azure-eventhub/setup.py @@ -70,6 +70,7 @@ packages=find_packages(exclude=exclude_packages), install_requires=[ "azure-core<2.0.0,>=1.14.0", + "uamqp>=1.5.1,<2.0.0", "typing-extensions>=4.0.1", ] ) diff --git a/sdk/eventhub/azure-eventhub/tests/__init__.py b/sdk/eventhub/azure-eventhub/tests/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/_test_case.py b/sdk/eventhub/azure-eventhub/tests/_test_case.py new file mode 100644 index 000000000000..2f77bbf23a0c --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/_test_case.py @@ -0,0 +1,11 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ + +def get_decorator(): + try: + import uamqp + except (ImportError, ModuleNotFoundError): + return [False] + return [True, False] diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py b/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py index c00ea84067ea..e8d9c4dd8341 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py @@ -11,20 +11,27 @@ from azure.eventhub import EventData, EventHubProducerClient, EventHubConsumerClient, EventHubSharedKeyCredential from azure.eventhub._client_base import EventHubSASTokenCredential from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_client_secret_credential(live_eventhub): +def test_client_secret_credential(live_eventhub, uamqp_transport): credential = EnvironmentCredential() producer_client = EventHubProducerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport) consumer_client = EventHubConsumerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], consumer_group='$default', credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') batch.add(EventData(body='A single message')) @@ -50,11 +57,15 @@ def on_event(partition_context, event): assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8') +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_client_sas_credential(live_eventhub): +def test_client_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. hostname = live_eventhub['hostname'] - producer_client = EventHubProducerClient.from_connection_string(live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub']) + producer_client = EventHubProducerClient.from_connection_string( + live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub'], uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -67,7 +78,8 @@ def test_client_sas_credential(live_eventhub): token = credential.get_token(auth_uri).token producer_client = EventHubProducerClient(fully_qualified_namespace=hostname, eventhub_name=live_eventhub['event_hub'], - credential=EventHubSASTokenCredential(token, time.time() + 3000)) + credential=EventHubSASTokenCredential(token, time.time() + 3000), + uamqp_transport=uamqp_transport) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -77,7 +89,8 @@ def test_client_sas_credential(live_eventhub): # Finally let's do it with SAS token + conn str token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode()) conn_str_producer_client = EventHubProducerClient.from_connection_string(token_conn_str, - eventhub_name=live_eventhub['event_hub']) + eventhub_name=live_eventhub['event_hub'], + uamqp_transport=uamqp_transport) with conn_str_producer_client: batch = conn_str_producer_client.create_batch(partition_id='0') @@ -85,11 +98,15 @@ def test_client_sas_credential(live_eventhub): conn_str_producer_client.send_batch(batch) +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_client_azure_sas_credential(live_eventhub): +def test_client_azure_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. hostname = live_eventhub['hostname'] - producer_client = EventHubProducerClient.from_connection_string(live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub']) + producer_client = EventHubProducerClient.from_connection_string( + live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub'], uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -110,14 +127,17 @@ def test_client_azure_sas_credential(live_eventhub): producer_client.send_batch(batch) +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_client_azure_named_key_credential(live_eventhub): +def test_client_azure_named_key_credential(live_eventhub, uamqp_transport): credential = AzureNamedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) consumer_client = EventHubConsumerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], consumer_group='$default', credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport) assert consumer_client.get_eventhub_properties() is not None diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py index 9e09cd156ef8..e9320c83aed7 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py @@ -6,14 +6,24 @@ from azure.eventhub import EventHubConsumerClient from azure.eventhub._eventprocessor.in_memory_checkpoint_store import InMemoryCheckpointStore from azure.eventhub._constants import ALL_PARTITIONS +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_no_partition(connstr_senders): +def test_receive_no_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) senders[1].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', receive_timeout=1) + client = EventHubConsumerClient.from_connection_string( + connection_str, + consumer_group='$default', + receive_timeout=1, + uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): on_event.received += 1 @@ -45,11 +55,15 @@ def on_event(partition_context, event): assert len([checkpoint for checkpoint in checkpoints if checkpoint["sequence_number"] == on_event.sequence_number]) > 0 +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_partition(connstr_senders): +def test_receive_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): on_event.received += 1 @@ -73,17 +87,21 @@ def on_event(partition_context, event): assert on_event.eventhub_name == senders[0]._client.eventhub_name +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_load_balancing(connstr_senders): +def test_receive_load_balancing(connstr_senders, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - test code using multiple threads. Sometimes OSX aborts python process") connection_str, senders = connstr_senders cs = InMemoryCheckpointStore() client1 = EventHubConsumerClient.from_connection_string( - connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1) + connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1, uamqp_transport=uamqp_transport + ) client2 = EventHubConsumerClient.from_connection_string( - connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1) + connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1, uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): pass @@ -105,13 +123,17 @@ def on_event(partition_context, event): assert len(client2._event_processors[("$default", ALL_PARTITIONS)]._consumers) == 1 -def test_receive_batch_no_max_wait_time(connstr_senders): +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_receive_batch_no_max_wait_time(connstr_senders, uamqp_transport): '''Test whether callback is called when max_wait_time is None and max_batch_size has reached ''' connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) senders[1].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event_batch(partition_context, event_batch): on_event_batch.received += len(event_batch) @@ -146,14 +168,14 @@ def on_event_batch(partition_context, event_batch): worker.join() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("max_wait_time, sleep_time, expected_result", [(3, 10, []), - (3, 2, None), - ]) -def test_receive_batch_empty_with_max_wait_time(connection_str, max_wait_time, sleep_time, expected_result): + (3, 2, None)]) +def test_receive_batch_empty_with_max_wait_time(uamqp_transport, connection_str, max_wait_time, sleep_time, expected_result): '''Test whether event handler is called when max_wait_time > 0 and no event is received ''' - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) def on_event_batch(partition_context, event_batch): on_event_batch.event_batch = event_batch @@ -168,13 +190,17 @@ def on_event_batch(partition_context, event_batch): worker.join() -def test_receive_batch_early_callback(connstr_senders): +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_receive_batch_early_callback(connstr_senders, uamqp_transport): ''' Test whether the callback is called once max_batch_size reaches and before max_wait_time reaches. ''' connection_str, senders = connstr_senders for _ in range(10): senders[0].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event_batch(partition_context, event_batch): on_event_batch.received += len(event_batch) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py index 35722dfbf635..b98101c0eac7 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py @@ -20,27 +20,40 @@ ) from azure.eventhub import EventHubConsumerClient from azure.eventhub import EventHubProducerClient +try: + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except (ImportError, ModuleNotFoundError): + UamqpTransport = None +from azure.eventhub._transport._pyamqp_transport import PyamqpTransport +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() + +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_send_batch_with_invalid_hostname(invalid_hostname): +def test_send_batch_with_invalid_hostname(invalid_hostname, uamqp_transport): + amqp_transport = UamqpTransport() if uamqp_transport else PyamqpTransport() if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " "and blocking other tests") - client = EventHubProducerClient.from_connection_string(invalid_hostname) + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) with client: with pytest.raises(ConnectError): - batch = EventDataBatch() + batch = EventDataBatch(amqp_transport=amqp_transport) batch.add(EventData("test data")) client.send_batch(batch) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_with_invalid_hostname_sync(invalid_hostname): +def test_receive_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): def on_event(partition_context, event): pass - client = EventHubConsumerClient.from_connection_string(invalid_hostname, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + invalid_hostname, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event, )) @@ -50,23 +63,26 @@ def on_event(partition_context, event): thread.join() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_send_batch_with_invalid_key(invalid_key): - client = EventHubProducerClient.from_connection_string(invalid_key) +def test_send_batch_with_invalid_key(invalid_key, uamqp_transport): + client = EventHubProducerClient.from_connection_string(invalid_key, uamqp_transport=uamqp_transport) + amqp_transport = UamqpTransport() if uamqp_transport else PyamqpTransport() try: with pytest.raises(ConnectError): - batch = EventDataBatch() + batch = EventDataBatch(amqp_transport=amqp_transport) batch.add(EventData("test data")) client.send_batch(batch) finally: client.close() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_send_batch_to_invalid_partitions(connection_str): +def test_send_batch_to_invalid_partitions(connection_str, uamqp_transport): partitions = ["XYZ", "-1", "1000", "-"] for p in partitions: - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: with pytest.raises(ConnectError): batch = client.create_batch(partition_id=p) @@ -76,11 +92,12 @@ def test_send_batch_to_invalid_partitions(connection_str): client.close() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_send_batch_too_large_message(connection_str): +def test_send_batch_too_large_message(connection_str, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - open issue regarding message size") - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: data = EventData(b"A" * 1100000) batch = client.create_batch() @@ -90,9 +107,10 @@ def test_send_batch_too_large_message(connection_str): client.close() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_send_batch_null_body(connection_str): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_send_batch_null_body(connection_str, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: with pytest.raises(ValueError): data = EventData(None) @@ -103,20 +121,22 @@ def test_send_batch_null_body(connection_str): client.close() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_create_batch_with_invalid_hostname_sync(invalid_hostname): +def test_create_batch_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " "and blocking other tests") - client = EventHubProducerClient.from_connection_string(invalid_hostname) + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) with client: with pytest.raises(ConnectError): client.create_batch(max_size_in_bytes=300) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_create_batch_with_too_large_size_sync(connection_str): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_create_batch_with_too_large_size_sync(connection_str, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: with pytest.raises(ValueError): client.create_batch(max_size_in_bytes=5 * 1024 * 1024) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py index eb197eec44b0..678fccabc106 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py @@ -9,60 +9,78 @@ from azure.eventhub import EventHubSharedKeyCredential from azure.eventhub import EventHubConsumerClient from azure.eventhub.exceptions import AuthenticationError, ConnectError, EventHubError +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() + +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_get_properties(live_eventhub): +def test_get_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: properties = client.get_eventhub_properties() assert properties['eventhub_name'] == live_eventhub['event_hub'] and properties['partition_ids'] == ['0', '1'] +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_get_properties_with_auth_error_sync(live_eventhub): +def test_get_properties_with_auth_error_sync(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], "AaBbCcDdEeFf=")) + EventHubSharedKeyCredential(live_eventhub['key_name'], "AaBbCcDdEeFf="), + uamqp_transport=uamqp_transport + ) with client: with pytest.raises(AuthenticationError) as e: client.get_eventhub_properties() client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential("invalid", live_eventhub['access_key']) + EventHubSharedKeyCredential("invalid", live_eventhub['access_key']), uamqp_transport=uamqp_transport ) with client: with pytest.raises(AuthenticationError) as e: client.get_eventhub_properties() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_get_properties_with_connect_error(live_eventhub): +def test_get_properties_with_connect_error(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], "invalid", '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport ) with client: with pytest.raises(ConnectError) as e: client.get_eventhub_properties() client = EventHubConsumerClient("invalid.servicebus.windows.net", live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport ) with client: with pytest.raises(EventHubError) as e: # This can be either ConnectError or ConnectionLostError client.get_eventhub_properties() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_get_partition_ids(live_eventhub): +def test_get_partition_ids(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: partition_ids = client.get_partition_ids() assert partition_ids == ['0', '1'] +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_get_partition_properties(live_eventhub): +def test_get_partition_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: properties = client.get_partition_properties('0') assert properties['eventhub_name'] == live_eventhub['event_hub'] \ diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py index 21d6e249581e..bf03cae60b37 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py @@ -9,13 +9,20 @@ import pytest import time import datetime +import uamqp from azure.eventhub import EventData, TransportType, EventHubConsumerClient from azure.eventhub.exceptions import EventHubError +from azure.eventhub._pyamqp.message import Properties +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_end_of_stream(connstr_senders): +def test_receive_end_of_stream(connstr_senders, uamqp_transport): def on_event(partition_context, event): if partition_context.partition_id == "0": assert event.body_as_str() == "Receiving only a single event" @@ -29,7 +36,9 @@ def on_event(partition_context, event): assert ", partition_key: 0" in event_str on_event.called = False connection_str, senders = connstr_senders - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"partition_id": "0", "starting_position": "@latest"}) @@ -43,6 +52,7 @@ def on_event(partition_context, event): thread.join() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("position, inclusive, expected_result", [("offset", False, "Exclusive"), ("offset", True, "Inclusive"), @@ -50,7 +60,7 @@ def on_event(partition_context, event): ("sequence", True, "Inclusive"), ("enqueued_time", False, "Exclusive")]) @pytest.mark.liveTest -def test_receive_with_event_position_sync(connstr_senders, position, inclusive, expected_result): +def test_receive_with_event_position_sync(uamqp_transport, connstr_senders, position, inclusive, expected_result): def on_event(partition_context, event): assert partition_context.last_enqueued_event_properties.get('sequence_number') == event.sequence_number assert partition_context.last_enqueued_event_properties.get('offset') == event.offset @@ -69,7 +79,9 @@ def on_event(partition_context, event): connection_str, senders = connstr_senders senders[0].send(EventData(b"Inclusive")) senders[1].send(EventData(b"Inclusive")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"starting_position": "-1", @@ -82,7 +94,9 @@ def on_event(partition_context, event): thread.join() senders[0].send(EventData(expected_result)) senders[1].send(EventData(expected_result)) - client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client2 = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client2: thread = threading.Thread(target=client2.receive, args=(on_event,), kwargs={"starting_position": on_event.event_position, @@ -90,14 +104,44 @@ def on_event(partition_context, event): "track_last_enqueued_event_properties": True}) thread.daemon = True thread.start() - time.sleep(10) + time.sleep(15) assert on_event.event.body_as_str() == expected_result thread.join() - +# TODO: after fixing message property mutability, test +#@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +#@pytest.mark.liveTest +#def test_receive_modify_message_resend_sync(uamqp_transport, connstr_senders): +# received_modified = [False] +# def on_event(partition_context, event): +# message = event.message +# if message.properties.message_id == b'a1': +# message.properties.message_id = 'a2' +# senders[0].send(event) +# elif message.properties.message_id == b'a2': +# received_modified = [True] +# +# connection_str, senders = connstr_senders +# event = EventData("A", message_id='a1') +# senders[0].send(event) +# client = EventHubConsumerClient.from_connection_string( +# connection_str, consumer_group='$default', uamqp_transport=uamqp_transport +# ) +# with client: +# thread = threading.Thread(target=client.receive, args=(on_event,), +# kwargs={"partition_id": "0", "starting_position": "-1"}) +# thread.daemon = True +# thread.start() +# time.sleep(10) +# assert received_modified[0] +# thread.join() + + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_owner_level(connstr_senders): +def test_receive_owner_level(connstr_senders, uamqp_transport): def on_event(partition_context, event): pass def on_error(partition_context, error): @@ -105,8 +149,8 @@ def on_error(partition_context, error): on_error.error = None connection_str, senders = connstr_senders - client1 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') - client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client1 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) + client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) with client1, client2: thread1 = threading.Thread(target=client1.receive, args=(on_event,), kwargs={"partition_id": "0", "starting_position": "-1", @@ -128,8 +172,10 @@ def on_error(partition_context, error): assert isinstance(on_error.error, EventHubError) +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_over_websocket_sync(connstr_senders): +def test_receive_over_websocket_sync(connstr_senders, uamqp_transport): app_prop = {"raw_prop": "raw_value"} content_type = "text/plain" message_id_base = "mess_id_sample_" @@ -143,7 +189,8 @@ def on_event(partition_context, event): connection_str, senders = connstr_senders client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', - transport_type=TransportType.AmqpOverWebsocket) + transport_type=TransportType.AmqpOverWebsocket, + uamqp_transport=uamqp_transport) event_list = [] for i in range(5): diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py index 2a6c33c5bf25..21481b89081d 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py @@ -11,21 +11,34 @@ from azure.eventhub._pyamqp.client import ReceiveClient from azure.eventhub._pyamqp import error, constants - from azure.eventhub import ( EventData, EventHubSharedKeyCredential, EventHubProducerClient, - EventHubConsumerClient + EventHubConsumerClient, ) from azure.eventhub.exceptions import OperationTimeoutError - - +from azure.eventhub._utils import transform_outbound_single_message +try: + import uamqp + from uamqp import compat + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except (ImportError, ModuleNotFoundError): + UamqpTransport = None +from azure.eventhub._transport._pyamqp_transport import PyamqpTransport +from ..._test_case import get_decorator + +uamqp_transport_vals = get_decorator() + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_with_long_interval_sync(live_eventhub, sleep): +def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport): test_partition = "0" sender = EventHubProducerClient(live_eventhub['hostname'], live_eventhub['event_hub'], - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], + live_eventhub['access_key']), uamqp_transport=uamqp_transport + ) with sender: batch = sender.create_batch(partition_id=test_partition) batch.add(EventData(b"A single event")) @@ -63,51 +76,78 @@ def test_send_with_long_interval_sync(live_eventhub, sleep): assert list(received[0].body)[0] == b"A single event" +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers): +def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers + if uamqp_transport: + amqp_transport = UamqpTransport() + retry_total = 3 + else: + amqp_transport = PyamqpTransport() + retry_total = 0 # no retry, should just raise error - client = EventHubProducerClient.from_connection_string(conn_str=connection_str, idle_timeout=10, retry_total=0) - with client: - ed = EventData('data') - sender = client._create_producer(partition_id='0') - with sender: - sender._open_with_retry() - time.sleep(11) - sender._unsent_events = [ed.message] - with pytest.raises(error.AMQPConnectionError): - sender._send_event_data() - - # with retry, should work - client = EventHubProducerClient.from_connection_string(conn_str=connection_str, idle_timeout=10) + client = EventHubProducerClient.from_connection_string( + conn_str=connection_str, idle_timeout=10, retry_total=retry_total, uamqp_transport=uamqp_transport + ) with client: ed = EventData('data') sender = client._create_producer(partition_id='0') with sender: sender._open_with_retry() time.sleep(11) - sender._unsent_events = [ed.message] - sender._send_event_data() + ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) + sender._unsent_events = [ed._message] + if uamqp_transport: + sender._unsent_events[0].on_send_complete = sender._on_outcome + with pytest.raises((uamqp.errors.ConnectionClose, + uamqp.errors.MessageHandlerError, OperationTimeoutError)): + sender._send_event_data() + else: + with pytest.raises(error.AMQPConnectionError): + sender._send_event_data() + if uamqp_transport: + sender._send_event_data_with_retry() + + # pyamqp - with retry, should work + if not uamqp_transport: + client = EventHubProducerClient.from_connection_string( + conn_str=connection_str, idle_timeout=10, uamqp_transport=uamqp_transport + ) + with client: + ed = EventData('data') + sender = client._create_producer(partition_id='0') + with sender: + sender._open_with_retry() + time.sleep(11) + ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) + sender._unsent_events = [ed._message] + sender._send_event_data() retry = 0 while retry < 3: try: - messages = receivers[0].receive_message_batch(max_batch_size=10, timeout=10) + timeout = 10000 if uamqp_transport else 10 + messages = receivers[0].receive_message_batch(max_batch_size=10, timeout=timeout) if messages: received_ed1 = EventData._from_message(messages[0]) assert received_ed1.body_as_str() == 'data' break - except TimeoutError: + except (compat.TimeoutException, TimeoutError): retry += 1 +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_connection_idle_timeout_and_reconnect_sync(connstr_senders): +def test_receive_connection_idle_timeout_and_reconnect_sync(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders client = EventHubConsumerClient.from_connection_string( conn_str=connection_str, consumer_group='$default', - idle_timeout=10 + idle_timeout=10, + uamqp_transport=uamqp_transport ) def on_event_received(event): @@ -122,7 +162,10 @@ def on_event_received(event): senders[0].send(ed) consumer._handler.do_work() - assert consumer._handler._connection.state == constants.ConnectionState.END + if uamqp_transport: + assert consumer._handler._connection._state == uamqp.c_uamqp.ConnectionState.DISCARDING + else: + assert consumer._handler._connection.state == constants.ConnectionState.END duration = 10 now_time = time.time() diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index f04cb602e90b..8f6a76e01239 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -20,11 +20,21 @@ AmqpAnnotatedMessage, AmqpMessageProperties, ) +try: + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except (ImportError, ModuleNotFoundError): + UamqpTransport = None +from azure.eventhub._transport._pyamqp_transport import PyamqpTransport +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_with_partition_key(connstr_receivers): +def test_send_with_partition_key(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: data_val = 0 for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: @@ -49,12 +59,14 @@ def test_send_with_partition_key(connstr_receivers): found_partition_keys[event_data.partition_key] = index +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_and_receive_large_body_size(connstr_receivers): +def test_send_and_receive_large_body_size(connstr_receivers, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - open issue regarding message size") connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: payload = 250 * 1024 batch = client.create_batch() @@ -69,10 +81,12 @@ def test_send_and_receive_large_body_size(connstr_receivers): assert len(list(received[0].body)[0]) == payload +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_amqp_annotated_message(connstr_receivers): +def test_send_amqp_annotated_message(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: sequence_body = [b'message', 123.456, True] footer = {'footer_key': 'footer_value'} @@ -108,7 +122,7 @@ def test_send_amqp_annotated_message(connstr_receivers): ) body_ed = """{"json_key": "json_val"}""" - prop_ed = {"raw_prop": "raw_value"} + prop_ed = {b"raw_prop": b"raw_value"} cont_type_ed = "text/plain" corr_id_ed = "corr_id" mess_id_ed = "mess_id" @@ -116,6 +130,7 @@ def test_send_amqp_annotated_message(connstr_receivers): event_data.content_type = cont_type_ed event_data.correlation_id = corr_id_ed event_data.message_id = mess_id_ed + event_data.properties = prop_ed batch = client.create_batch() batch.add(data_message) @@ -146,6 +161,7 @@ def check_values(event): assert event.correlation_id == corr_id_ed assert event.message_id == mess_id_ed assert event.content_type == cont_type_ed + assert event.properties == prop_ed assert event.body_type == AmqpMessageBodyType.DATA received_count["normal_msg"] += 1 elif raw_amqp_message.body_type == AmqpMessageBodyType.SEQUENCE: @@ -168,7 +184,8 @@ def on_event(partition_context, event): on_event.received = [] client = EventHubConsumerClient.from_connection_string(connection_str, - consumer_group='$default') + consumer_group='$default', + uamqp_transport=uamqp_transport) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"starting_position": "-1"}) @@ -184,12 +201,13 @@ def on_event(partition_context, event): assert received_count["normal_msg"] == 2 +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("payload", - [b"", b"A single event"]) + [(b""), (b"A single event")]) @pytest.mark.liveTest -def test_send_and_receive_small_body(connstr_receivers, payload): +def test_send_and_receive_small_body(connstr_receivers, payload, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch() batch.add(EventData(payload)) @@ -202,10 +220,12 @@ def test_send_and_receive_small_body(connstr_receivers, payload): assert list(received[0].body)[0] == payload +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_partition(connstr_receivers): +def test_send_partition(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch() @@ -237,10 +257,12 @@ def test_send_partition(connstr_receivers): assert len(partition_0) + len(partition_1) == 2 +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_non_ascii(connstr_receivers): +def test_send_non_ascii(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch(partition_id="0") batch.add(EventData(u"é,è,à,ù,â,ê,î,ô,û")) @@ -257,13 +279,15 @@ def test_send_non_ascii(connstr_receivers): assert partition_0[1].body_as_json() == {"foo": u"漢字"} +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_multiple_partitions_with_app_prop(connstr_receivers): +def test_send_multiple_partitions_with_app_prop(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers app_prop_key = "raw_prop" app_prop_value = "raw_value" app_prop = {app_prop_key: app_prop_value} - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: ed0 = EventData(b"Message 0") ed0.properties = app_prop @@ -285,11 +309,14 @@ def test_send_multiple_partitions_with_app_prop(connstr_receivers): assert partition_1[0].properties[b"raw_prop"] == b"raw_value" +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_over_websocket_sync(connstr_receivers): - pytest.skip("websocket not supported") +def test_send_over_websocket_sync(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) + client = EventHubProducerClient.from_connection_string( + connection_str, transport_type=TransportType.AmqpOverWebsocket, uamqp_transport=uamqp_transport + ) with client: batch = client.create_batch(partition_id="0") @@ -302,13 +329,17 @@ def test_send_over_websocket_sync(connstr_receivers): assert len(received) == 1 +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers): +def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers app_prop_key = "raw_prop" app_prop_value = "raw_value" app_prop = {app_prop_key: app_prop_value} - client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) + client = EventHubProducerClient.from_connection_string( + connection_str, transport_type=TransportType.AmqpOverWebsocket, uamqp_transport=uamqp_transport + ) with client: event_data_batch = client.create_batch(max_size_in_bytes=100000) while True: @@ -326,10 +357,12 @@ def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers): assert EventData._from_message(received[0]).properties[b"raw_prop"] == b"raw_value" +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_list(connstr_receivers): +def test_send_list(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) payload = "A1" with client: client.send_batch([EventData(payload)]) @@ -341,10 +374,12 @@ def test_send_list(connstr_receivers): assert received[0].body_as_str() == payload +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_list_partition(connstr_receivers): +def test_send_list_partition(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) payload = "A1" with client: client.send_batch([EventData(payload)], partition_id="0") @@ -353,23 +388,25 @@ def test_send_list_partition(connstr_receivers): assert received.body_as_str() == payload +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("to_send, exception_type", [([EventData("A"*1024)]*1100, ValueError), - ("any str", AttributeError) - ]) + ("any str", AttributeError)]) @pytest.mark.liveTest -def test_send_list_wrong_data(connection_str, to_send, exception_type): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_send_list_wrong_data(connection_str, to_send, exception_type, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: with pytest.raises(exception_type): client.send_batch(to_send) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("partition_id, partition_key", [("0", None), (None, "pk")]) -def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key): +def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key, uamqp_transport): # Use invalid_hostname because this is not a live test. - client = EventHubProducerClient.from_connection_string(invalid_hostname) - batch = EventDataBatch(partition_id=partition_id, partition_key=partition_key) + amqp_transport = UamqpTransport() if uamqp_transport else PyamqpTransport() + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) + batch = EventDataBatch(partition_id=partition_id, partition_key=partition_key, amqp_transport=amqp_transport) with client: with pytest.raises(TypeError): - client.send_batch(batch, partition_id=partition_id, partition_key=partition_key) + client.send_batch(batch, partition_id=partition_id, partition_key=partition_key, amqp_transport=amqp_transport) diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py b/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index 612f2e8b9605..3012739a05d6 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -1,9 +1,26 @@ +# -- coding: utf-8 -- +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + import platform import pytest -from packaging import version -from azure.eventhub.amqp import AmqpAnnotatedMessage +try: + import uamqp + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except ImportError: + UamqpTransport = None + pass +from azure.eventhub._transport._pyamqp_transport import PyamqpTransport +from azure.eventhub.amqp import AmqpAnnotatedMessage, AmqpMessageHeader, AmqpMessageProperties from azure.eventhub import _common -from azure.eventhub._pyamqp.message import Message, Properties +from azure.eventhub._pyamqp.message import Message, Properties, Header +from azure.eventhub._utils import transform_outbound_single_message +from .._test_case import get_decorator + +uamqp_transport_vals = get_decorator() pytestmark = pytest.mark.skipif(platform.python_implementation() == "PyPy", reason="This is ignored for PyPy") @@ -55,24 +72,44 @@ def test_app_properties(): assert event_data.properties["a"] == "b" -def test_sys_properties(): - properties = Properties( - message_id="message_id", - user_id="user_id", - to="to", - subject="subject", - reply_to="reply_to", - correlation_id="correlation_id", - content_type="content_type", - content_encoding="content_encoding", - absolute_expiry_time=1, - creation_time=1, - group_id="group_id", - group_sequence=1, - reply_to_group_id="reply_to_group_id" - ) - message_annotations = {_common.PROP_OFFSET: "@latest"} - message = Message(properties=properties, message_annotations=message_annotations) +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_sys_properties(uamqp_transport): + if uamqp_transport: + properties = uamqp.message.MessageProperties() + properties.message_id = "message_id" + properties.user_id = "user_id" + properties.to = "to" + properties.subject = "subject" + properties.reply_to = "reply_to" + properties.correlation_id = "correlation_id" + properties.content_type = "content_type" + properties.content_encoding = "content_encoding" + properties.absolute_expiry_time = 1 + properties.creation_time = 1 + properties.group_id = "group_id" + properties.group_sequence = 1 + properties.reply_to_group_id = "reply_to_group_id" + message = uamqp.message.Message(properties=properties) + message.annotations = {_common.PROP_OFFSET: "@latest"} + else: + properties = Properties( + message_id="message_id", + user_id="user_id", + to="to", + subject="subject", + reply_to="reply_to", + correlation_id="correlation_id", + content_type="content_type", + content_encoding="content_encoding", + absolute_expiry_time=1, + creation_time=1, + group_id="group_id", + group_sequence=1, + reply_to_group_id="reply_to_group_id" + ) + message_annotations = {_common.PROP_OFFSET: "@latest"} + message = Message(properties=properties, message_annotations=message_annotations) ed = EventData._from_message(message) # type: EventData assert ed.system_properties[_common.PROP_OFFSET] == "@latest" @@ -91,22 +128,37 @@ def test_sys_properties(): assert ed.system_properties[_common.PROP_REPLY_TO_GROUP_ID] == properties.reply_to_group_id -def test_event_data_batch(): - batch = EventDataBatch(max_size_in_bytes=110, partition_key="par") +# TODO: see why pyamqp went from 99 to 87 +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_event_data_batch(uamqp_transport): + if uamqp_transport: + amqp_transport = UamqpTransport() + expected_result = 101 + else: + amqp_transport = PyamqpTransport() + expected_result = 87 + batch = EventDataBatch(max_size_in_bytes=110, partition_key="par", amqp_transport=amqp_transport) batch.add(EventData("A")) assert str(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" assert repr(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" # TODO: uamqp uses 93 bytes for encode, while python amqp uses 99 bytes # we should understand why extra bytes are needed to encode the content and how it could be improved - assert batch.size_in_bytes == 99 and len(batch) == 1 + assert batch.size_in_bytes == expected_result and len(batch) == 1 with pytest.raises(ValueError): batch.add(EventData("A")) -def test_event_data_from_message(): - message = Message(data=b'A') +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +def test_event_data_from_message(uamqp_transport): + if uamqp_transport: + amqp_transport = UamqpTransport() + else: + amqp_transport = PyamqpTransport() + annotated_message = AmqpAnnotatedMessage(data_body=b'A') + message = amqp_transport.to_outgoing_amqp_message(annotated_message) event = EventData._from_message(message) assert event.content_type is None assert event.correlation_id is None @@ -118,7 +170,7 @@ def test_event_data_from_message(): assert event.content_type == 'content_type' assert event.correlation_id == 'correlation_id' assert event.message_id == 'message_id' - assert event.body == b'A' + assert list(event.body) == [b'A'] def test_amqp_message_str_repr(): @@ -126,3 +178,483 @@ def test_amqp_message_str_repr(): message = AmqpAnnotatedMessage(data_body=data_body) assert str(message) == 'A' assert 'AmqpAnnotatedMessage(body=A, body_type=data' in repr(message) + + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_amqp_message_from_message(uamqp_transport): + if uamqp_transport: + header = uamqp.message.MessageHeader() + header.delivery_count = 1 + header.time_to_live = 10000 + header.first_acquirer = True + header.durable = True + header.priority = 1 + properties = uamqp.message.MessageProperties() + properties.message_id = "message_id" + properties.user_id = "user_id" + properties.to = "to" + properties.subject = "subject" + properties.reply_to = "reply_to" + properties.correlation_id = "correlation_id" + properties.content_type = "content_type" + properties.content_encoding = "content_encoding" + properties.absolute_expiry_time = 1 + properties.creation_time = 1 + properties.group_id = "group_id" + properties.group_sequence = 1 + properties.reply_to_group_id = "reply_to_group_id" + message = uamqp.message.Message(header=header, properties=properties) + message.annotations = {_common.PROP_OFFSET: "@latest"} + else: + header = Header( + delivery_count=1, + ttl=10000, + first_acquirer=True, + durable=True, + priority=1 + ) + properties = Properties( + message_id="message_id", + user_id="user_id", + to="to", + subject="subject", + reply_to="reply_to", + correlation_id="correlation_id", + content_type="content_type", + content_encoding="content_encoding", + absolute_expiry_time=1, + creation_time=1, + group_id="group_id", + group_sequence=1, + reply_to_group_id="reply_to_group_id" + ) + message_annotations = {_common.PROP_OFFSET: "@latest"} + message = Message(properties=properties, header=header, message_annotations=message_annotations) + + amqp_message = AmqpAnnotatedMessage(message=message) + assert amqp_message.properties.message_id == message.properties.message_id + assert amqp_message.properties.user_id == message.properties.user_id + assert amqp_message.properties.to == message.properties.to + assert amqp_message.properties.subject == message.properties.subject + assert amqp_message.properties.reply_to == message.properties.reply_to + assert amqp_message.properties.correlation_id == message.properties.correlation_id + assert amqp_message.properties.content_type == message.properties.content_type + assert amqp_message.properties.absolute_expiry_time == message.properties.absolute_expiry_time + assert amqp_message.properties.creation_time == message.properties.creation_time + assert amqp_message.properties.group_id == message.properties.group_id + assert amqp_message.properties.group_sequence == message.properties.group_sequence + assert amqp_message.properties.reply_to_group_id == message.properties.reply_to_group_id + assert amqp_message.header.time_to_live == message.header.ttl + assert amqp_message.header.delivery_count == message.header.delivery_count + assert amqp_message.header.first_acquirer == message.header.first_acquirer + assert amqp_message.header.durable == message.header.durable + assert amqp_message.header.priority == message.header.priority + assert amqp_message.annotations == message.message_annotations + +# TODO: ADD MESSAGE BACKCOMPAT TESTS +#class EventDataMessageBackcompatTests: +# +# def test_message_backcompat_receive_and_delete_databody(): +# outgoing_event_data = EventData(body="hello") +# outgoing_event_data.application_properties = {'prop': 'test'} +# outgoing_event_data.session_id = "id_session" +# outgoing_event_data.message_id = "id_message" +# outgoing_event_data.time_to_live = timedelta(seconds=30) +# outgoing_event_data.content_type = "content type" +# outgoing_event_data.correlation_id = "correlation" +# outgoing_event_data.subject = "github" +# outgoing_event_data.partition_key = "id_session" +# outgoing_event_data.to = "forward to" +# outgoing_event_data.reply_to = "reply to" +# outgoing_event_data.reply_to_session_id = "reply to session" +# +# # TODO: Attribute shouldn't exist until after message has been sent. +# # with pytest.raises(AttributeError): +# # outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=True) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# assert outgoing_message.message +# with pytest.raises(TypeError): +# outgoing_message.message.accept() +# with pytest.raises(TypeError): +# outgoing_message.message.release() +# with pytest.raises(TypeError): +# outgoing_message.message.reject() +# with pytest.raises(TypeError): +# outgoing_message.message.modify(True, True) +# assert outgoing_message.message.state == uamqp.constants.MessageState.SendComplete +# assert outgoing_message.message.settled +# assert outgoing_message.message.delivery_annotations is None +# assert outgoing_message.message.delivery_no is None +# assert outgoing_message.message.delivery_tag is None +# assert outgoing_message.message.on_send_complete is None +# assert outgoing_message.message.footer is None +# assert outgoing_message.message.retries >= 0 +# assert outgoing_message.message.idle_time >= 0 +# with pytest.raises(Exception): +# outgoing_message.message.gather() +# assert isinstance(outgoing_message.message.encode_message(), bytes) +# assert outgoing_message.message.get_message_encoded_size() == 208 +# assert list(outgoing_message.message.get_data()) == [b'hello'] +# assert outgoing_message.message.application_properties == {'prop': 'test'} +# assert outgoing_message.message.get_message() # C instance. +# assert len(outgoing_message.message.annotations) == 1 +# assert list(outgoing_message.message.annotations.values())[0] == 'id_session' +# assert str(outgoing_message.message.header) == str({'delivery_count': None, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None}) +# assert outgoing_message.message.header.get_header_obj().delivery_count is None +# assert outgoing_message.message.properties.message_id == b'id_message' +# assert outgoing_message.message.properties.user_id is None +# assert outgoing_message.message.properties.to == b'forward to' +# assert outgoing_message.message.properties.subject == b'github' +# assert outgoing_message.message.properties.reply_to == b'reply to' +# assert outgoing_message.message.properties.correlation_id == b'correlation' +# assert outgoing_message.message.properties.content_type == b'content type' +# assert outgoing_message.message.properties.content_encoding is None +# assert outgoing_message.message.properties.absolute_expiry_time +# assert outgoing_message.message.properties.creation_time +# assert outgoing_message.message.properties.group_id == b'id_session' +# assert outgoing_message.message.properties.group_sequence is None +# assert outgoing_message.message.properties.reply_to_group_id == b'reply to session' +# assert outgoing_message.message.properties.get_properties_obj().message_id +# +# # TODO: Test updating message and resending +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# assert incoming_message.message.delivery_annotations == {} +# assert incoming_message.message.delivery_no >= 1 +# assert incoming_message.message.delivery_tag is None +# assert incoming_message.message.on_send_complete is None +# assert incoming_message.message.footer is None +# assert incoming_message.message.retries >= 0 +# assert incoming_message.message.idle_time == 0 +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert isinstance(incoming_message.message.encode_message(), bytes) +# # TODO: Pyamqp has size at 266 +# # assert incoming_message.message.get_message_encoded_size() == 267 +# assert list(incoming_message.message.get_data()) == [b'hello'] +# assert incoming_message.message.application_properties == {b'prop': b'test'} +# assert incoming_message.message.get_message() # C instance. +# assert len(incoming_message.message.annotations) == 3 +# assert incoming_message.message.annotations[b'x-opt-enqueued-time'] > 0 +# assert incoming_message.message.annotations[b'x-opt-sequence-number'] > 0 +# assert incoming_message.message.annotations[b'x-opt-partition-key'] == b'id_session' +# # TODO: Pyamqp has header {'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None} +# # assert str(incoming_message.message.header) == str({'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': True, 'durable': True, 'priority': 4}) +# assert incoming_message.message.header.get_header_obj().delivery_count == 0 +# assert incoming_message.message.properties.message_id == b'id_message' +# assert incoming_message.message.properties.user_id is None +# assert incoming_message.message.properties.to == b'forward to' +# assert incoming_message.message.properties.subject == b'github' +# assert incoming_message.message.properties.reply_to == b'reply to' +# assert incoming_message.message.properties.correlation_id == b'correlation' +# assert incoming_message.message.properties.content_type == b'content type' +# assert incoming_message.message.properties.content_encoding is None +# assert incoming_message.message.properties.absolute_expiry_time +# assert incoming_message.message.properties.creation_time +# assert incoming_message.message.properties.group_id == b'id_session' +# assert incoming_message.message.properties.group_sequence is None +# assert incoming_message.message.properties.reply_to_group_id == b'reply to session' +# assert incoming_message.message.properties.get_properties_obj().message_id +# assert not incoming_message.message.accept() +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# # TODO: Test updating message and resending +# +# @pytest.mark.liveTest +# @pytest.mark.live_test_only +# @CachedResourceGroupPreparer(name_prefix='servicebustest') +# @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') +# @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) +# def test_message_backcompat_peek_lock_databody(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): +# queue_name = servicebus_queue.name +# outgoing_message = ServiceBusMessage( +# body="hello", +# application_properties={'prop': 'test'}, +# session_id="id_session", +# message_id="id_message", +# time_to_live=timedelta(seconds=30), +# content_type="content type", +# correlation_id="correlation", +# subject="github", +# partition_key="id_session", +# to="forward to", +# reply_to="reply to", +# reply_to_session_id="reply to session" +# ) +# +# # TODO: Attribute shouldn't exist until after message has been sent. +# # with pytest.raises(AttributeError): +# # outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=True) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# assert outgoing_message.message +# with pytest.raises(TypeError): +# outgoing_message.message.accept() +# with pytest.raises(TypeError): +# outgoing_message.message.release() +# with pytest.raises(TypeError): +# outgoing_message.message.reject() +# with pytest.raises(TypeError): +# outgoing_message.message.modify(True, True) +# assert outgoing_message.message.state == uamqp.constants.MessageState.SendComplete +# assert outgoing_message.message.settled +# assert outgoing_message.message.delivery_annotations is None +# assert outgoing_message.message.delivery_no is None +# assert outgoing_message.message.delivery_tag is None +# assert outgoing_message.message.on_send_complete is None +# assert outgoing_message.message.footer is None +# assert outgoing_message.message.retries >= 0 +# assert outgoing_message.message.idle_time >= 0 +# with pytest.raises(Exception): +# outgoing_message.message.gather() +# assert isinstance(outgoing_message.message.encode_message(), bytes) +# assert outgoing_message.message.get_message_encoded_size() == 208 +# assert list(outgoing_message.message.get_data()) == [b'hello'] +# assert outgoing_message.message.application_properties == {'prop': 'test'} +# assert outgoing_message.message.get_message() # C instance. +# assert len(outgoing_message.message.annotations) == 1 +# assert list(outgoing_message.message.annotations.values())[0] == 'id_session' +# assert str(outgoing_message.message.header) == str({'delivery_count': None, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None}) +# assert outgoing_message.message.header.get_header_obj().delivery_count is None +# assert outgoing_message.message.properties.message_id == b'id_message' +# assert outgoing_message.message.properties.user_id is None +# assert outgoing_message.message.properties.to == b'forward to' +# assert outgoing_message.message.properties.subject == b'github' +# assert outgoing_message.message.properties.reply_to == b'reply to' +# assert outgoing_message.message.properties.correlation_id == b'correlation' +# assert outgoing_message.message.properties.content_type == b'content type' +# assert outgoing_message.message.properties.content_encoding is None +# assert outgoing_message.message.properties.absolute_expiry_time +# assert outgoing_message.message.properties.creation_time +# assert outgoing_message.message.properties.group_id == b'id_session' +# assert outgoing_message.message.properties.group_sequence is None +# assert outgoing_message.message.properties.reply_to_group_id == b'reply to session' +# assert outgoing_message.message.properties.get_properties_obj().message_id +# +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.PEEK_LOCK, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedUnsettled +# assert not incoming_message.message.settled +# assert incoming_message.message.delivery_annotations[b'x-opt-lock-token'] +# assert incoming_message.message.delivery_no >= 1 +# assert incoming_message.message.delivery_tag +# assert incoming_message.message.on_send_complete is None +# assert incoming_message.message.footer is None +# assert incoming_message.message.retries >= 0 +# assert incoming_message.message.idle_time == 0 +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert isinstance(incoming_message.message.encode_message(), bytes) +# # TODO: Pyamqp has size at 336 +# # assert incoming_message.message.get_message_encoded_size() == 334 +# assert list(incoming_message.message.get_data()) == [b'hello'] +# assert incoming_message.message.application_properties == {b'prop': b'test'} +# assert incoming_message.message.get_message() # C instance. +# assert len(incoming_message.message.annotations) == 4 +# assert incoming_message.message.annotations[b'x-opt-enqueued-time'] > 0 +# assert incoming_message.message.annotations[b'x-opt-sequence-number'] > 0 +# assert incoming_message.message.annotations[b'x-opt-partition-key'] == b'id_session' +# assert incoming_message.message.annotations[b'x-opt-locked-until'] +# # TODO: Pyamqp has header {'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None} +# # assert str(incoming_message.message.header) == str({'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': True, 'durable': True, 'priority': 4}) +# assert str(incoming_message.message.header) == str({'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None}) +# assert incoming_message.message.header.get_header_obj().delivery_count == 0 +# assert incoming_message.message.properties.message_id == b'id_message' +# assert incoming_message.message.properties.user_id is None +# assert incoming_message.message.properties.to == b'forward to' +# assert incoming_message.message.properties.subject == b'github' +# assert incoming_message.message.properties.reply_to == b'reply to' +# assert incoming_message.message.properties.correlation_id == b'correlation' +# assert incoming_message.message.properties.content_type == b'content type' +# assert incoming_message.message.properties.content_encoding is None +# assert incoming_message.message.properties.absolute_expiry_time +# assert incoming_message.message.properties.creation_time +# assert incoming_message.message.properties.group_id == b'id_session' +# assert incoming_message.message.properties.group_sequence is None +# assert incoming_message.message.properties.reply_to_group_id == b'reply to session' +# assert incoming_message.message.properties.get_properties_obj().message_id +# assert incoming_message.message.accept() +# # TODO: State isn't updated if settled correctly via the receiver. +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# @pytest.mark.liveTest +# @pytest.mark.live_test_only +# @CachedResourceGroupPreparer(name_prefix='servicebustest') +# @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') +# @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) +# def test_message_backcompat_receive_and_delete_valuebody(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): +# queue_name = servicebus_queue.name +# outgoing_message = AmqpAnnotatedMessage(value_body={b"key": b"value"}) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=False) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert incoming_message.message.get_data() == {b"key": b"value"} +# assert not incoming_message.message.accept() +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# @pytest.mark.liveTest +# @pytest.mark.live_test_only +# @CachedResourceGroupPreparer(name_prefix='servicebustest') +# @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') +# @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) +# def test_message_backcompat_peek_lock_valuebody(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): +# queue_name = servicebus_queue.name +# outgoing_message = AmqpAnnotatedMessage(value_body={b"key": b"value"}) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=False) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.PEEK_LOCK, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedUnsettled +# assert not incoming_message.message.settled +# assert incoming_message.message.delivery_annotations[b'x-opt-lock-token'] +# assert incoming_message.message.delivery_no >= 1 +# assert incoming_message.message.delivery_tag +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert incoming_message.message.get_data() == {b"key": b"value"} +# assert incoming_message.message.accept() +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# @pytest.mark.liveTest +# @pytest.mark.live_test_only +# @CachedResourceGroupPreparer(name_prefix='servicebustest') +# @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') +# @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) +# def test_message_backcompat_receive_and_delete_sequencebody(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): +# queue_name = servicebus_queue.name +# outgoing_message = AmqpAnnotatedMessage(sequence_body=[1, 2, 3]) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=False) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert list(incoming_message.message.get_data()) == [[1, 2, 3]] +# assert not incoming_message.message.accept() +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# @pytest.mark.liveTest +# @pytest.mark.live_test_only +# @CachedResourceGroupPreparer(name_prefix='servicebustest') +# @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') +# @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) +# def test_message_backcompat_peek_lock_sequencebody(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): +# queue_name = servicebus_queue.name +# outgoing_message = AmqpAnnotatedMessage(sequence_body=[1, 2, 3]) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=False) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.PEEK_LOCK, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedUnsettled +# assert not incoming_message.message.settled +# assert incoming_message.message.delivery_annotations[b'x-opt-lock-token'] +# assert incoming_message.message.delivery_no >= 1 +# assert incoming_message.message.delivery_tag +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert list(incoming_message.message.get_data()) == [[1, 2, 3]] +# assert incoming_message.message.accept() +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# # TODO: Add batch message backcompat tests