diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 811230155bbb..7fee9a8442fe 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -10,6 +10,8 @@ - The `text` property on `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse` has changed to a method, which also takes an `encoding` parameter. +- `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse` are now abstract base classes. They should not be initialized directly, instead +your transport responses should inherit from them and implement them. ### Bugs Fixed diff --git a/sdk/core/azure-core/azure/core/_pipeline_client.py b/sdk/core/azure-core/azure/core/_pipeline_client.py index 44e5995c5f84..7bc1f0d1fb3b 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client.py @@ -39,7 +39,6 @@ RequestIdPolicy, RetryPolicy, ) -from .pipeline._tools import to_rest_response as _to_rest_response try: from typing import TYPE_CHECKING @@ -65,17 +64,6 @@ _LOGGER = logging.getLogger(__name__) -def _prepare_request(request): - # returns the request ready to run through pipelines - # and a bool telling whether we ended up converting it - rest_request = False - try: - request_to_run = request._to_pipeline_transport_request() # pylint: disable=protected-access - rest_request = True - except AttributeError: - request_to_run = request - return rest_request, request_to_run - class PipelineClient(PipelineClientBase): """Service client core methods. @@ -203,22 +191,9 @@ def send_request(self, request, **kwargs): :keyword bool stream: Whether the response payload will be streamed. Defaults to False. :return: The response of your network call. Does not do error handling on your response. :rtype: ~azure.core.rest.HttpResponse - # """ - rest_request, request_to_run = _prepare_request(request) + """ return_pipeline_response = kwargs.pop("_return_pipeline_response", False) - pipeline_response = self._pipeline.run(request_to_run, **kwargs) # pylint: disable=protected-access - response = pipeline_response.http_response - if rest_request: - response = _to_rest_response(response) - try: - if not kwargs.get("stream", False): - response.read() - response.close() - except Exception as exc: - response.close() - raise exc + pipeline_response = self._pipeline.run(request, **kwargs) # pylint: disable=protected-access if return_pipeline_response: - pipeline_response.http_response = response - pipeline_response.http_request = request return pipeline_response - return response + return pipeline_response.http_response diff --git a/sdk/core/azure-core/azure/core/_pipeline_client_async.py b/sdk/core/azure-core/azure/core/_pipeline_client_async.py index 357b3d9b917d..423d0efa45de 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client_async.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client_async.py @@ -25,7 +25,7 @@ # -------------------------------------------------------------------------- import logging -from collections.abc import Iterable +import collections.abc from typing import Any, Awaitable from .configuration import Configuration from .pipeline import AsyncPipeline @@ -37,8 +37,6 @@ RequestIdPolicy, AsyncRetryPolicy, ) -from ._pipeline_client import _prepare_request -from .pipeline._tools_async import to_rest_response as _to_rest_response try: from typing import TYPE_CHECKING, TypeVar @@ -63,6 +61,26 @@ _LOGGER = logging.getLogger(__name__) +class _AsyncContextManager(collections.abc.Awaitable): + + def __init__(self, wrapped: collections.abc.Awaitable): + super().__init__() + self.wrapped = wrapped + self.response = None + + def __await__(self): + return self.wrapped.__await__() + + async def __aenter__(self): + self.response = await self + return self.response + + async def __aexit__(self, *args): + await self.response.__aexit__(*args) + + async def close(self): + await self.response.close() + class AsyncPipelineClient(PipelineClientBase): """Service client core methods. @@ -126,7 +144,7 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use config.proxy_policy, ContentDecodePolicy(**kwargs) ] - if isinstance(per_call_policies, Iterable): + if isinstance(per_call_policies, collections.abc.Iterable): policies.extend(per_call_policies) else: policies.append(per_call_policies) @@ -135,7 +153,7 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use config.retry_policy, config.authentication_policy, config.custom_hook_policy]) - if isinstance(per_retry_policies, Iterable): + if isinstance(per_retry_policies, collections.abc.Iterable): policies.extend(per_retry_policies) else: policies.append(per_retry_policies) @@ -144,13 +162,13 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use DistributedTracingPolicy(**kwargs), config.http_logging_policy or HttpLoggingPolicy(**kwargs)]) else: - if isinstance(per_call_policies, Iterable): + if isinstance(per_call_policies, collections.abc.Iterable): per_call_policies_list = list(per_call_policies) else: per_call_policies_list = [per_call_policies] per_call_policies_list.extend(policies) policies = per_call_policies_list - if isinstance(per_retry_policies, Iterable): + if isinstance(per_retry_policies, collections.abc.Iterable): per_retry_policies_list = list(per_retry_policies) else: per_retry_policies_list = [per_retry_policies] @@ -175,30 +193,13 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use return AsyncPipeline(transport, policies) async def _make_pipeline_call(self, request, **kwargs): - rest_request, request_to_run = _prepare_request(request) return_pipeline_response = kwargs.pop("_return_pipeline_response", False) pipeline_response = await self._pipeline.run( - request_to_run, **kwargs # pylint: disable=protected-access + request, **kwargs # pylint: disable=protected-access ) - response = pipeline_response.http_response - if rest_request: - rest_response = _to_rest_response(response) - if not kwargs.get("stream"): - try: - # in this case, the pipeline transport response already called .load_body(), so - # the body is loaded. instead of doing response.read(), going to set the body - # to the internal content - rest_response._content = response.body() # pylint: disable=protected-access - await rest_response.close() - except Exception as exc: - await rest_response.close() - raise exc - response = rest_response if return_pipeline_response: - pipeline_response.http_response = response - pipeline_response.http_request = request return pipeline_response - return response + return pipeline_response.http_response def send_request( self, @@ -223,6 +224,5 @@ def send_request( :return: The response of your network call. Does not do error handling on your response. :rtype: ~azure.core.rest.AsyncHttpResponse """ - from .rest._rest_py3 import _AsyncContextManager wrapped = self._make_pipeline_call(request, stream=stream, **kwargs) return _AsyncContextManager(wrapped=wrapped) diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools.py b/sdk/core/azure-core/azure/core/pipeline/_tools.py index a8beebd75b99..d77ec1cb5490 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools.py @@ -33,39 +33,11 @@ def await_result(func, *args, **kwargs): ) return result -def to_rest_request(pipeline_transport_request): - from ..rest import HttpRequest as RestHttpRequest - return RestHttpRequest( - method=pipeline_transport_request.method, - url=pipeline_transport_request.url, - headers=pipeline_transport_request.headers, - files=pipeline_transport_request.files, - data=pipeline_transport_request.data - ) - -def to_rest_response(pipeline_transport_response): - from .transport._requests_basic import RequestsTransportResponse - from ..rest._requests_basic import RestRequestsTransportResponse - from ..rest import HttpResponse - if isinstance(pipeline_transport_response, RequestsTransportResponse): - response_type = RestRequestsTransportResponse - else: - response_type = HttpResponse - response = response_type( - request=to_rest_request(pipeline_transport_response.request), - internal_response=pipeline_transport_response.internal_response, - ) - response._connection_data_block_size = pipeline_transport_response.block_size # pylint: disable=protected-access - return response - -def get_block_size(response): - try: - return response._connection_data_block_size # pylint: disable=protected-access - except AttributeError: - return response.block_size - -def get_internal_response(response): +def read_in_response(response, is_stream_response): try: - return response._internal_response # pylint: disable=protected-access - except AttributeError: - return response.internal_response + if not is_stream_response: + response.read() + response.close() + except Exception as exc: + response.close() + raise exc diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py index de59dfdd86ed..ac42895a5452 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py @@ -23,7 +23,7 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- -from ._tools import to_rest_request +from typing import Optional async def await_result(func, *args, **kwargs): """If func returns an awaitable, await it.""" @@ -33,36 +33,11 @@ async def await_result(func, *args, **kwargs): return await result # type: ignore return result -def _get_response_type(pipeline_transport_response): +async def read_in_response(response, is_stream_response: Optional[bool]) -> None: try: - from .transport import AioHttpTransportResponse - from ..rest._aiohttp import RestAioHttpTransportResponse - if isinstance(pipeline_transport_response, AioHttpTransportResponse): - return RestAioHttpTransportResponse - except ImportError: - pass - try: - from .transport import AsyncioRequestsTransportResponse - from ..rest._requests_asyncio import RestAsyncioRequestsTransportResponse - if isinstance(pipeline_transport_response, AsyncioRequestsTransportResponse): - return RestAsyncioRequestsTransportResponse - except ImportError: - pass - try: - from .transport import TrioRequestsTransportResponse - from ..rest._requests_trio import RestTrioRequestsTransportResponse - if isinstance(pipeline_transport_response, TrioRequestsTransportResponse): - return RestTrioRequestsTransportResponse - except ImportError: - pass - from ..rest import AsyncHttpResponse - return AsyncHttpResponse - -def to_rest_response(pipeline_transport_response): - response_type = _get_response_type(pipeline_transport_response) - response = response_type( - request=to_rest_request(pipeline_transport_response.request), - internal_response=pipeline_transport_response.internal_response, - ) - response._connection_data_block_size = pipeline_transport_response.block_size # pylint: disable=protected-access - return response + if not is_stream_response: + await response.read() + await response.close() + except Exception as exc: + await response.close() + raise exc diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py b/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py index 1e8d712bc12d..d92d1252d6d1 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py @@ -576,6 +576,8 @@ def deserialize_from_http_generics( mime_type = "application/json" # Rely on transport implementation to give me "text()" decoded correctly + if hasattr(response, "read"): + response.read() return cls.deserialize_from_text(response.text(encoding), mime_type, response=response) def on_request(self, request): diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index c5cd6816f2f5..d8766966bdf0 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -25,12 +25,12 @@ # -------------------------------------------------------------------------- import sys from typing import Any, Optional, AsyncIterator as AsyncIteratorType -from collections.abc import AsyncIterator +import collections.abc try: import cchardet as chardet except ImportError: # pragma: no cover import chardet # type: ignore - +from itertools import groupby import logging import asyncio import codecs @@ -41,17 +41,109 @@ from azure.core.exceptions import ServiceRequestError, ServiceResponseError from azure.core.pipeline import Pipeline -from ._base import HttpRequest +from ._base import HttpRequest as PipelineTransportHttpRequest from ._base_async import ( AsyncHttpTransport, AsyncHttpResponse, - _ResponseStopIteration) -from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response + _ResponseStopIteration, + RestAsyncHttpResponseImpl, +) + +from .._tools_async import read_in_response as _read_in_response +from ...rest import ( + HttpRequest as RestHttpRequest, AsyncHttpResponse as RestAsyncHttpResponse +) # Matching requests, because why not? CONTENT_CHUNK_SIZE = 10 * 1024 _LOGGER = logging.getLogger(__name__) +class _ItemsView(collections.abc.ItemsView): + def __init__(self, ref): + super().__init__(ref) + self._ref = ref + + def __iter__(self): + for key, groups in groupby(self._ref.__iter__(), lambda x: x[0]): + yield tuple([key, ", ".join(group[1] for group in groups)]) + + def __contains__(self, item): + if not (isinstance(item, (list, tuple)) and len(item) == 2): + return False + for k, v in self.__iter__(): + if item[0].lower() == k.lower() and item[1] == v: + return True + return False + + def __repr__(self): + return f"dict_items({list(self.__iter__())})" + +class _KeysView(collections.abc.KeysView): + def __init__(self, items): + super().__init__(items) + self._items = items + + def __iter__(self): + for key, _ in self._items: + yield key + + def __contains__(self, key): + for k in self.__iter__(): + if key.lower() == k.lower(): + return True + return False + def __repr__(self): + return f"dict_keys({list(self.__iter__())})" + +class _ValuesView(collections.abc.ValuesView): + def __init__(self, items): + super().__init__(items) + self._items = items + + def __iter__(self): + for _, value in self._items: + yield value + + def __contains__(self, value): + for v in self.__iter__(): + if value == v: + return True + return False + + def __repr__(self): + return f"dict_values({list(self.__iter__())})" + + +class _CIMultiDict(CIMultiDict): + """Dictionary with the support for duplicate case-insensitive keys.""" + + def __iter__(self): + return iter(self.keys()) + + def keys(self): + """Return a new view of the dictionary's keys.""" + return _KeysView(self.items()) + + def items(self): + """Return a new view of the dictionary's items.""" + return _ItemsView(super().items()) + + def values(self): + """Return a new view of the dictionary's values.""" + return _ValuesView(self.items()) + + def __getitem__(self, key: str) -> str: + return ", ".join(self.getall(key, [])) + + def __setitem__(self, key, value) -> None: + self.update({key: value}) + + def get(self, key, default=None): + values = self.getall(key, None) + if values: + values = ", ".join(values) + return values or default + class AioHttpTransport(AsyncHttpTransport): """AioHttp HTTP sender implementation. @@ -135,17 +227,17 @@ def _get_request_data(self, request): #pylint: disable=no-self-use return form_data return request.data - async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpResponse]: + async def send(self, request: RestHttpRequest, **config: Any) -> Optional[RestAsyncHttpResponse]: """Send the request using this HTTP sender. Will pre-load the body into memory to be available with a sync method. Pass stream=True to avoid this behavior. :param request: The HttpRequest object - :type request: ~azure.core.pipeline.transport.HttpRequest + :type request: ~azure.core.rest.HttpRequest :param config: Any keyword arguments :return: The AsyncHttpResponse - :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse :keyword bool stream: Defaults to False. :keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) @@ -192,11 +284,14 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR allow_redirects=False, **config ) - response = AioHttpTransportResponse(request, result, - self.connection_config.data_block_size, - decompress=not auto_decompress) - if not stream_response: - await response.load_body() + response = RestAioHttpTransportResponse( + request=request, + internal_response=result, + block_size=self.connection_config.data_block_size, + decompress=not auto_decompress, + ) + await _read_in_response(response, stream_response) + except aiohttp.client_exceptions.ClientResponseError as err: raise ServiceResponseError(err, error=err) from err except aiohttp.client_exceptions.ClientError as err: @@ -205,7 +300,7 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR raise ServiceResponseError(err, error=err) from err return response -class AioHttpStreamDownloadGenerator(AsyncIterator): +class AioHttpStreamDownloadGenerator(collections.abc.AsyncIterator): """Streams the response body data. :param pipeline: The pipeline object @@ -217,17 +312,16 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, *, decompres self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = _get_block_size(response) + self.block_size = response.block_size self._decompress = decompress - internal_response = _get_internal_response(response) - self.content_length = int(internal_response.headers.get('Content-Length', 0)) + self.content_length = int(response.internal_response.headers.get('Content-Length', 0)) self._decompressor = None def __len__(self): return self.content_length async def __anext__(self): - internal_response = _get_internal_response(self.response) + internal_response = self.response.internal_response try: chunk = await internal_response.content.read(self.block_size) if not chunk: @@ -265,7 +359,7 @@ class AioHttpTransportResponse(AsyncHttpResponse): :param bool decompress: If True which is default, will attempt to decode the body based on the *content-encoding* header. """ - def __init__(self, request: HttpRequest, + def __init__(self, request: PipelineTransportHttpRequest, aiohttp_response: aiohttp.ClientResponse, block_size=None, *, decompress=True) -> None: super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size) @@ -360,3 +454,70 @@ def __getstate__(self): state['internal_response'] = None # aiohttp response are not pickable (see headers comments) state['headers'] = CIMultiDict(self.headers) # MultiDictProxy is not pickable return state + + +##################### REST ##################### + +class _RestAioHttpTransportResponseBackcompatMixin(): + + async def _load_body(self) -> None: + """Load in memory the body, so it could be accessible from sync methods.""" + self._content = await self.read() # type: ignore + +class RestAioHttpTransportResponse(RestAsyncHttpResponseImpl, _RestAioHttpTransportResponseBackcompatMixin): + def __init__( + self, + *, + internal_response, + decompress: bool = False, + **kwargs + ): + super().__init__( + internal_response=internal_response, + status_code=internal_response.status, + headers=CIMultiDict(internal_response.headers), + reason=internal_response.reason, + content_type=internal_response.headers.get('content-type'), + stream_download_generator=AioHttpStreamDownloadGenerator, + **kwargs + ) + self._decompress = decompress + self._content = internal_response._body + + def __getattr__(self, attr): + backcompat_attrs = ["load_body"] + if attr in backcompat_attrs: + attr = "_" + attr + return super().__getattr__(attr) + + def __getstate__(self): + + state = self.__dict__.copy() + # Remove the unpicklable entries. + state['_internal_response'] = None # aiohttp response are not pickable (see headers comments) + state['headers'] = CIMultiDict(self.headers) # MultiDictProxy is not pickable + return state + + async def close(self) -> None: + """Close the response. + + :return: None + :rtype: None + """ + self._is_closed = True + self._internal_response.close() + await asyncio.sleep(0) + + async def read(self) -> bytes: + """Read the response's bytes into memory. + + :return: The response's bytes + :rtype: bytes + """ + iterator = self.iter_bytes() if self._decompress else self.iter_raw() + if self._content is None: + parts = [] + async for part in iterator: + parts.append(part) + self._content = b"".join(parts) + return self._content diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py index 761c61caa9ae..bbece30073aa 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py @@ -26,11 +26,8 @@ from __future__ import absolute_import import abc from email.message import Message - -try: - from email import message_from_bytes as message_parser -except ImportError: # 2.7 - from email import message_from_string as message_parser # type: ignore +import cgi +import codecs from io import BytesIO import json import logging @@ -54,6 +51,7 @@ IO, List, Union, + Callable, Any, Mapping, Dict, @@ -63,9 +61,9 @@ Type ) -from six.moves.http_client import HTTPConnection, HTTPResponse as _HTTPResponse +from six.moves.http_client import HTTPResponse as _HTTPResponse -from azure.core.exceptions import HttpResponseError +from azure.core.exceptions import HttpResponseError, ResponseNotReadError, StreamClosedError, StreamConsumedError from azure.core.pipeline import ( ABC, AbstractContextManager, @@ -73,20 +71,42 @@ PipelineResponse, PipelineContext, ) -from .._tools import await_result as _await_result from ...utils._utils import _case_insensitive_dict - +from ...utils._pipeline_transport_rest_shared import ( + _format_parameters_helper, + _pad_attr_name, + _prepare_multipart_body_helper, + _serialize_request, +) +try: + from ...rest._rest_py3 import ( + _HttpResponseBase as _RestHttpResponseBase, + HttpResponse as RestHttpResponse, + HttpRequest as RestHttpRequest + ) +except (SyntaxError, ImportError): + from ...rest._rest import ( # type: ignore + _HttpResponseBase as _RestHttpResponseBase, + HttpResponse as RestHttpResponse, + HttpRequest as RestHttpRequest + ) if TYPE_CHECKING: from ..policies import SansIOHTTPPolicy from collections.abc import MutableMapping +try: + from email import message_from_bytes as message_parser +except ImportError: # 2.7 + from email import message_from_string as message_parser # type: ignore + HTTPResponseType = TypeVar("HTTPResponseType") HTTPRequestType = TypeVar("HTTPRequestType") PipelineType = TypeVar("PipelineType") _LOGGER = logging.getLogger(__name__) +from ...pipeline._tools import await_result as _await_result def _format_url_section(template, **kwargs): """String format the template with the kwargs, auto-skip sections of the template that are NOT in the kwargs. @@ -127,35 +147,82 @@ def _urljoin(base_url, stub_url): parsed = parsed._replace(path=parsed.path.rstrip("/") + "/" + stub_url) return parsed.geturl() +def _decode_parts_helper( + http_response, message, http_response_type, requests, deserialize_response_callable +): + responses = [] + for index, raw_reponse in enumerate(message.get_payload()): + content_type = raw_reponse.get_content_type() + if content_type == "application/http": + responses.append( + deserialize_response_callable( + raw_reponse.get_payload(decode=True), + requests[index], + http_response_type=http_response_type, + ) + ) + elif content_type == "multipart/mixed" and requests[index].multipart_mixed_info: + # The message batch contains one or more change sets + changeset_requests = requests[index].multipart_mixed_info[0] # type: ignore + changeset_responses = http_response._decode_parts( # pylint: disable=protected-access + raw_reponse, + http_response_type, + changeset_requests + ) + responses.extend(changeset_responses) + else: + raise ValueError( + "Multipart doesn't support part other than application/http for now" + ) + return responses + +def _get_raw_parts_helper(http_response, http_response_type, default_http_response_type): + if http_response_type is None: + http_response_type = default_http_response_type + + body_as_bytes = http_response.body() + # In order to use email.message parser, I need full HTTP bytes. Faking something to make the parser happy + http_body = ( + b"Content-Type: " + + http_response.content_type.encode("ascii") + + b"\r\n\r\n" + + body_as_bytes + ) + message = message_parser(http_body) # type: Message + requests = http_response.request.multipart_mixed_info[0] # type: List[HttpRequest] + return http_response._decode_parts(message, http_response_type, requests) # pylint: disable=protected-access + +def _parts_helper(http_response): + if not http_response.content_type or not http_response.content_type.startswith("multipart/mixed"): + raise ValueError( + "You can't get parts if the response is not multipart/mixed" + ) -class _HTTPSerializer(HTTPConnection, object): - """Hacking the stdlib HTTPConnection to serialize HTTP request as strings. - """ - - def __init__(self, *args, **kwargs): - self.buffer = b"" - kwargs.setdefault("host", "fakehost") - super(_HTTPSerializer, self).__init__(*args, **kwargs) + responses = http_response._get_raw_parts() # pylint: disable=protected-access + if http_response.request.multipart_mixed_info: + policies = http_response.request.multipart_mixed_info[1] # type: List[SansIOHTTPPolicy] - def putheader(self, header, *values): - if header in ["Host", "Accept-Encoding"]: - return - super(_HTTPSerializer, self).putheader(header, *values) + # Apply on_response concurrently to all requests + import concurrent.futures - def send(self, data): - self.buffer += data + def parse_responses(response): + http_request = response.request + context = PipelineContext(None) + pipeline_request = PipelineRequest(http_request, context) + pipeline_response = PipelineResponse( + http_request, response, context=context + ) + for policy in policies: + _await_result(policy.on_response, pipeline_request, pipeline_response) -def _serialize_request(http_request): - serializer = _HTTPSerializer() - serializer.request( - method=http_request.method, - url=http_request.url, - body=http_request.body, - headers=http_request.headers, - ) - return serializer.buffer + with concurrent.futures.ThreadPoolExecutor() as executor: + # List comprehension to raise exceptions if happened + [ # pylint: disable=expression-not-assigned, unnecessary-comprehension + _ for _ in executor.map(parse_responses, responses) + ] + return responses class HttpTransport( AbstractContextManager, ABC, Generic[HTTPRequestType, HTTPResponseType] @@ -272,26 +339,7 @@ def format_parameters(self, params): :param dict params: A dictionary of parameters. """ - query = urlparse(self.url).query - if query: - self.url = self.url.partition("?")[0] - existing_params = { - p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")] - } - params.update(existing_params) - query_params = [] - for k, v in params.items(): - if isinstance(v, list): - for w in v: - if w is None: - raise ValueError("Query parameter {} cannot be None".format(k)) - query_params.append("{}={}".format(k, w)) - else: - if v is None: - raise ValueError("Query parameter {} cannot be None".format(k)) - query_params.append("{}={}".format(k, v)) - query = "?" + "&".join(query_params) - self.url = self.url + query + return _format_parameters_helper(self, params) def set_streamed_data_body(self, data): """Set a streamable data body. @@ -416,54 +464,7 @@ def prepare_multipart_body(self, content_index=0): :returns: The updated index after all parts in this request have been added. :rtype: int """ - if not self.multipart_mixed_info: - return 0 - - requests = self.multipart_mixed_info[0] # type: List[HttpRequest] - boundary = self.multipart_mixed_info[2] # type: Optional[str] - - # Update the main request with the body - main_message = Message() - main_message.add_header("Content-Type", "multipart/mixed") - if boundary: - main_message.set_boundary(boundary) - - for req in requests: - part_message = Message() - if req.multipart_mixed_info: - content_index = req.prepare_multipart_body(content_index=content_index) - part_message.add_header("Content-Type", req.headers['Content-Type']) - payload = req.serialize() - # We need to remove the ~HTTP/1.1 prefix along with the added content-length - payload = payload[payload.index(b'--'):] - else: - part_message.add_header("Content-Type", "application/http") - part_message.add_header("Content-Transfer-Encoding", "binary") - part_message.add_header("Content-ID", str(content_index)) - payload = req.serialize() - content_index += 1 - part_message.set_payload(payload) - main_message.attach(part_message) - - try: - from email.policy import HTTP - - full_message = main_message.as_bytes(policy=HTTP) - eol = b"\r\n" - except ImportError: # Python 2.7 - # Right now we decide to not support Python 2.7 on serialization, since - # it doesn't serialize a valid HTTP request (and our main scenario Storage refuses it) - raise NotImplementedError( - "Multipart request are not supported on Python 2.7" - ) - # full_message = main_message.as_string() - # eol = b'\n' - _, _, body = full_message.split(eol, 2) - self.set_bytes_body(body) - self.headers["Content-Type"] = ( - "multipart/mixed; boundary=" + main_message.get_boundary() - ) - return content_index + return _prepare_multipart_body_helper(self, content_index) def serialize(self): # type: () -> bytes @@ -516,27 +517,9 @@ def text(self, encoding=None): def _decode_parts(self, message, http_response_type, requests): # type: (Message, Type[_HttpResponseBase], List[HttpRequest]) -> List[HttpResponse] """Rebuild an HTTP response from pure string.""" - responses = [] - for index, raw_reponse in enumerate(message.get_payload()): - content_type = raw_reponse.get_content_type() - if content_type == "application/http": - responses.append( - _deserialize_response( - raw_reponse.get_payload(decode=True), - requests[index], - http_response_type=http_response_type, - ) - ) - elif content_type == "multipart/mixed" and requests[index].multipart_mixed_info: - # The message batch contains one or more change sets - changeset_requests = requests[index].multipart_mixed_info[0] # type: ignore - changeset_responses = self._decode_parts(raw_reponse, http_response_type, changeset_requests) - responses.extend(changeset_responses) - else: - raise ValueError( - "Multipart doesn't support part other than application/http for now" - ) - return responses + return _decode_parts_helper( + self, message, http_response_type, requests, _deserialize_response + ) def _get_raw_parts(self, http_response_type=None): # type (Optional[Type[_HttpResponseBase]]) -> Iterator[HttpResponse] @@ -545,20 +528,9 @@ def _get_raw_parts(self, http_response_type=None): If parts are application/http use http_response_type or HttpClientTransportResponse as enveloppe. """ - if http_response_type is None: - http_response_type = HttpClientTransportResponse - - body_as_bytes = self.body() - # In order to use email.message parser, I need full HTTP bytes. Faking something to make the parser happy - http_body = ( - b"Content-Type: " - + self.content_type.encode("ascii") - + b"\r\n\r\n" - + body_as_bytes + return _get_raw_parts_helper( + self, http_response_type, HttpClientTransportResponse ) - message = message_parser(http_body) # type: Message - requests = self.request.multipart_mixed_info[0] # type: List[HttpRequest] - return self._decode_parts(message, http_response_type, requests) def raise_for_status(self): # type () -> None @@ -591,42 +563,7 @@ def stream_download(self, pipeline, **kwargs): def parts(self): # type: () -> Iterator[HttpResponse] - """Assuming the content-type is multipart/mixed, will return the parts as an iterator. - - :rtype: iterator[HttpResponse] - :raises ValueError: If the content is not multipart/mixed - """ - if not self.content_type or not self.content_type.startswith("multipart/mixed"): - raise ValueError( - "You can't get parts if the response is not multipart/mixed" - ) - - responses = self._get_raw_parts() - if self.request.multipart_mixed_info: - policies = self.request.multipart_mixed_info[1] # type: List[SansIOHTTPPolicy] - - # Apply on_response concurrently to all requests - import concurrent.futures - - def parse_responses(response): - http_request = response.request - context = PipelineContext(None) - pipeline_request = PipelineRequest(http_request, context) - pipeline_response = PipelineResponse( - http_request, response, context=context - ) - - for policy in policies: - _await_result(policy.on_response, pipeline_request, pipeline_response) - - with concurrent.futures.ThreadPoolExecutor() as executor: - # List comprehension to raise exceptions if happened - [ # pylint: disable=expression-not-assigned, unnecessary-comprehension - _ for _ in executor.map(parse_responses, responses) - ] - - return responses - + return _parts_helper(self) class _HttpClientTransportResponse(_HttpResponseBase): """Create a HTTPResponse from an http.client response. @@ -657,7 +594,6 @@ class HttpClientTransportResponse(_HttpClientTransportResponse, HttpResponse): Body will NOT be read by the constructor. Call "body()" to load the body in memory if necessary. """ - class BytesIOSocket(object): """Mocking the "makefile" of socket for HTTPResponse. @@ -673,7 +609,7 @@ def makefile(self, *_): def _deserialize_response( - http_response_as_bytes, http_request, http_response_type=HttpClientTransportResponse + http_response_as_bytes, http_request, http_response_type=HttpClientTransportResponse, ): local_socket = BytesIOSocket(http_response_as_bytes) response = _HTTPResponse(local_socket, method=http_request.method) @@ -948,3 +884,420 @@ def options(self, url, params=None, headers=None, **kwargs): "OPTIONS", url, params, headers, content, form_content, None ) return request + + +####################### REST ####################### + +def _lookup_encoding(encoding): + # type: (str) -> bool + # including check for whether encoding is known taken from httpx + try: + codecs.lookup(encoding) + return True + except LookupError: + return False + +class _RestHttpResponseBackcompatMixinBase(object): + + def __getattr__(self, attr): + backcompat_attrs = [ + "body", + "internal_response", + "block_size", + "stream_download", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + return self.__getattribute__(attr) + + def __setattr__(self, attr, value): + backcompat_attrs = [ + "block_size", + "internal_response", + "request", + "status_code", + "headers", + "reason", + "content_type", + "stream_download", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + super(_RestHttpResponseBackcompatMixinBase, self).__setattr__(attr, value) + + def _body(self): + """DEPRECATED: Get the response body. + + This is deprecated and will be removed in a later release. + You should get it through the `content` property instead + """ + self.read() + return self.content # pylint: disable=no-member + + def _decode_parts(self, message, http_response_type, requests): + def _deserialize_response( + http_response_as_bytes, http_request, http_response_type + ): + local_socket = BytesIOSocket(http_response_as_bytes) + response = _HTTPResponse(local_socket, method=http_request.method) + response.begin() + return http_response_type(request=http_request, internal_response=response) + return _decode_parts_helper( + self, message, http_response_type, requests, _deserialize_response + ) + + def _get_raw_parts(self, http_response_type=None): + return _get_raw_parts_helper( + self, http_response_type, RestHttpClientTransportResponse + ) + + def _stream_download(self, pipeline, **kwargs): + """DEPRECATED: Generator for streaming request body data. + + This is deprecated and will be removed in a later release. + You should use `iter_bytes` or `iter_raw` instead. + + :rtype: iterator[bytes] + """ + return self._stream_download_generator(pipeline, self, **kwargs) + +class _RestHttpResponseBackcompatMixin(_RestHttpResponseBackcompatMixinBase): + + def __getattr__(self, attr): + backcompat_attrs = ["parts"] + attr = _pad_attr_name(attr, backcompat_attrs) + return super(_RestHttpResponseBackcompatMixin, self).__getattr__(attr) + + def parts(self): + """DEPRECATED: Assuming the content-type is multipart/mixed, will return the parts as an async iterator. + + This is deprecated and will be removed in a later release. + + :rtype: Iterator + :raises ValueError: If the content is not multipart/mixed + """ + return _parts_helper(self) + +class _RestHttpResponseBaseImpl( + _RestHttpResponseBase, _RestHttpResponseBackcompatMixin +): + + def __init__(self, **kwargs): + # type: (Any) -> None + super(_RestHttpResponseBaseImpl, self).__init__() + self._request = kwargs.pop("request") # type: RestHttpRequest + self._internal_response = kwargs.pop("internal_response") + self._is_closed = False + self._is_stream_consumed = False + self._block_size = kwargs.get("block_size", None) or 4096 # type: int + self._status_code = kwargs.pop("status_code") # type: int + self._reason = kwargs.pop("reason") # type: str + self._content_type = kwargs.pop("content_type") # type: str + self._headers = kwargs.pop("headers") # type: Optional[HeadersType] + self._stream_download_generator = kwargs.pop("stream_download_generator") # type: Callable + self._json = None # this is filled in ContentDecodePolicy, when we deserialize + self._content = None # type: Optional[bytes] + self._text = None # type: Optional[str] + + + @property + def status_code(self): + # type: (...) -> int + """The status code of this response""" + return self._status_code + + @property + def headers(self): + # type: (...) -> Optional[HeadersType] + return self._headers + + @property + def reason(self): + # type: (...) -> str + """The response headers""" + return self._reason + + @property + def content_type(self): + # type: (...) -> Optional[str] + """The content type of the response""" + return self._content_type + + @property + def request(self): + # type: (...) -> RestHttpRequest + return self._request + + @property + def url(self): + # type: (...) -> str + """Returns the URL that resulted in this response""" + return self.request.url + + @property + def is_closed(self): + # type: (...) -> bool + """Whether the network connection has been closed yet""" + return self._is_closed + + @property + def is_stream_consumed(self): + # type: (...) -> bool + """Whether the stream has been fully consumed""" + return self._is_stream_consumed + + def _get_charset_encoding(self): + # type: (...) -> Optional[str] + content_type = self.headers.get("Content-Type") + + if not content_type: + return None + _, params = cgi.parse_header(content_type) + encoding = params.get('charset') # -> utf-8 + if encoding is None or not _lookup_encoding(encoding): + return None + return encoding + + @property + def encoding(self): + # type: (...) -> Optional[str] + """Returns the response encoding. + + :return: The response encoding. We either return the encoding set by the user, + or try extracting the encoding from the response's content type. If all fails, + we return `None`. + :rtype: optional[str] + """ + try: + return self._encoding + except AttributeError: + self._encoding = self._get_charset_encoding() # type: Optional[str] + return self._encoding + + @encoding.setter + def encoding(self, value): + # type: (str) -> None + """Sets the response encoding""" + self._encoding = value + self._text = None # clear text cache + + def _decode_to_text(self, encoding): + # type: (Optional[str]) -> str + if not self.content: + return "" + if encoding == "utf-8": + encoding = "utf-8-sig" + if encoding: + return self.content.decode(encoding) + return codecs.getincrementaldecoder("utf-8-sig")(errors="replace").decode(self.content) + + def text(self, encoding=None): + # type: (Optional[str]) -> str + """Returns the response body as a string + + :param optional[str] encoding: The encoding you want to decode the text with. Can + also be set independently through our encoding property + :return: The response's content decoded as a string. + """ + if self._text is None or encoding: + encoding_to_pass = encoding or self.encoding + self._text = self._decode_to_text(encoding_to_pass) + return self._text + + def json(self): + # type: (...) -> Any + """Returns the whole body as a json object. + + :return: The JSON deserialized response body + :rtype: any + :raises json.decoder.JSONDecodeError or ValueError (in python 2.7) if object is not JSON decodable: + """ + # this will trigger errors if response is not read in + self.content # pylint: disable=pointless-statement + if not self._json: + self._json = json.loads(self.text()) + return self._json + + def raise_for_status(self): + # type: (...) -> None + """Raises an HttpResponseError if the response has an error status code. + + If response is good, does nothing. + """ + if cast(int, self.status_code) >= 400: + raise HttpResponseError(response=self) + + @property + def content(self): + # type: (...) -> bytes + """Return the response's content in bytes.""" + if self._content is None: + raise ResponseNotReadError(self) + return self._content + + def _has_content(self): + try: + self.content # pylint: disable=pointless-statement + return True + except ResponseNotReadError: + return False + + def __repr__(self): + # type: (...) -> str + content_type_str = ( + ", Content-Type: {}".format(self.content_type) if self.content_type else "" + ) + return "".format( + self.status_code, self.reason, content_type_str + ) + + @staticmethod + def _parse_lines_from_text(text): + # largely taken from httpx's LineDecoder code + lines = [] + last_chunk_of_text = "" + while text: + text_length = len(text) + for idx in range(text_length): + curr_char = text[idx] + next_char = None if idx == len(text) - 1 else text[idx + 1] + if curr_char == "\n": + lines.append(text[: idx + 1]) + text = text[idx + 1: ] + break + if curr_char == "\r" and next_char == "\n": + # if it ends with \r\n, we only do \n + lines.append(text[:idx] + "\n") + text = text[idx + 2:] + break + if curr_char == "\r" and next_char is not None: + # if it's \r then a normal character, we switch \r to \n + lines.append(text[:idx] + "\n") + text = text[idx + 1:] + break + if next_char is None: + last_chunk_of_text += text + text = "" + break + if last_chunk_of_text.endswith("\r"): + # if ends with \r, we switch \r to \n + lines.append(last_chunk_of_text[:-1] + "\n") + elif last_chunk_of_text: + lines.append(last_chunk_of_text) + return lines + + def _stream_download_check(self): + if self.is_stream_consumed: + raise StreamConsumedError(self) + if self.is_closed: + raise StreamClosedError(self) + + self._is_stream_consumed = True + +class RestHttpResponseImpl( + _RestHttpResponseBaseImpl, RestHttpResponse, _RestHttpResponseBackcompatMixinBase +): + """HttpResponseImpl built on top of our HttpResponse protocol class. + + Helper impl for creating our transport responses + """ + + def __enter__(self): + # type: (...) -> RestHttpResponseImpl + return self + + def close(self): + # type: (...) -> None + if not self.is_closed: + self._is_closed = True + self._internal_response.close() + + def __exit__(self, *args): + # type: (...) -> None + self.close() + + def read(self): + # type: (...) -> bytes + """ + Read the response's bytes. + + """ + if self._content is None: + self._content = b"".join(self.iter_bytes()) + return self.content + + def iter_bytes(self): + # type: () -> Iterator[bytes] + """Iterates over the response's bytes. Will decompress in the process + :return: An iterator of bytes from the response + :rtype: Iterator[str] + """ + if self._has_content(): + chunk_size = cast(int, self._block_size) + for i in range(0, len(self.content), chunk_size): + yield self.content[i : i + chunk_size] + else: + self._stream_download_check() + for part in self._stream_download_generator( + response=self, + pipeline=None, + decompress=True, + ): + yield part + self.close() + + def iter_raw(self): + # type: () -> Iterator[bytes] + """Iterates over the response's bytes. Will not decompress in the process + :return: An iterator of bytes from the response + :rtype: Iterator[str] + """ + self._stream_download_check() + for part in self._stream_download_generator( + response=self, pipeline=None, decompress=False + ): + yield part + self.close() + + def iter_text(self): + # type: () -> Iterator[str] + """Iterate over the response text + """ + for byte in self.iter_bytes(): + text = byte.decode(self.encoding or "utf-8") + yield text + + def iter_lines(self): + # type: () -> Iterator[str] + for text in self.iter_text(): + lines = self._parse_lines_from_text(text) + for line in lines: + yield line + +class _RestHttpClientTransportResponseBase(_RestHttpResponseBaseImpl): + + def __init__(self, **kwargs): + internal_response = kwargs.pop("internal_response") + headers = _case_insensitive_dict(internal_response.getheaders()) + super(_RestHttpClientTransportResponseBase, self).__init__( + internal_response=internal_response, + status_code=internal_response.status, + reason=internal_response.reason, + headers=headers, + content_type=headers.get("Content-Type"), + stream_download_generator=None, + **kwargs + ) + +class RestHttpClientTransportResponse(_RestHttpClientTransportResponseBase, RestHttpResponseImpl): + """Create a Rest HTTPResponse from an http.client response. + """ + + def iter_bytes(self): + raise TypeError("We do not support iter_bytes for this transport response") + + def iter_raw(self): + raise TypeError("We do not support iter_raw for this transport response") + + def read(self): + if not self._has_content(): + self._content = self._internal_response.read() + return self._content diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py index 73fcd51bf957..4a1d70a3b217 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py @@ -28,16 +28,20 @@ import abc from collections.abc import AsyncIterator -from typing import AsyncIterator as AsyncIteratorType, TypeVar, Generic, Any +from typing import AsyncIterator as AsyncIteratorType, TypeVar, Generic, Any, Callable, Optional from ._base import ( _HttpResponseBase, _HttpClientTransportResponse, + _RestHttpClientTransportResponseBase, PipelineContext, PipelineRequest, PipelineResponse, + _RestHttpResponseBaseImpl, + _RestHttpResponseBackcompatMixinBase, ) from .._tools_async import await_result as _await_result - +from ...utils._pipeline_transport_rest_shared import _pad_attr_name +from ...rest import AsyncHttpResponse as RestAsyncHttpResponse try: from contextlib import AbstractAsyncContextManager # type: ignore except ImportError: # Python <= 3.7 @@ -159,7 +163,6 @@ class AsyncHttpClientTransportResponse(_HttpClientTransportResponse, AsyncHttpRe :param httpclient_response: The object returned from an HTTP(S)Connection from http.client """ - class AsyncHttpTransport( AbstractAsyncContextManager, abc.ABC, @@ -183,3 +186,137 @@ async def close(self): async def sleep(self, duration): await asyncio.sleep(duration) + +####################### REST ####################### + +class _RestAsyncHttpResponseBackcompatMixin(_RestHttpResponseBackcompatMixinBase): + def __getattr__(self, attr): + backcompat_attrs = ["parts"] + attr = _pad_attr_name(attr, backcompat_attrs) + return super().__getattr__(attr) + + def parts(self): + """DEPRECATED: Assuming the content-type is multipart/mixed, will return the parts as an async iterator. + + This is deprecated and will be removed in a later release. + + :rtype: AsyncIterator + :raises ValueError: If the content is not multipart/mixed + """ + if not self.content_type or not self.content_type.startswith("multipart/mixed"): + raise ValueError( + "You can't get parts if the response is not multipart/mixed" + ) + + return _PartGenerator(self) + +class RestAsyncHttpResponseImpl( + RestAsyncHttpResponse, _RestHttpResponseBaseImpl, _RestAsyncHttpResponseBackcompatMixin +): + """AsyncHttpResponseImpl built on top of our HttpResponse protocol class. + + Helper impl for creating our transport responses + """ + + async def read(self) -> bytes: + """Read the response's bytes into memory. + + :return: The response's bytes + :rtype: bytes + """ + if self._content is None: + parts = [] + async for part in self.iter_bytes(): + parts.append(part) + self._content = b"".join(parts) + self._is_stream_consumed = True + return self._content + + async def iter_text(self) -> AsyncIteratorType[str]: + """Asynchronously iterates over the text in the response. + + :return: An async iterator of string. Each string chunk will be a text from the response + :rtype: AsyncIterator[str] + """ + async for byte in self.iter_bytes(): # type: ignore + text = byte.decode(self.encoding or "utf-8") + yield text + + async def iter_lines(self) -> AsyncIteratorType[str]: + """Asynchronously iterates over the lines in the response. + + :return: An async iterator of string. Each string chunk will be a line from the response + :rtype: AsyncIterator[str] + """ + async for text in self.iter_text(): + lines = self._parse_lines_from_text(text) + for line in lines: + yield line + + async def iter_raw(self) -> AsyncIteratorType[bytes]: + """Asynchronously iterates over the response's bytes. Will not decompress in the process + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + self._stream_download_check() + async for part in self._stream_download_generator( + response=self, pipeline=None, decompress=False + ): + yield part + await self.close() + + async def iter_bytes(self) -> AsyncIteratorType[bytes]: + """Asynchronously iterates over the response's bytes. Will decompress in the process + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + if self._has_content(): + for i in range(0, len(self.content), self._block_size): + yield self.content[i : i + self._block_size] + else: + self._stream_download_check() + async for part in self._stream_download_generator( + response=self, + pipeline=None, + decompress=True + ): + yield part + await self.close() + + async def close(self) -> None: + """Close the response. + + :return: None + :rtype: None + """ + if not self.is_closed: + self._is_closed = True + await self._internal_response.close() + + async def __aexit__(self, *args) -> None: + await self.close() + + def __repr__(self) -> str: + content_type_str = ( + ", Content-Type: {}".format(self.content_type) if self.content_type else "" + ) + return "".format( + self.status_code, self.reason, content_type_str + ) + +class RestAsyncHttpClientTransportResponse(_RestHttpClientTransportResponseBase, RestAsyncHttpResponseImpl): + """Create a Rest HTTPResponse from an http.client response. + """ + + async def iter_bytes(self): + raise TypeError("We do not support iter_bytes for this transport response") + + async def iter_raw(self): + raise TypeError("We do not support iter_raw for this transport response") + + async def read(self): + if self._content is None: + self._content = self._internal_response.read() + return self._content diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py index e41e4de91325..f022edec60d0 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py @@ -37,15 +37,21 @@ ServiceResponseError ) from azure.core.pipeline import Pipeline -from ._base import HttpRequest from ._base_async import ( AsyncHttpResponse, _ResponseStopIteration, - _iterate_response_content) -from ._requests_basic import RequestsTransportResponse, _read_raw_stream + _iterate_response_content, + RestAsyncHttpResponseImpl, +) +from ._requests_basic import ( + _RestRequestsTransportResponseBase, + RequestsTransportResponse, + _read_raw_stream, +) from ._base_requests_async import RequestsAsyncTransportBase -from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response +from .._tools_async import read_in_response as _read_in_response +from ...rest import HttpRequest, AsyncHttpResponse as RestAsyncHttpResponse _LOGGER = logging.getLogger(__name__) @@ -83,13 +89,13 @@ async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ async def sleep(self, duration): # pylint:disable=invalid-overridden-method await asyncio.sleep(duration) - async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: # type: ignore # pylint:disable=invalid-overridden-method + async def send(self, request: HttpRequest, **kwargs: Any) -> RestAsyncHttpResponse: # type: ignore # pylint:disable=invalid-overridden-method """Send the request using this HTTP sender. :param request: The HttpRequest - :type request: ~azure.core.pipeline.transport.HttpRequest + :type request: ~azure.core.rest.HttpRequest :return: The AsyncHttpResponse - :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse :keyword requests.Session session: will override the driver session and use yours. Should NOT be done unless really required. Anything else is sent straight to requests. @@ -131,8 +137,13 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: if error: raise error - return AsyncioRequestsTransportResponse(request, response, self.connection_config.data_block_size) - + retval = RestAsyncioRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size + ) + await _read_in_response(retval, kwargs.get("stream")) + return retval class AsyncioStreamDownloadGenerator(AsyncIterator): """Streams the response body data. @@ -146,11 +157,11 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = _get_block_size(response) + self.block_size = response.block_size decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) - internal_response = _get_internal_response(response) + internal_response = response.internal_response if decompress: self.iter_content_func = internal_response.iter_content(self.block_size) else: @@ -162,7 +173,7 @@ def __len__(self): async def __anext__(self): loop = _get_running_loop() - internal_response = _get_internal_response(self.response) + internal_response = self.response.internal_response try: chunk = await loop.run_in_executor( None, @@ -182,10 +193,47 @@ async def __anext__(self): internal_response.close() raise - class AsyncioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore """Asynchronous streaming of data from the response. """ def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore """Generator for streaming request body data.""" return AsyncioStreamDownloadGenerator(pipeline, self, **kwargs) # type: ignore + +##################### REST ##################### + +class RestAsyncioRequestsTransportResponse( + _RestRequestsTransportResponseBase, + RestAsyncHttpResponseImpl, +): # type: ignore + """Asynchronous streaming of data from the response. + """ + + def __init__(self, **kwargs): + super().__init__( + stream_download_generator=AsyncioStreamDownloadGenerator, + **kwargs + ) + + async def close(self) -> None: + """Close the response. + + :return: None + :rtype: None + """ + self._is_closed = True + self._internal_response.close() + await asyncio.sleep(0) + + async def read(self) -> bytes: + """Read the response's bytes into memory. + + :return: The response's bytes + :rtype: bytes + """ + if not self._has_content(): + parts = [] + async for part in self.iter_bytes(): # type: ignore + parts.append(part) + self._internal_response._content = b"".join(parts) # pylint: disable=protected-access + return self.content diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py index 28b81d705c16..a957469fbf32 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py @@ -24,34 +24,68 @@ # # -------------------------------------------------------------------------- from __future__ import absolute_import +try: + import collections.abc as collections +except ImportError: + import collections # type: ignore import logging -from typing import Iterator, Optional, Any, Union, TypeVar +from typing import Iterator, Optional, Any, Union, TypeVar, cast import urllib3 # type: ignore from urllib3.util.retry import Retry # type: ignore from urllib3.exceptions import ( DecodeError, ReadTimeoutError, ProtocolError ) import requests +from requests.structures import CaseInsensitiveDict from azure.core.configuration import ConnectionConfiguration from azure.core.exceptions import ( ServiceRequestError, - ServiceResponseError + ServiceResponseError, + ResponseNotReadError, + StreamConsumedError, + StreamClosedError, ) -from . import HttpRequest # pylint: disable=unused-import from ._base import ( HttpTransport, HttpResponse, - _HttpResponseBase + _HttpResponseBase, + RestHttpResponseImpl, + _RestHttpResponseBaseImpl, ) from ._bigger_block_size_http_adapters import BiggerBlockSizeHTTPAdapter -from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response +from .._tools import ( + read_in_response as _read_in_response, +) +from ...rest import HttpRequest as RestHttpRequest, HttpResponse as RestHttpResponse PipelineType = TypeVar("PipelineType") _LOGGER = logging.getLogger(__name__) +class _ItemsView(collections.ItemsView): + def __contains__(self, item): + if not (isinstance(item, (list, tuple)) and len(item) == 2): + return False # requests raises here, we just return False + for k, v in self.__iter__(): + if item[0].lower() == k.lower() and item[1] == v: + return True + return False + + def __repr__(self): + return 'ItemsView({})'.format(dict(self.__iter__())) + +class _CaseInsensitiveDict(CaseInsensitiveDict): + """Overriding default requests dict so we can unify + to not raise if users pass in incorrect items to contains. + Instead, we return False + """ + + def items(self): + """Return a new view of the dictionary's items.""" + return _ItemsView(self) + def _read_raw_stream(response, chunk_size=1): # Special case for urllib3. if hasattr(response.raw, 'stream'): @@ -132,11 +166,11 @@ def __init__(self, pipeline, response, **kwargs): self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = _get_block_size(response) + self.block_size = response.block_size decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) - internal_response = _get_internal_response(response) + internal_response = response.internal_response if decompress: self.iter_content_func = internal_response.iter_content(self.block_size) else: @@ -150,7 +184,7 @@ def __iter__(self): return self def __next__(self): - internal_response = _get_internal_response(self.response) + internal_response = self.response.internal_response try: chunk = next(self.iter_content_func) if not chunk: @@ -243,13 +277,13 @@ def close(self): self.session = None def send(self, request, **kwargs): # type: ignore - # type: (HttpRequest, Any) -> HttpResponse + # type: (RestHttpRequest, Any) -> RestHttpResponse """Send request object according to configuration. :param request: The request object to be sent. - :type request: ~azure.core.pipeline.transport.HttpRequest + :type request: ~azure.core.rest.HttpRequest :return: An HTTPResponse object. - :rtype: ~azure.core.pipeline.transport.HttpResponse + :rtype: ~azure.core.rest.HttpResponse :keyword requests.Session session: will override the driver session and use yours. Should NOT be done unless really required. Anything else is sent straight to requests. @@ -296,4 +330,58 @@ def send(self, request, **kwargs): # type: ignore if error: raise error - return RequestsTransportResponse(request, response, self.connection_config.data_block_size) + retval = RestRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size + ) + _read_in_response(retval, kwargs.get("stream")) + return retval + +##################### REST ##################### + +class _RestRequestsTransportResponseBase(_RestHttpResponseBaseImpl): + def __init__(self, **kwargs): + internal_response = kwargs.pop("internal_response") + super(_RestRequestsTransportResponseBase, self).__init__( + internal_response=internal_response, + status_code=internal_response.status_code, + headers=internal_response.headers, + reason=internal_response.reason, + content_type=internal_response.headers.get('content-type'), + **kwargs + ) + + @property + def content(self): + # type: () -> bytes + if not self._internal_response._content_consumed: # pylint: disable=protected-access + # if we just call .content, requests will read in the content. + # we want to read it in our own way + raise ResponseNotReadError(self) + + try: + return self._internal_response.content + except RuntimeError: + # requests throws a RuntimeError if the content for a response is already consumed + raise ResponseNotReadError(self) + +class RestRequestsTransportResponse( + _RestRequestsTransportResponseBase, RestHttpResponseImpl +): + def __init__(self, **kwargs): + super(RestRequestsTransportResponse, self).__init__( + stream_download_generator=StreamDownloadGenerator, + **kwargs + ) + + def read(self): + # type: () -> bytes + """Read the response's bytes. + + :return: The read in bytes + :rtype: bytes + """ + if not self._has_content(): + self._internal_response._content = b"".join(self.iter_bytes()) # pylint: disable=protected-access + return self.content diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py index e21ee5115327..7f8d45683bfa 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py @@ -37,15 +37,21 @@ ServiceResponseError ) from azure.core.pipeline import Pipeline -from ._base import HttpRequest from ._base_async import ( AsyncHttpResponse, _ResponseStopIteration, - _iterate_response_content) -from ._requests_basic import RequestsTransportResponse, _read_raw_stream + _iterate_response_content, + RestAsyncHttpResponseImpl, +) +from ._requests_basic import ( + RequestsTransportResponse, + _RestRequestsTransportResponseBase, + _read_raw_stream, +) from ._base_requests_async import RequestsAsyncTransportBase -from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response +from .._tools_async import read_in_response as _read_in_response +from ...rest import HttpRequest, AsyncHttpResponse as RestAsyncHttpResponse _LOGGER = logging.getLogger(__name__) @@ -62,11 +68,11 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = _get_block_size(response) + self.block_size = response.block_size decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) - internal_response = _get_internal_response(response) + internal_response = response.internal_response if decompress: self.iter_content_func = internal_response.iter_content(self.block_size) else: @@ -77,7 +83,7 @@ def __len__(self): return self.content_length async def __anext__(self): - internal_response = _get_internal_response(self.response) + internal_response = self.response.internal_response try: try: chunk = await trio.to_thread.run_sync( @@ -110,7 +116,6 @@ def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # ty """ return TrioStreamDownloadGenerator(pipeline, self, **kwargs) - class TrioRequestsTransport(RequestsAsyncTransportBase): # type: ignore """Identical implementation as the synchronous RequestsTransport wrapped in a class with asynchronous methods. Uses the third party trio event loop. @@ -133,13 +138,13 @@ async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ async def sleep(self, duration): # pylint:disable=invalid-overridden-method await trio.sleep(duration) - async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: # type: ignore # pylint:disable=invalid-overridden-method + async def send(self, request: HttpRequest, **kwargs: Any) -> RestAsyncHttpResponse: # type: ignore # pylint:disable=invalid-overridden-method """Send the request using this HTTP sender. :param request: The HttpRequest - :type request: ~azure.core.pipeline.transport.HttpRequest + :type request: ~azure.core.rest.HttpRequest :return: The AsyncHttpResponse - :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse :keyword requests.Session session: will override the driver session and use yours. Should NOT be done unless really required. Anything else is sent straight to requests. @@ -196,5 +201,42 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: if error: raise error + retval = RestTrioRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size + ) + await _read_in_response(retval, kwargs.get("stream")) + return retval + +##################### REST ##################### + +class RestTrioRequestsTransportResponse( + _RestRequestsTransportResponseBase, + RestAsyncHttpResponseImpl, +): # type: ignore + """Asynchronous streaming of data from the response. + """ + def __init__(self, **kwargs): + super().__init__( + stream_download_generator=TrioStreamDownloadGenerator, + **kwargs + ) + + async def read(self) -> bytes: + """Read the response's bytes into memory. - return TrioRequestsTransportResponse(request, response, self.connection_config.data_block_size) + :return: The response's bytes + :rtype: bytes + """ + if not self._has_content(): + parts = [] + async for part in self.iter_bytes(): # type: ignore + parts.append(part) + self._internal_response._content = b"".join(parts) # pylint: disable=protected-access + return self.content + + async def close(self) -> None: + self._is_closed = True + self._internal_response.close() + await trio.sleep(0) diff --git a/sdk/core/azure-core/azure/core/polling/async_base_polling.py b/sdk/core/azure-core/azure/core/polling/async_base_polling.py index 1294a30704ef..1b75ca69c9b8 100644 --- a/sdk/core/azure-core/azure/core/polling/async_base_polling.py +++ b/sdk/core/azure-core/azure/core/polling/async_base_polling.py @@ -32,6 +32,7 @@ LROBasePolling, _raise_if_bad_http_status_and_method, ) +from ..rest import HttpRequest as RestHttpRequest __all__ = ["AsyncLROBasePolling"] @@ -116,7 +117,7 @@ async def request_status(self, status_link): # pylint:disable=invalid-overridde """ if self._path_format_arguments: status_link = self._client.format_url(status_link, **self._path_format_arguments) - request = self._client.get(status_link) + request = RestHttpRequest("GET", url=status_link) # Re-inject 'x-ms-client-request-id' while polling if "request_id" not in self._operation_config: self._operation_config["request_id"] = self._get_request_id() diff --git a/sdk/core/azure-core/azure/core/polling/base_polling.py b/sdk/core/azure-core/azure/core/polling/base_polling.py index 5fa2ea3863fa..9255fc67764f 100644 --- a/sdk/core/azure-core/azure/core/polling/base_polling.py +++ b/sdk/core/azure-core/azure/core/polling/base_polling.py @@ -31,13 +31,13 @@ from ..exceptions import HttpResponseError, DecodeError from . import PollingMethod from ..pipeline.policies._utils import get_retry_after +from ..rest import HttpRequest if TYPE_CHECKING: from azure.core.pipeline import PipelineResponse from azure.core.pipeline.transport import ( HttpResponse, AsyncHttpResponse, - HttpRequest, ) ResponseType = Union[HttpResponse, AsyncHttpResponse] @@ -574,7 +574,7 @@ def request_status(self, status_link): """ if self._path_format_arguments: status_link = self._client.format_url(status_link, **self._path_format_arguments) - request = self._client.get(status_link) + request = HttpRequest(method="GET", url=status_link) # Re-inject 'x-ms-client-request-id' while polling if "request_id" not in self._operation_config: self._operation_config["request_id"] = self._get_request_id() diff --git a/sdk/core/azure-core/azure/core/rest/_aiohttp.py b/sdk/core/azure-core/azure/core/rest/_aiohttp.py deleted file mode 100644 index f25d9f7679b0..000000000000 --- a/sdk/core/azure-core/azure/core/rest/_aiohttp.py +++ /dev/null @@ -1,87 +0,0 @@ -# -------------------------------------------------------------------------- -# -# Copyright (c) Microsoft Corporation. All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the ""Software""), to -# deal in the Software without restriction, including without limitation the -# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -# sell copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -# IN THE SOFTWARE. -# -# -------------------------------------------------------------------------- - -import asyncio -from typing import AsyncIterator -from multidict import CIMultiDict -from . import HttpRequest, AsyncHttpResponse -from ._helpers_py3 import iter_raw_helper, iter_bytes_helper -from ..pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator - - -class RestAioHttpTransportResponse(AsyncHttpResponse): - def __init__( - self, - *, - request: HttpRequest, - internal_response, - ): - super().__init__(request=request, internal_response=internal_response) - self.status_code = internal_response.status - self.headers = CIMultiDict(internal_response.headers) # type: ignore - self.reason = internal_response.reason - self.content_type = internal_response.headers.get('content-type') - - async def iter_raw(self) -> AsyncIterator[bytes]: - """Asynchronously iterates over the response's bytes. Will not decompress in the process - - :return: An async iterator of bytes from the response - :rtype: AsyncIterator[bytes] - """ - async for part in iter_raw_helper(AioHttpStreamDownloadGenerator, self): - yield part - await self.close() - - async def iter_bytes(self) -> AsyncIterator[bytes]: - """Asynchronously iterates over the response's bytes. Will decompress in the process - - :return: An async iterator of bytes from the response - :rtype: AsyncIterator[bytes] - """ - async for part in iter_bytes_helper( - AioHttpStreamDownloadGenerator, - self, - content=self._content - ): - yield part - await self.close() - - def __getstate__(self): - state = self.__dict__.copy() - # Remove the unpicklable entries. - state['internal_response'] = None # aiohttp response are not pickable (see headers comments) - state['headers'] = CIMultiDict(self.headers) # MultiDictProxy is not pickable - return state - - async def close(self) -> None: - """Close the response. - - :return: None - :rtype: None - """ - self.is_closed = True - self._internal_response.close() - await asyncio.sleep(0) diff --git a/sdk/core/azure-core/azure/core/rest/_helpers.py b/sdk/core/azure-core/azure/core/rest/_helpers.py index 1a011689a238..e21db0ec1ad4 100644 --- a/sdk/core/azure-core/azure/core/rest/_helpers.py +++ b/sdk/core/azure-core/azure/core/rest/_helpers.py @@ -24,33 +24,39 @@ # # -------------------------------------------------------------------------- import os -import codecs -import cgi from enum import Enum from json import dumps -import collections +try: + import collections.abc as collections +except ImportError: + import collections # type: ignore from typing import ( Optional, Union, Mapping, Sequence, - List, Tuple, IO, Any, Dict, Iterable, - Iterator, cast, - Callable, ) import xml.etree.ElementTree as ET import six try: + binary_type = str from urlparse import urlparse # type: ignore except ImportError: + binary_type = bytes # type: ignore from urllib.parse import urlparse from azure.core.serialization import AzureJSONEncoder +from ..utils._pipeline_transport_rest_shared import ( + _format_parameters_helper, + _pad_attr_name, + _prepare_multipart_body_helper, + _serialize_request, +) ################################### TYPES SECTION ######################### @@ -82,10 +88,6 @@ class HttpVerbs(str, Enum): DELETE = "DELETE" MERGE = "MERGE" -########################### ERRORS SECTION ################################# - - - ########################### HELPER SECTION ################################# def _verify_data_object(name, value): @@ -218,86 +220,209 @@ def format_parameters(url, params): url += query return url -def lookup_encoding(encoding): - # type: (str) -> bool - # including check for whether encoding is known taken from httpx - try: - codecs.lookup(encoding) - return True - except LookupError: - return False - -def parse_lines_from_text(text): - # largely taken from httpx's LineDecoder code - lines = [] - last_chunk_of_text = "" - while text: - text_length = len(text) - for idx in range(text_length): - curr_char = text[idx] - next_char = None if idx == len(text) - 1 else text[idx + 1] - if curr_char == "\n": - lines.append(text[: idx + 1]) - text = text[idx + 1: ] - break - if curr_char == "\r" and next_char == "\n": - # if it ends with \r\n, we only do \n - lines.append(text[:idx] + "\n") - text = text[idx + 2:] - break - if curr_char == "\r" and next_char is not None: - # if it's \r then a normal character, we switch \r to \n - lines.append(text[:idx] + "\n") - text = text[idx + 1:] - break - if next_char is None: - last_chunk_of_text += text - text = "" - break - if last_chunk_of_text.endswith("\r"): - # if ends with \r, we switch \r to \n - lines.append(last_chunk_of_text[:-1] + "\n") - elif last_chunk_of_text: - lines.append(last_chunk_of_text) - return lines - -def to_pipeline_transport_request_helper(rest_request): - from ..pipeline.transport import HttpRequest as PipelineTransportHttpRequest - return PipelineTransportHttpRequest( - method=rest_request.method, - url=rest_request.url, - headers=rest_request.headers, - files=rest_request._files, # pylint: disable=protected-access - data=rest_request._data # pylint: disable=protected-access - ) -def from_pipeline_transport_request_helper(request_class, pipeline_transport_request): - return request_class( - method=pipeline_transport_request.method, - url=pipeline_transport_request.url, - headers=pipeline_transport_request.headers, - files=pipeline_transport_request.files, - data=pipeline_transport_request.data - ) +class HttpRequestBackcompatMixin(object): + + def __getattr__(self, attr): + backcompat_attrs = [ + "files", + "data", + "multipart_mixed_info", + "query", + "body", + "format_parameters", + "set_streamed_data_body", + "set_text_body", + "set_xml_body", + "set_json_body", + "set_formdata_body", + "set_bytes_body", + "set_multipart_mixed", + "prepare_multipart_body", + "serialize", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + return self.__getattribute__(attr) + + def __setattr__(self, attr, value): + backcompat_attrs = [ + "multipart_mixed_info", + "files", + "data", + "body", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + super(HttpRequestBackcompatMixin, self).__setattr__(attr, value) + + @property + def _multipart_mixed_info(self): + """DEPRECATED: Information used to make multipart mixed requests. + + This is deprecated and will be removed in a later release. + """ + try: + return self.__multipart_mixed_info + except AttributeError: + return None + + @_multipart_mixed_info.setter + def _multipart_mixed_info(self, val): + """DEPRECATED: Set information to make multipart mixed requests. + + This is deprecated and will be removed in a later release. + """ + self.__multipart_mixed_info = val + + @property + def _query(self): + """DEPRECATED: Query parameters passed in by user + + This is deprecated and will be removed in a later release. + """ + query = urlparse(self.url).query + if query: + return {p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]} + return {} + + @property + def _body(self): + """DEPRECATED: Body of the request. You should use the `content` property instead + + This is deprecated and will be removed in a later release. + """ + return self._data + + @_body.setter + def _body(self, val): + """DEPRECATED: Set the body of the request + + This is deprecated and will be removed in a later release. + """ + self._data = val + + @staticmethod + def _format_data(data): + from ..pipeline.transport._base import HttpRequest as PipelineTransportHttpRequest + return PipelineTransportHttpRequest._format_data(data) # pylint: disable=protected-access + + def _format_parameters(self, params): + """DEPRECATED: Format the query parameters + + This is deprecated and will be removed in a later release. + You should pass the query parameters through the kwarg `params` + instead. + """ + return _format_parameters_helper(self, params) + + def _set_streamed_data_body(self, data): + """DEPRECATED: Set the streamed request body. + + This is deprecated and will be removed in a later release. + You should pass your stream content through the `content` kwarg instead + """ + if not isinstance(data, binary_type) and not any( + hasattr(data, attr) for attr in ["read", "__iter__", "__aiter__"] + ): + raise TypeError( + "A streamable data source must be an open file-like object or iterable." + ) + self._data = data + self._files = None + + def _set_text_body(self, data): + """DEPRECATED: Set the text body + + This is deprecated and will be removed in a later release. + You should pass your text content through the `content` kwarg instead + """ + if data is None: + self._data = None + else: + self._data = data + self.headers["Content-Length"] = str(len(self._data)) + self._files = None + + def _set_xml_body(self, data): + """DEPRECATED: Set the xml body. + + This is deprecated and will be removed in a later release. + You should pass your xml content through the `content` kwarg instead + """ + if data is None: + self._data = None + else: + bytes_data = ET.tostring(data, encoding="utf8") + self._data = bytes_data.replace(b"encoding='utf8'", b"encoding='utf-8'") + self.headers["Content-Length"] = str(len(self._data)) + self._files = None + + def _set_json_body(self, data): + """DEPRECATED: Set the json request body. + + This is deprecated and will be removed in a later release. + You should pass your json content through the `json` kwarg instead + """ + if data is None: + self._data = None + else: + self._data = dumps(data) + self.headers["Content-Length"] = str(len(self._data)) + self._files = None + + def _set_formdata_body(self, data=None): + """DEPRECATED: Set the formrequest body. + + This is deprecated and will be removed in a later release. + You should pass your stream content through the `files` kwarg instead + """ + if data is None: + data = {} + content_type = self.headers.pop("Content-Type", None) if self.headers else None + + if content_type and content_type.lower() == "application/x-www-form-urlencoded": + self._data = {f: d for f, d in data.items() if d is not None} + self._files = None + else: # Assume "multipart/form-data" + self._files = { + f: self._format_data(d) for f, d in data.items() if d is not None + } + self._data = None + + def _set_bytes_body(self, data): + """DEPRECATED: Set the bytes request body. + + This is deprecated and will be removed in a later release. + You should pass your bytes content through the `content` kwarg instead + """ + if data: + self.headers["Content-Length"] = str(len(data)) + self._data = data + self._files = None + + def _set_multipart_mixed(self, *requests, **kwargs): + """DEPRECATED: Set the multipart mixed info. + + This is deprecated and will be removed in a later release. + """ + self.multipart_mixed_info = ( + requests, + kwargs.pop("policies", []), + kwargs.pop("boundary", None), + kwargs + ) + + def _prepare_multipart_body(self, content_index=0): + """DEPRECATED: Prepare your request body for multipart requests. + + This is deprecated and will be removed in a later release. + """ + return _prepare_multipart_body_helper(self, content_index) + + def _serialize(self): + """DEPRECATED: Serialize this request using application/http spec. + + This is deprecated and will be removed in a later release. -def get_charset_encoding(response): - # type: (...) -> Optional[str] - content_type = response.headers.get("Content-Type") - - if not content_type: - return None - _, params = cgi.parse_header(content_type) - encoding = params.get('charset') # -> utf-8 - if encoding is None or not lookup_encoding(encoding): - return None - return encoding - -def decode_to_text(encoding, content): - # type: (Optional[str], bytes) -> str - if not content: - return "" - if encoding == "utf-8": - encoding = "utf-8-sig" - if encoding: - return content.decode(encoding) - return codecs.getincrementaldecoder("utf-8-sig")(errors="replace").decode(content) + :rtype: bytes + """ + return _serialize_request(self) diff --git a/sdk/core/azure-core/azure/core/rest/_helpers_py3.py b/sdk/core/azure-core/azure/core/rest/_helpers_py3.py deleted file mode 100644 index 90948012db2a..000000000000 --- a/sdk/core/azure-core/azure/core/rest/_helpers_py3.py +++ /dev/null @@ -1,101 +0,0 @@ -# -------------------------------------------------------------------------- -# -# Copyright (c) Microsoft Corporation. All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the ""Software""), to -# deal in the Software without restriction, including without limitation the -# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -# sell copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -# IN THE SOFTWARE. -# -# -------------------------------------------------------------------------- -import collections.abc -from typing import ( - AsyncIterable, - Dict, - Iterable, - Tuple, - Union, - Callable, - Optional, - AsyncIterator as AsyncIteratorType -) -from ..exceptions import StreamConsumedError, StreamClosedError - -from ._helpers import ( - _shared_set_content_body, - HeadersType -) -ContentType = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] - -def set_content_body(content: ContentType) -> Tuple[ - HeadersType, ContentType -]: - headers, body = _shared_set_content_body(content) - if body is not None: - return headers, body - if isinstance(content, collections.abc.AsyncIterable): - return {}, content - raise TypeError( - "Unexpected type for 'content': '{}'. ".format(type(content)) + - "We expect 'content' to either be str, bytes, or an Iterable / AsyncIterable" - ) - -def _stream_download_helper( - decompress: bool, - stream_download_generator: Callable, - response, -) -> AsyncIteratorType[bytes]: - if response.is_stream_consumed: - raise StreamConsumedError(response) - if response.is_closed: - raise StreamClosedError(response) - - response.is_stream_consumed = True - return stream_download_generator( - pipeline=None, - response=response, - decompress=decompress, - ) - -async def iter_bytes_helper( - stream_download_generator: Callable, - response, - content: Optional[bytes], -) -> AsyncIteratorType[bytes]: - if content: - chunk_size = response._connection_data_block_size # pylint: disable=protected-access - for i in range(0, len(content), chunk_size): - yield content[i : i + chunk_size] - else: - async for part in _stream_download_helper( - decompress=True, - stream_download_generator=stream_download_generator, - response=response, - ): - yield part - -async def iter_raw_helper( - stream_download_generator: Callable, - response, -) -> AsyncIteratorType[bytes]: - async for part in _stream_download_helper( - decompress=False, - stream_download_generator=stream_download_generator, - response=response, - ): - yield part diff --git a/sdk/core/azure-core/azure/core/rest/_requests_asyncio.py b/sdk/core/azure-core/azure/core/rest/_requests_asyncio.py deleted file mode 100644 index b21545a79804..000000000000 --- a/sdk/core/azure-core/azure/core/rest/_requests_asyncio.py +++ /dev/null @@ -1,83 +0,0 @@ -# -------------------------------------------------------------------------- -# -# Copyright (c) Microsoft Corporation. All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the ""Software""), to -# deal in the Software without restriction, including without limitation the -# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -# sell copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -# IN THE SOFTWARE. -# -# -------------------------------------------------------------------------- -from typing import AsyncIterator -import asyncio -from ._helpers_py3 import iter_bytes_helper, iter_raw_helper -from . import AsyncHttpResponse -from ._requests_basic import _RestRequestsTransportResponseBase, _has_content -from ..pipeline.transport._requests_asyncio import AsyncioStreamDownloadGenerator - -class RestAsyncioRequestsTransportResponse(AsyncHttpResponse, _RestRequestsTransportResponseBase): # type: ignore - """Asynchronous streaming of data from the response. - """ - - async def iter_raw(self) -> AsyncIterator[bytes]: - """Asynchronously iterates over the response's bytes. Will not decompress in the process - - :return: An async iterator of bytes from the response - :rtype: AsyncIterator[bytes] - """ - - async for part in iter_raw_helper(AsyncioStreamDownloadGenerator, self): - yield part - await self.close() - - async def iter_bytes(self) -> AsyncIterator[bytes]: - """Asynchronously iterates over the response's bytes. Will decompress in the process - - :return: An async iterator of bytes from the response - :rtype: AsyncIterator[bytes] - """ - async for part in iter_bytes_helper( - AsyncioStreamDownloadGenerator, - self, - content=self.content if _has_content(self) else None - ): - yield part - await self.close() - - async def close(self) -> None: - """Close the response. - - :return: None - :rtype: None - """ - self.is_closed = True - self._internal_response.close() - await asyncio.sleep(0) - - async def read(self) -> bytes: - """Read the response's bytes into memory. - - :return: The response's bytes - :rtype: bytes - """ - if not _has_content(self): - parts = [] - async for part in self.iter_bytes(): # type: ignore - parts.append(part) - self._internal_response._content = b"".join(parts) # pylint: disable=protected-access - return self.content diff --git a/sdk/core/azure-core/azure/core/rest/_requests_basic.py b/sdk/core/azure-core/azure/core/rest/_requests_basic.py deleted file mode 100644 index bd83dc29bd39..000000000000 --- a/sdk/core/azure-core/azure/core/rest/_requests_basic.py +++ /dev/null @@ -1,121 +0,0 @@ -# -------------------------------------------------------------------------- -# -# Copyright (c) Microsoft Corporation. All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the ""Software""), to -# deal in the Software without restriction, including without limitation the -# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -# sell copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -# IN THE SOFTWARE. -# -# -------------------------------------------------------------------------- -from typing import TYPE_CHECKING, cast - -from ..exceptions import ResponseNotReadError, StreamConsumedError, StreamClosedError -from ._rest import _HttpResponseBase, HttpResponse -from ..pipeline.transport._requests_basic import StreamDownloadGenerator - -if TYPE_CHECKING: - from typing import Iterator, Optional - -def _has_content(response): - try: - response.content # pylint: disable=pointless-statement - return True - except ResponseNotReadError: - return False - -class _RestRequestsTransportResponseBase(_HttpResponseBase): - def __init__(self, **kwargs): - super(_RestRequestsTransportResponseBase, self).__init__(**kwargs) - self.status_code = self._internal_response.status_code - self.headers = self._internal_response.headers - self.reason = self._internal_response.reason - self.content_type = self._internal_response.headers.get('content-type') - - @property - def content(self): - # type: () -> bytes - if not self._internal_response._content_consumed: # pylint: disable=protected-access - # if we just call .content, requests will read in the content. - # we want to read it in our own way - raise ResponseNotReadError(self) - - try: - return self._internal_response.content - except RuntimeError: - # requests throws a RuntimeError if the content for a response is already consumed - raise ResponseNotReadError(self) - -def _stream_download_helper(decompress, response): - if response.is_stream_consumed: - raise StreamConsumedError(response) - if response.is_closed: - raise StreamClosedError(response) - - response.is_stream_consumed = True - stream_download = StreamDownloadGenerator( - pipeline=None, - response=response, - decompress=decompress, - ) - for part in stream_download: - yield part - -class RestRequestsTransportResponse(HttpResponse, _RestRequestsTransportResponseBase): - - def iter_bytes(self): - # type: () -> Iterator[bytes] - """Iterates over the response's bytes. Will decompress in the process - :return: An iterator of bytes from the response - :rtype: Iterator[str] - """ - if _has_content(self): - chunk_size = cast(int, self._connection_data_block_size) - for i in range(0, len(self.content), chunk_size): - yield self.content[i : i + chunk_size] - else: - for part in _stream_download_helper( - decompress=True, - response=self, - ): - yield part - self.close() - - def iter_raw(self): - # type: () -> Iterator[bytes] - """Iterates over the response's bytes. Will not decompress in the process - :return: An iterator of bytes from the response - :rtype: Iterator[str] - """ - for raw_bytes in _stream_download_helper( - decompress=False, - response=self, - ): - yield raw_bytes - self.close() - - def read(self): - # type: () -> bytes - """Read the response's bytes. - - :return: The read in bytes - :rtype: bytes - """ - if not _has_content(self): - self._internal_response._content = b"".join(self.iter_bytes()) # pylint: disable=protected-access - return self.content diff --git a/sdk/core/azure-core/azure/core/rest/_requests_trio.py b/sdk/core/azure-core/azure/core/rest/_requests_trio.py deleted file mode 100644 index 9806380ef04f..000000000000 --- a/sdk/core/azure-core/azure/core/rest/_requests_trio.py +++ /dev/null @@ -1,77 +0,0 @@ -# -------------------------------------------------------------------------- -# -# Copyright (c) Microsoft Corporation. All rights reserved. -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the ""Software""), to -# deal in the Software without restriction, including without limitation the -# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -# sell copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -# IN THE SOFTWARE. -# -# -------------------------------------------------------------------------- -from typing import AsyncIterator -import trio -from . import AsyncHttpResponse -from ._requests_basic import _RestRequestsTransportResponseBase, _has_content -from ._helpers_py3 import iter_bytes_helper, iter_raw_helper -from ..pipeline.transport._requests_trio import TrioStreamDownloadGenerator - -class RestTrioRequestsTransportResponse(AsyncHttpResponse, _RestRequestsTransportResponseBase): # type: ignore - """Asynchronous streaming of data from the response. - """ - async def iter_raw(self) -> AsyncIterator[bytes]: - """Asynchronously iterates over the response's bytes. Will not decompress in the process - - :return: An async iterator of bytes from the response - :rtype: AsyncIterator[bytes] - """ - async for part in iter_raw_helper(TrioStreamDownloadGenerator, self): - yield part - await self.close() - - async def iter_bytes(self) -> AsyncIterator[bytes]: - """Asynchronously iterates over the response's bytes. Will decompress in the process - - :return: An async iterator of bytes from the response - :rtype: AsyncIterator[bytes] - """ - - async for part in iter_bytes_helper( - TrioStreamDownloadGenerator, - self, - content=self.content if _has_content(self) else None - ): - yield part - await self.close() - - async def read(self) -> bytes: - """Read the response's bytes into memory. - - :return: The response's bytes - :rtype: bytes - """ - if not _has_content(self): - parts = [] - async for part in self.iter_bytes(): # type: ignore - parts.append(part) - self._internal_response._content = b"".join(parts) # pylint: disable=protected-access - return self.content - - async def close(self) -> None: - self.is_closed = True - self._internal_response.close() - await trio.sleep(0) diff --git a/sdk/core/azure-core/azure/core/rest/_rest.py b/sdk/core/azure-core/azure/core/rest/_rest.py index 10a8486a2c64..d8562521a682 100644 --- a/sdk/core/azure-core/azure/core/rest/_rest.py +++ b/sdk/core/azure-core/azure/core/rest/_rest.py @@ -23,26 +23,20 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +import abc import copy -from json import loads -from typing import TYPE_CHECKING, cast - -from azure.core.exceptions import HttpResponseError +from typing import TYPE_CHECKING from ..utils._utils import _case_insensitive_dict from ._helpers import ( FilesType, - parse_lines_from_text, set_content_body, set_json_body, set_multipart_body, set_urlencoded_body, format_parameters, - to_pipeline_transport_request_helper, - from_pipeline_transport_request_helper, - get_charset_encoding, - decode_to_text, + HttpRequestBackcompatMixin, ) from ..exceptions import ResponseNotReadError if TYPE_CHECKING: @@ -60,11 +54,14 @@ from ._helpers import HeadersType, ContentTypeBase as ContentType - +try: + ABC = abc.ABC +except AttributeError: # Python 2.7, abc exists, but not ABC + ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore ################################## CLASSES ###################################### -class HttpRequest(object): +class HttpRequest(HttpRequestBackcompatMixin): """Provisional object that represents an HTTP request. **This object is provisional**, meaning it may be changed in a future release. @@ -184,37 +181,52 @@ def __deepcopy__(self, memo=None): except (ValueError, TypeError): return copy.copy(self) - def _to_pipeline_transport_request(self): - return to_pipeline_transport_request_helper(self) - - @classmethod - def _from_pipeline_transport_request(cls, pipeline_transport_request): - return from_pipeline_transport_request_helper(cls, pipeline_transport_request) - -class _HttpResponseBase(object): # pylint: disable=too-many-instance-attributes - - def __init__(self, **kwargs): - # type: (Any) -> None - self.request = kwargs.pop("request") - self._internal_response = kwargs.pop("internal_response") - self.status_code = None - self.headers = {} # type: HeadersType - self.reason = None - self.is_closed = False - self.is_stream_consumed = False - self.content_type = None - self._json = None # this is filled in ContentDecodePolicy, when we deserialize - self._connection_data_block_size = None # type: Optional[int] - self._content = None # type: Optional[bytes] - self._text = None # type: Optional[str] +class _HttpResponseBase(ABC): @property - def url(self): + @abc.abstractmethod + def request(self): + # type: (...) -> HttpRequest + """The request that resulted in this response.""" + + @property + @abc.abstractmethod + def status_code(self): + # type: (...) -> int + """The status code of this response""" + + @property + @abc.abstractmethod + def headers(self): + # type: (...) -> Optional[HeadersType] + """The response headers""" + + @property + @abc.abstractmethod + def reason(self): # type: (...) -> str - """Returns the URL that resulted in this response""" - return self.request.url + """The reason phrase for this response""" + + @property + @abc.abstractmethod + def content_type(self): + # type: (...) -> str + """The content type of the response""" + + @property + @abc.abstractmethod + def is_closed(self): + # type: (...) -> bool + """Whether the network connection has been closed yet""" + + @property + @abc.abstractmethod + def is_stream_consumed(self): + # type: (...) -> bool + """Whether the stream has been fully consumed""" @property + @abc.abstractmethod def encoding(self): # type: (...) -> Optional[str] """Returns the response encoding. @@ -224,19 +236,19 @@ def encoding(self): we return `None`. :rtype: optional[str] """ - try: - return self._encoding - except AttributeError: - self._encoding = get_charset_encoding(self) # type: Optional[str] - return self._encoding @encoding.setter def encoding(self, value): # type: (str) -> None """Sets the response encoding""" - self._encoding = value - self._text = None # clear text cache + @property + @abc.abstractmethod + def url(self): + # type: (...) -> str + """Returns the URL that resulted in this response""" + + @abc.abstractmethod def text(self, encoding=None): # type: (Optional[str]) -> str """Returns the response body as a string @@ -245,11 +257,8 @@ def text(self, encoding=None): also be set independently through our encoding property :return: The response's content decoded as a string. """ - if self._text is None or encoding: - encoding_to_pass = encoding or self.encoding - self._text = decode_to_text(encoding_to_pass, self.content) - return self._text + @abc.abstractmethod def json(self): # type: (...) -> Any """Returns the whole body as a json object. @@ -258,42 +267,27 @@ def json(self): :rtype: any :raises json.decoder.JSONDecodeError or ValueError (in python 2.7) if object is not JSON decodable: """ - # this will trigger errors if response is not read in - self.content # pylint: disable=pointless-statement - if not self._json: - self._json = loads(self.text()) - return self._json + @abc.abstractmethod def raise_for_status(self): # type: (...) -> None """Raises an HttpResponseError if the response has an error status code. If response is good, does nothing. """ - if cast(int, self.status_code) >= 400: - raise HttpResponseError(response=self) @property + @abc.abstractmethod def content(self): # type: (...) -> bytes """Return the response's content in bytes.""" - if self._content is None: - raise ResponseNotReadError(self) - return self._content - def __repr__(self): - # type: (...) -> str - content_type_str = ( - ", Content-Type: {}".format(self.content_type) if self.content_type else "" - ) - return "".format( - self.status_code, self.reason, content_type_str - ) -class HttpResponse(_HttpResponseBase): # pylint: disable=too-many-instance-attributes - """**Provisional** object that represents an HTTP response. +class HttpResponse(_HttpResponseBase): + """**Provisional** abstract base class for HTTP responses. **This object is provisional**, meaning it may be changed in a future release. + Use this abstract base class to create your own transport responses. It is returned from your client's `send_request` method if you pass in an :class:`~azure.core.rest.HttpRequest` @@ -304,16 +298,15 @@ class HttpResponse(_HttpResponseBase): # pylint: disable=too-many-instance-attr >>> response = client.send_request(request) - :keyword request: The request that resulted in this response. - :paramtype request: ~azure.core.rest.HttpRequest :ivar int status_code: The status code of this response - :ivar mapping headers: The response headers + :ivar mapping headers: The case-insensitive response headers. + While looking up headers is case-insensitive, when looking up + keys in `header.keys()`, we recommend using lowercase. :ivar str reason: The reason phrase for this response :ivar bytes content: The response content in bytes. :ivar str url: The URL that resulted in this response :ivar str encoding: The response encoding. Is settable, by default is the response Content-Type header - :ivar str text: The response body as a string. :ivar request: The request that resulted in this response. :vartype request: ~azure.core.rest.HttpRequest :ivar str content_type: The content type of the response @@ -322,57 +315,62 @@ class HttpResponse(_HttpResponseBase): # pylint: disable=too-many-instance-attr whether the stream has been fully consumed """ + @abc.abstractmethod def __enter__(self): # type: (...) -> HttpResponse - return self + """Enter this response""" + @abc.abstractmethod def close(self): # type: (...) -> None - self.is_closed = True - self._internal_response.close() + """Close this response""" + @abc.abstractmethod def __exit__(self, *args): # type: (...) -> None - self.close() + """Exit this response""" + @abc.abstractmethod def read(self): # type: (...) -> bytes - """ - Read the response's bytes. + """Read the response's bytes. + :return: The read in bytes + :rtype: bytes """ - if self._content is None: - self._content = b"".join(self.iter_bytes()) - return self.content + @abc.abstractmethod def iter_raw(self): # type: () -> Iterator[bytes] - """Iterate over the raw response bytes + """Iterates over the response's bytes. Will not decompress in the process + + :return: An iterator of bytes from the response + :rtype: Iterator[str] """ - raise NotImplementedError() + @abc.abstractmethod def iter_bytes(self): # type: () -> Iterator[bytes] - """Iterate over the response bytes + """Iterates over the response's bytes. Will decompress in the process + + :return: An iterator of bytes from the response + :rtype: Iterator[str] """ - raise NotImplementedError() + @abc.abstractmethod def iter_text(self): # type: () -> Iterator[str] - """Iterate over the response text + """Iterates over the text in the response. + + :return: An iterator of string. Each string chunk will be a text from the response + :rtype: Iterator[str] """ - for byte in self.iter_bytes(): - text = byte.decode(self.encoding or "utf-8") - yield text + @abc.abstractmethod def iter_lines(self): # type: () -> Iterator[str] - for text in self.iter_text(): - lines = parse_lines_from_text(text) - for line in lines: - yield line + """Iterates over the lines in the response. - def _close_stream(self): - # type: (...) -> None - self.is_stream_consumed = True - self.close() + :return: An iterator of string. Each string chunk will be a line from the response + :rtype: Iterator[str] + """ diff --git a/sdk/core/azure-core/azure/core/rest/_rest_py3.py b/sdk/core/azure-core/azure/core/rest/_rest_py3.py index 21e42f46b044..61cf5badeb68 100644 --- a/sdk/core/azure-core/azure/core/rest/_rest_py3.py +++ b/sdk/core/azure-core/azure/core/rest/_rest_py3.py @@ -23,69 +23,51 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +import abc import copy -import collections -import collections.abc -from json import loads from typing import ( Any, AsyncIterable, AsyncIterator, - Dict, - Iterable, Iterator, + Iterable, + Iterator, Optional, - Type, Union, + Tuple, ) - - -from azure.core.exceptions import HttpResponseError - +import collections.abc from ..utils._utils import _case_insensitive_dict from ._helpers import ( ParamsType, FilesType, HeadersType, - cast, - parse_lines_from_text, set_json_body, set_multipart_body, set_urlencoded_body, format_parameters, - to_pipeline_transport_request_helper, - from_pipeline_transport_request_helper, - get_charset_encoding, - decode_to_text, + _shared_set_content_body, + HttpRequestBackcompatMixin, ) -from ._helpers_py3 import set_content_body -from ..exceptions import ResponseNotReadError ContentType = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] -class _AsyncContextManager(collections.abc.Awaitable): - - def __init__(self, wrapped: collections.abc.Awaitable): - super().__init__() - self.wrapped = wrapped - self.response = None - - def __await__(self): - return self.wrapped.__await__() - - async def __aenter__(self): - self.response = await self - return self.response - - async def __aexit__(self, *args): - await self.response.__aexit__(*args) - - async def close(self): - await self.response.close() - ################################## CLASSES ###################################### -class HttpRequest: +def set_content_body(content: ContentType) -> Tuple[ + HeadersType, ContentType +]: + headers, body = _shared_set_content_body(content) + if body is not None: + return headers, body + if isinstance(content, collections.abc.AsyncIterable): + return {}, content + raise TypeError( + "Unexpected type for 'content': '{}'. ".format(type(content)) + + "We expect 'content' to either be str, bytes, or an Iterable / AsyncIterable" + ) + +class HttpRequest(HttpRequestBackcompatMixin): """**Provisional** object that represents an HTTP request. **This object is provisional**, meaning it may be changed in a future release. @@ -210,40 +192,54 @@ def __deepcopy__(self, memo=None) -> "HttpRequest": except (ValueError, TypeError): return copy.copy(self) - def _to_pipeline_transport_request(self): - return to_pipeline_transport_request_helper(self) +class _HttpResponseBase(abc.ABC): + """Base abstract base class for HttpResponses + """ + + @property + @abc.abstractmethod + def request(self) -> HttpRequest: + """The request that resulted in this response.""" + ... + + @property + @abc.abstractmethod + def status_code(self) -> int: + """The status code of this response""" + ... - @classmethod - def _from_pipeline_transport_request(cls, pipeline_transport_request): - return from_pipeline_transport_request_helper(cls, pipeline_transport_request) + @property + @abc.abstractmethod + def headers(self) -> Optional[HeadersType]: + """The response headers""" + ... -class _HttpResponseBase: # pylint: disable=too-many-instance-attributes + @property + @abc.abstractmethod + def reason(self) -> str: + """The reason phrase for this response""" + ... - def __init__( - self, - *, - request: HttpRequest, - **kwargs - ): - self.request = request - self._internal_response = kwargs.pop("internal_response") - self.status_code = None - self.headers = {} # type: HeadersType - self.reason = None - self.is_closed = False - self.is_stream_consumed = False - self.content_type = None - self._connection_data_block_size = None - self._json = None # this is filled in ContentDecodePolicy, when we deserialize - self._content = None # type: Optional[bytes] - self._text = None # type: Optional[str] + @property + @abc.abstractmethod + def content_type(self) -> str: + """The content type of the response""" + ... @property - def url(self) -> str: - """Returns the URL that resulted in this response""" - return self.request.url + @abc.abstractmethod + def is_closed(self) -> bool: + """Whether the network connection has been closed yet""" + ... @property + @abc.abstractmethod + def is_stream_consumed(self) -> bool: + """Whether the stream has been fully consumed""" + ... + + @property + @abc.abstractmethod def encoding(self) -> Optional[str]: """Returns the response encoding. @@ -252,18 +248,25 @@ def encoding(self) -> Optional[str]: we return `None`. :rtype: optional[str] """ - try: - return self._encoding - except AttributeError: - self._encoding: Optional[str] = get_charset_encoding(self) - return self._encoding + ... @encoding.setter - def encoding(self, value: str) -> None: + def encoding(self, value: Optional[str]) -> None: """Sets the response encoding""" - self._encoding = value - self._text = None # clear text cache + @property + @abc.abstractmethod + def url(self) -> str: + """The URL that resulted in this response""" + ... + + @property + @abc.abstractmethod + def content(self) -> bytes: + """Return the response's content in bytes.""" + ... + + @abc.abstractmethod def text(self, encoding: Optional[str] = None) -> str: """Returns the response body as a string @@ -271,11 +274,9 @@ def text(self, encoding: Optional[str] = None) -> str: also be set independently through our encoding property :return: The response's content decoded as a string. """ - if self._text is None or encoding: - encoding_to_pass = encoding or self.encoding - self._text = decode_to_text(encoding_to_pass, self.content) - return self._text + ... + @abc.abstractmethod def json(self) -> Any: """Returns the whole body as a json object. @@ -283,31 +284,21 @@ def json(self) -> Any: :rtype: any :raises json.decoder.JSONDecodeError or ValueError (in python 2.7) if object is not JSON decodable: """ - # this will trigger errors if response is not read in - self.content # pylint: disable=pointless-statement - if not self._json: - self._json = loads(self.text()) - return self._json + ... + @abc.abstractmethod def raise_for_status(self) -> None: """Raises an HttpResponseError if the response has an error status code. If response is good, does nothing. """ - if cast(int, self.status_code) >= 400: - raise HttpResponseError(response=self) - - @property - def content(self) -> bytes: - """Return the response's content in bytes.""" - if self._content is None: - raise ResponseNotReadError(self) - return self._content + ... class HttpResponse(_HttpResponseBase): - """**Provisional** object that represents an HTTP response. + """**Provisional** abstract base class for HTTP responses. **This object is provisional**, meaning it may be changed in a future release. + Use this abstract base class to create your own transport responses. It is returned from your client's `send_request` method if you pass in an :class:`~azure.core.rest.HttpRequest` @@ -318,16 +309,15 @@ class HttpResponse(_HttpResponseBase): >>> response = client.send_request(request) - :keyword request: The request that resulted in this response. - :paramtype request: ~azure.core.rest.HttpRequest :ivar int status_code: The status code of this response - :ivar mapping headers: The response headers + :ivar mapping headers: The case-insensitive response headers. + While looking up headers is case-insensitive, when looking up + keys in `header.keys()`, we recommend using lowercase. :ivar str reason: The reason phrase for this response :ivar bytes content: The response content in bytes. :ivar str url: The URL that resulted in this response :ivar str encoding: The response encoding. Is settable, by default is the response Content-Type header - :ivar str text: The response body as a string. :ivar request: The request that resulted in this response. :vartype request: ~azure.core.rest.HttpRequest :ivar str content_type: The content type of the response @@ -336,80 +326,68 @@ class HttpResponse(_HttpResponseBase): whether the stream has been fully consumed """ + @abc.abstractmethod def __enter__(self) -> "HttpResponse": - return self - - def close(self) -> None: - """Close the response - - :return: None - :rtype: None - """ - self.is_closed = True - self._internal_response.close() + ... + @abc.abstractmethod def __exit__(self, *args) -> None: - self.close() + ... + + @abc.abstractmethod + def close(self) -> None: + ... + @abc.abstractmethod def read(self) -> bytes: """Read the response's bytes. :return: The read in bytes :rtype: bytes """ - if self._content is None: - self._content = b"".join(self.iter_bytes()) - return self.content + ... + @abc.abstractmethod def iter_raw(self) -> Iterator[bytes]: """Iterates over the response's bytes. Will not decompress in the process :return: An iterator of bytes from the response :rtype: Iterator[str] """ - raise NotImplementedError() + ... + @abc.abstractmethod def iter_bytes(self) -> Iterator[bytes]: """Iterates over the response's bytes. Will decompress in the process :return: An iterator of bytes from the response :rtype: Iterator[str] """ - raise NotImplementedError() + ... + @abc.abstractmethod def iter_text(self) -> Iterator[str]: """Iterates over the text in the response. :return: An iterator of string. Each string chunk will be a text from the response :rtype: Iterator[str] """ - for byte in self.iter_bytes(): - text = byte.decode(self.encoding or "utf-8") - yield text + ... + @abc.abstractmethod def iter_lines(self) -> Iterator[str]: """Iterates over the lines in the response. :return: An iterator of string. Each string chunk will be a line from the response :rtype: Iterator[str] """ - for text in self.iter_text(): - lines = parse_lines_from_text(text) - for line in lines: - yield line - - def __repr__(self) -> str: - content_type_str = ( - ", Content-Type: {}".format(self.content_type) if self.content_type else "" - ) - return "".format( - self.status_code, self.reason, content_type_str - ) + ... class AsyncHttpResponse(_HttpResponseBase): - """**Provisional** object that represents an Async HTTP response. + """**Provisional** abstract base class for Async HTTP responses. **This object is provisional**, meaning it may be changed in a future release. + Use this abstract base class to create your own transport responses. It is returned from your async client's `send_request` method if you pass in an :class:`~azure.core.rest.HttpRequest` @@ -420,8 +398,6 @@ class AsyncHttpResponse(_HttpResponseBase): >>> response = await client.send_request(request) - :keyword request: The request that resulted in this response. - :paramtype request: ~azure.core.rest.HttpRequest :ivar int status_code: The status code of this response :ivar mapping headers: The response headers :ivar str reason: The reason phrase for this response @@ -429,7 +405,6 @@ class AsyncHttpResponse(_HttpResponseBase): :ivar str url: The URL that resulted in this response :ivar str encoding: The response encoding. Is settable, by default is the response Content-Type header - :ivar str text: The response body as a string. :ivar request: The request that resulted in this response. :vartype request: ~azure.core.rest.HttpRequest :ivar str content_type: The content type of the response @@ -438,76 +413,55 @@ class AsyncHttpResponse(_HttpResponseBase): whether the stream has been fully consumed """ + @abc.abstractmethod async def read(self) -> bytes: """Read the response's bytes into memory. :return: The response's bytes :rtype: bytes """ - if self._content is None: - parts = [] - async for part in self.iter_bytes(): - parts.append(part) - self._content = b"".join(parts) - return self._content + ... + @abc.abstractmethod async def iter_raw(self) -> AsyncIterator[bytes]: """Asynchronously iterates over the response's bytes. Will not decompress in the process :return: An async iterator of bytes from the response :rtype: AsyncIterator[bytes] """ - raise NotImplementedError() - # getting around mypy behavior, see https://github.com/python/mypy/issues/10732 - yield # pylint: disable=unreachable + ... + @abc.abstractmethod async def iter_bytes(self) -> AsyncIterator[bytes]: """Asynchronously iterates over the response's bytes. Will decompress in the process :return: An async iterator of bytes from the response :rtype: AsyncIterator[bytes] """ - raise NotImplementedError() - # getting around mypy behavior, see https://github.com/python/mypy/issues/10732 - yield # pylint: disable=unreachable + ... + @abc.abstractmethod async def iter_text(self) -> AsyncIterator[str]: """Asynchronously iterates over the text in the response. :return: An async iterator of string. Each string chunk will be a text from the response :rtype: AsyncIterator[str] """ - async for byte in self.iter_bytes(): # type: ignore - text = byte.decode(self.encoding or "utf-8") - yield text + ... + @abc.abstractmethod async def iter_lines(self) -> AsyncIterator[str]: """Asynchronously iterates over the lines in the response. :return: An async iterator of string. Each string chunk will be a line from the response :rtype: AsyncIterator[str] """ - async for text in self.iter_text(): - lines = parse_lines_from_text(text) - for line in lines: - yield line + ... + @abc.abstractmethod async def close(self) -> None: - """Close the response. - - :return: None - :rtype: None - """ - self.is_closed = True - await self._internal_response.close() + ... + @abc.abstractmethod async def __aexit__(self, *args) -> None: - await self.close() - - def __repr__(self) -> str: - content_type_str = ( - ", Content-Type: {}".format(self.content_type) if self.content_type else "" - ) - return "".format( - self.status_code, self.reason, content_type_str - ) + ... diff --git a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py new file mode 100644 index 000000000000..545f6a430e03 --- /dev/null +++ b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import absolute_import +from email.message import Message +from six.moves.http_client import HTTPConnection +try: + binary_type = str + from urlparse import urlparse # type: ignore +except ImportError: + binary_type = bytes # type: ignore + from urllib.parse import urlparse + + +def _format_parameters_helper(http_request, params): + query = urlparse(http_request.url).query + if query: + http_request.url = http_request.url.partition("?")[0] + existing_params = { + p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")] + } + params.update(existing_params) + query_params = [] + for k, v in params.items(): + if isinstance(v, list): + for w in v: + if w is None: + raise ValueError("Query parameter {} cannot be None".format(k)) + query_params.append("{}={}".format(k, w)) + else: + if v is None: + raise ValueError("Query parameter {} cannot be None".format(k)) + query_params.append("{}={}".format(k, v)) + query = "?" + "&".join(query_params) + http_request.url = http_request.url + query + +def _pad_attr_name(attr, backcompat_attrs): + return "_{}".format(attr) if attr in backcompat_attrs else attr + +def _prepare_multipart_body_helper(http_response, content_index=0): + if not http_response.multipart_mixed_info: + return 0 + + requests = http_response.multipart_mixed_info[0] # type: List[HttpRequest] + boundary = http_response.multipart_mixed_info[2] # type: Optional[str] + + # Update the main request with the body + main_message = Message() + main_message.add_header("Content-Type", "multipart/mixed") + if boundary: + main_message.set_boundary(boundary) + + for req in requests: + part_message = Message() + if req.multipart_mixed_info: + content_index = req.prepare_multipart_body(content_index=content_index) + part_message.add_header("Content-Type", req.headers['Content-Type']) + payload = req.serialize() + # We need to remove the ~HTTP/1.1 prefix along with the added content-length + payload = payload[payload.index(b'--'):] + else: + part_message.add_header("Content-Type", "application/http") + part_message.add_header("Content-Transfer-Encoding", "binary") + part_message.add_header("Content-ID", str(content_index)) + payload = req.serialize() + content_index += 1 + part_message.set_payload(payload) + main_message.attach(part_message) + + try: + from email.policy import HTTP + + full_message = main_message.as_bytes(policy=HTTP) + eol = b"\r\n" + except ImportError: # Python 2.7 + # Right now we decide to not support Python 2.7 on serialization, since + # it doesn't serialize a valid HTTP request (and our main scenario Storage refuses it) + raise NotImplementedError( + "Multipart request are not supported on Python 2.7" + ) + # full_message = main_message.as_string() + # eol = b'\n' + _, _, body = full_message.split(eol, 2) + http_response.set_bytes_body(body) + http_response.headers["Content-Type"] = ( + "multipart/mixed; boundary=" + main_message.get_boundary() + ) + return content_index + +class _HTTPSerializer(HTTPConnection, object): + """Hacking the stdlib HTTPConnection to serialize HTTP request as strings. + """ + + def __init__(self, *args, **kwargs): + self.buffer = b"" + kwargs.setdefault("host", "fakehost") + super(_HTTPSerializer, self).__init__(*args, **kwargs) + + def putheader(self, header, *values): + if header in ["Host", "Accept-Encoding"]: + return + super(_HTTPSerializer, self).putheader(header, *values) + + def send(self, data): + self.buffer += data + +def _serialize_request(http_request): + serializer = _HTTPSerializer() + serializer.request( + method=http_request.method, + url=http_request.url, + body=http_request.body, + headers=http_request.headers, + ) + return serializer.buffer diff --git a/sdk/core/azure-core/doc/azure.core.rst b/sdk/core/azure-core/doc/azure.core.rst index c6f0ed92463b..68af546c7bda 100644 --- a/sdk/core/azure-core/doc/azure.core.rst +++ b/sdk/core/azure-core/doc/azure.core.rst @@ -83,4 +83,7 @@ This module is ***provisional***, meaning any of the objects and methods in this :members: :undoc-members: :inherited-members: + :exclude-members: files,data,multipart_mixed_info,query,body,format_parameters,set_streamed_data_body, + set_text_body,set_xml_body,set_json_body,set_formdata_body,set_bytes_body,set_multipart_mixed,prepare_multipart_body, + serialize,internal_response,block_size,stream_download,parts diff --git a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py index 7230018aa37f..4a1deaf13a55 100644 --- a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py @@ -11,13 +11,13 @@ from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import AsyncPipeline from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, SansIOHTTPPolicy -from azure.core.pipeline.transport import HttpRequest +from utils import HTTP_REQUESTS import pytest pytestmark = pytest.mark.asyncio - -async def test_bearer_policy_adds_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_adds_header(http_request): """The bearer token policy should add a header containing a token from its credential""" # 2524608000 == 01/01/2050 @ 12:00am (UTC) expected_token = AccessToken("expected_token", 2524608000) @@ -37,17 +37,17 @@ async def get_token(_): policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] pipeline = AsyncPipeline(transport=Mock(), policies=policies) - await pipeline.run(HttpRequest("GET", "https://spam.eggs"), context=None) + await pipeline.run(http_request("GET", "https://spam.eggs"), context=None) assert get_token_calls == 1 - await pipeline.run(HttpRequest("GET", "https://spam.eggs"), context=None) + await pipeline.run(http_request("GET", "https://spam.eggs"), context=None) # Didn't need a new token assert get_token_calls == 1 - -async def test_bearer_policy_send(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_send(http_request): """The bearer token policy should invoke the next policy's send method and return the result""" - expected_request = HttpRequest("GET", "https://spam.eggs") + expected_request = http_request("GET", "https://spam.eggs") expected_response = Mock() async def verify_request(request): @@ -60,8 +60,8 @@ async def verify_request(request): assert response is expected_response - -async def test_bearer_policy_token_caching(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_token_caching(http_request): good_for_one_hour = AccessToken("token", time.time() + 3600) expected_token = good_for_one_hour get_token_calls = 0 @@ -78,10 +78,10 @@ async def get_token(_): ] pipeline = AsyncPipeline(transport=Mock, policies=policies) - await pipeline.run(HttpRequest("GET", "https://spam.eggs")) + await pipeline.run(http_request("GET", "https://spam.eggs")) assert get_token_calls == 1 # policy has no token at first request -> it should call get_token - await pipeline.run(HttpRequest("GET", "https://spam.eggs")) + await pipeline.run(http_request("GET", "https://spam.eggs")) assert get_token_calls == 1 # token is good for an hour -> policy should return it from cache expired_token = AccessToken("token", time.time()) @@ -93,14 +93,14 @@ async def get_token(_): ] pipeline = AsyncPipeline(transport=Mock(), policies=policies) - await pipeline.run(HttpRequest("GET", "https://spam.eggs")) + await pipeline.run(http_request("GET", "https://spam.eggs")) assert get_token_calls == 1 - await pipeline.run(HttpRequest("GET", "https://spam.eggs")) + await pipeline.run(http_request("GET", "https://spam.eggs")) assert get_token_calls == 2 # token expired -> policy should call get_token - -async def test_bearer_policy_optionally_enforces_https(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_optionally_enforces_https(http_request): """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" async def assert_option_popped(request, **kwargs): @@ -114,20 +114,20 @@ async def assert_option_popped(request, **kwargs): # by default and when enforce_https=True, the policy should raise when given an insecure request with pytest.raises(ServiceRequestError): - await pipeline.run(HttpRequest("GET", "http://not.secure")) + await pipeline.run(http_request("GET", "http://not.secure")) with pytest.raises(ServiceRequestError): - await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True) + await pipeline.run(http_request("GET", "http://not.secure"), enforce_https=True) # when enforce_https=False, an insecure request should pass - await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + await pipeline.run(http_request("GET", "http://not.secure"), enforce_https=False) # https requests should always pass - await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False) - await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True) - await pipeline.run(HttpRequest("GET", "https://secure")) + await pipeline.run(http_request("GET", "https://secure"), enforce_https=False) + await pipeline.run(http_request("GET", "https://secure"), enforce_https=True) + await pipeline.run(http_request("GET", "https://secure")) - -async def test_bearer_policy_preserves_enforce_https_opt_out(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_preserves_enforce_https_opt_out(http_request): """The policy should use request context to preserve an opt out from https enforcement""" class ContextValidator(SansIOHTTPPolicy): @@ -140,10 +140,10 @@ def on_request(self, request): policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=policies) - await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) - + await pipeline.run(http_request("GET", "http://not.secure"), enforce_https=False) -async def test_bearer_policy_context_unmodified_by_default(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_context_unmodified_by_default(http_request): """When no options for the policy accompany a request, the policy shouldn't add anything to the request context""" class ContextValidator(SansIOHTTPPolicy): @@ -156,10 +156,10 @@ def on_request(self, request): policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=policies) - await pipeline.run(HttpRequest("GET", "https://secure")) - + await pipeline.run(http_request("GET", "https://secure")) -async def test_bearer_policy_calls_sansio_methods(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_calls_sansio_methods(http_request): """AsyncBearerTokenCredentialPolicy should call SansIOHttpPolicy methods as does _SansIOAsyncHTTPPolicyRunner""" class TestPolicy(AsyncBearerTokenCredentialPolicy): @@ -179,7 +179,7 @@ async def send(self, request): transport = Mock(send=Mock(return_value=get_completed_future(Mock(status_code=200)))) pipeline = AsyncPipeline(transport=transport, policies=[policy]) - await pipeline.run(HttpRequest("GET", "https://localhost")) + await pipeline.run(http_request("GET", "https://localhost")) policy.on_request.assert_called_once_with(policy.request) policy.on_response.assert_called_once_with(policy.request, policy.response) @@ -193,7 +193,7 @@ class TestException(Exception): policy = TestPolicy(credential, "scope") pipeline = AsyncPipeline(transport=transport, policies=[policy]) with pytest.raises(TestException): - await pipeline.run(HttpRequest("GET", "https://localhost")) + await pipeline.run(http_request("GET", "https://localhost")) policy.on_exception.assert_called_once_with(policy.request) # ...or the second @@ -209,7 +209,7 @@ async def fake_send(*args, **kwargs): transport = Mock(send=Mock(wraps=fake_send)) pipeline = AsyncPipeline(transport=transport, policies=[policy]) with pytest.raises(TestException): - await pipeline.run(HttpRequest("GET", "https://localhost")) + await pipeline.run(http_request("GET", "https://localhost")) assert transport.send.call_count == 2 policy.on_challenge.assert_called_once() policy.on_exception.assert_called_once_with(policy.request) diff --git a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py index f010abaaddf4..e5f00a29bcac 100644 --- a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py @@ -44,12 +44,20 @@ from azure.core.exceptions import DecodeError, HttpResponseError from azure.core import AsyncPipelineClient from azure.core.pipeline import PipelineResponse, AsyncPipeline, PipelineContext -from azure.core.pipeline.transport import AsyncioRequestsTransportResponse, AsyncHttpTransport +from azure.core.pipeline.transport import AsyncHttpTransport from azure.core.polling.async_base_polling import ( AsyncLROBasePolling, ) +from utils import( + HTTP_REQUESTS, + ASYNCIO_REQUESTS_TRANSPORT_RESPONSES, + create_http_response, + pipeline_transport_and_rest_product, + is_rest_http_request, +) + class SimpleResource: """An implementation of Python 3 SimpleNamespace. @@ -84,7 +92,9 @@ class BadEndpointError(Exception): CLIENT = AsyncPipelineClient("http://example.org") async def mock_run(client_self, request, **kwargs): - return TestBasePolling.mock_update(request.url) + return TestBasePolling.mock_update(client_self.http_request_type, client_self.http_response_type, request.url) +CLIENT.http_request_type = None +CLIENT.http_response_type = None CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) @@ -127,27 +137,29 @@ def cb(pipeline_response): @pytest.fixture def polling_response(): - polling = AsyncLROBasePolling() - headers = {} + def _callback(http_response): + polling = AsyncLROBasePolling() + headers = {} - response = Response() - response.headers = headers - response.status_code = 200 + response = Response() + response.headers = headers + response.status_code = 200 - polling._pipeline_response = PipelineResponse( - None, - AsyncioRequestsTransportResponse( + polling._pipeline_response = PipelineResponse( None, - response, - ), - PipelineContext(None) - ) - polling._initial_response = polling._pipeline_response - return polling, headers - - -def test_base_polling_continuation_token(client, polling_response): - polling, _ = polling_response + create_http_response( + http_response, + None, + response, + ), + PipelineContext(None) + ) + polling._initial_response = polling._pipeline_response + return polling, headers + return _callback +@pytest.mark.parametrize("http_response", ASYNCIO_REQUESTS_TRANSPORT_RESPONSES) +def test_base_polling_continuation_token(client, polling_response, http_response): + polling, _ = polling_response(http_response) continuation_token = polling.get_continuation_token() assert isinstance(continuation_token, str) @@ -162,12 +174,15 @@ def test_base_polling_continuation_token(client, polling_response): @pytest.mark.asyncio -async def test_post(async_pipeline_client_builder, deserialization_cb): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_post(async_pipeline_client_builder, deserialization_cb, http_request, http_response): # Test POST LRO with both Location and Operation-Location # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, { @@ -182,12 +197,16 @@ async def send(request, **kwargs): if request.url == 'http://example.org/location': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body={'location_result': True} ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body={'status': 'Succeeded'} @@ -213,12 +232,16 @@ async def send(request, **kwargs): if request.url == 'http://example.org/location': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body=None ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body={'status': 'Succeeded'} @@ -238,12 +261,15 @@ async def send(request, **kwargs): @pytest.mark.asyncio -async def test_post_resource_location(async_pipeline_client_builder, deserialization_cb): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_post_resource_location(async_pipeline_client_builder, deserialization_cb, http_request, http_response): # ResourceLocation # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, { @@ -257,12 +283,16 @@ async def send(request, **kwargs): if request.url == 'http://example.org/resource_location': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body={'location_result': True} ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body={'status': 'Succeeded', 'resourceLocation': 'http://example.org/resource_location'} @@ -285,7 +315,7 @@ class TestBasePolling(object): convert = re.compile('([a-z0-9])([A-Z])') @staticmethod - def mock_send(method, status, headers=None, body=RESPONSE_BODY): + def mock_send(http_request, http_response, method, status, headers=None, body=RESPONSE_BODY): if headers is None: headers = {} response = Response() @@ -302,19 +332,28 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): response.headers.update({"content-type": "application/json; charset=utf8"}) response.reason = "OK" - request = CLIENT._request( - response.request.method, - response.request.url, - None, # params - response.request.headers, - body, - None, # form_content - None # stream_content - ) + if is_rest_http_request(http_request): + request = http_request( + response.request.method, + response.request.url, + headers=response.request.headers, + content=body, + ) + else: + request = CLIENT._request( + response.request.method, + response.request.url, + None, # params + response.request.headers, + body, + None, # form_content + None # stream_content + ) return PipelineResponse( request, - AsyncioRequestsTransportResponse( + create_http_response( + http_response, request, response, ), @@ -322,7 +361,7 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): ) @staticmethod - def mock_update(url, headers=None): + def mock_update(http_request, http_response, url, headers=None): response = Response() response._content_consumed = True response.request = mock.create_autospec(Request) @@ -354,19 +393,27 @@ def mock_update(url, headers=None): else: raise Exception('URL does not match') - request = CLIENT._request( - response.request.method, - response.request.url, - None, # params - {}, # request has no headers - None, # Request has no body - None, # form_content - None # stream_content - ) + if is_rest_http_request(http_request): + request = http_request( + response.request.method, + response.request.url, + headers={}, + ) + else: + request = CLIENT._request( + response.request.method, + response.request.url, + None, # params + {}, # request has no headers + None, # Request has no body + None, # form_content + None # stream_content + ) return PipelineResponse( request, - AsyncioRequestsTransportResponse( + create_http_response( + http_response, request, response, ), @@ -404,11 +451,15 @@ def mock_deserialization_no_body(pipeline_response): return None @pytest.mark.asyncio -async def test_long_running_put(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_long_running_put(http_request, http_response): #TODO: Test custom header field - + CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test throw on non LRO related status code - response = TestBasePolling.mock_send('PUT', 1000, {}) + response = TestBasePolling.mock_send( + http_request, + http_response,'PUT', 1000, {}) with pytest.raises(HttpResponseError): await async_poller(CLIENT, response, TestBasePolling.mock_outputs, @@ -420,6 +471,8 @@ async def test_long_running_put(): 'name': TEST_NAME } response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {}, response_body ) @@ -435,6 +488,8 @@ def no_update_allowed(url, headers=None): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {'operation-location': ASYNC_URL}) polling_method = AsyncLROBasePolling(0) @@ -446,6 +501,8 @@ def no_update_allowed(url, headers=None): # Test polling location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {'location': LOCATION_URL}) polling_method = AsyncLROBasePolling(0) @@ -458,6 +515,8 @@ def no_update_allowed(url, headers=None): # Test polling initial payload invalid (SQLDb) response_body = {} # Empty will raise response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {'location': LOCATION_URL}, response_body) polling_method = AsyncLROBasePolling(0) @@ -469,6 +528,8 @@ def no_update_allowed(url, headers=None): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -478,6 +539,8 @@ def no_update_allowed(url, headers=None): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -486,10 +549,14 @@ def no_update_allowed(url, headers=None): AsyncLROBasePolling(0)) @pytest.mark.asyncio -async def test_long_running_patch(): - +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_long_running_patch(http_request, http_response): + CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test polling from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -502,6 +569,8 @@ async def test_long_running_patch(): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -514,6 +583,8 @@ async def test_long_running_patch(): # Test polling from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 200, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -526,6 +597,8 @@ async def test_long_running_patch(): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 200, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -538,6 +611,8 @@ async def test_long_running_patch(): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -547,6 +622,8 @@ async def test_long_running_patch(): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -555,9 +632,14 @@ async def test_long_running_patch(): AsyncLROBasePolling(0)) @pytest.mark.asyncio -async def test_long_running_delete(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_long_running_delete(http_request, http_response): + CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'DELETE', 202, {'operation-location': ASYNC_URL}, body="" @@ -570,10 +652,14 @@ async def test_long_running_delete(): assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None @pytest.mark.asyncio -async def test_long_running_post(): - +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_long_running_post(http_request, http_response): + CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 201, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -585,6 +671,8 @@ async def test_long_running_post(): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -596,6 +684,8 @@ async def test_long_running_post(): # Test polling from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -608,6 +698,8 @@ async def test_long_running_post(): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -617,6 +709,8 @@ async def test_long_running_post(): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -625,13 +719,17 @@ async def test_long_running_post(): AsyncLROBasePolling(0)) @pytest.mark.asyncio -async def test_long_running_negative(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_long_running_negative(http_request, http_response): global LOCATION_BODY global POLLING_STATUS - + CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test LRO PUT throws for invalid json LOCATION_BODY = '{' response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = async_poller( @@ -645,6 +743,8 @@ async def test_long_running_negative(): LOCATION_BODY = '{\'"}' response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = async_poller(CLIENT, response, @@ -656,6 +756,8 @@ async def test_long_running_negative(): LOCATION_BODY = '{' POLLING_STATUS = 203 response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = async_poller(CLIENT, response, diff --git a/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py b/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py index 1ae80db6f60c..b0640c4eee87 100644 --- a/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py @@ -3,10 +3,24 @@ # Licensed under the MIT License. See LICENSE.txt in the project root for # license information. # ------------------------------------------------------------------------- -from azure.core.pipeline.transport import HttpRequest, AsyncHttpResponse, AsyncHttpTransport, AioHttpTransport +from six.moves.http_client import HTTPConnection +import time + +try: + from unittest import mock +except ImportError: + import mock + +from azure.core.pipeline.transport import ( + AsyncHttpResponse as PipelineTransportAsyncHttpResponse, AsyncHttpTransport, AioHttpTransport +) +from azure.core.rest import ( + AsyncHttpResponse as RestAsyncHttpResponse, +) from azure.core.pipeline.policies import HeadersPolicy from azure.core.pipeline import AsyncPipeline from azure.core.exceptions import HttpResponseError +from utils import HTTP_REQUESTS, pipeline_transport_and_rest_product import pytest @@ -23,20 +37,32 @@ async def close(self): pass async def send(self, request, **kwargs): pass -class MockResponse(AsyncHttpResponse): +class PipelineTransportMockResponse(PipelineTransportAsyncHttpResponse): def __init__(self, request, body, content_type): - super(MockResponse, self).__init__(request, None) + super(PipelineTransportMockResponse, self).__init__(request, None) self._body = body self.content_type = content_type def body(self): return self._body +class RestMockResponse(RestAsyncHttpResponse): + def __init__(self, request, body, content_type): + super(RestMockResponse, self).__init__(request=request, internal_response=None) + self._content = body + self.content_type = content_type + + @property + def content(self): + return self._content + +MOCK_RESPONSES = [PipelineTransportMockResponse, RestMockResponse] @pytest.mark.asyncio -async def test_basic_options_aiohttp(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_options_aiohttp(port, http_request): - request = HttpRequest("OPTIONS", "http://localhost:{}/basic/string".format(port)) + request = http_request("OPTIONS", "http://localhost:{}/basic/string".format(port)) async with AsyncPipeline(AioHttpTransport(), policies=[]) as pipeline: response = await pipeline.run(request) @@ -45,7 +71,8 @@ async def test_basic_options_aiohttp(port): @pytest.mark.asyncio -async def test_multipart_send(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send(http_request): transport = MockAsyncHttpTransport() class RequestPolicy(object): @@ -53,10 +80,10 @@ async def on_request(self, request): # type: (PipelineRequest) -> None request.http_request.headers['x-ms-date'] = 'Thu, 14 Jun 2018 16:46:54 GMT' - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -91,7 +118,8 @@ async def on_request(self, request): @pytest.mark.asyncio -async def test_multipart_send_with_context(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_context(http_request): transport = MockAsyncHttpTransport() header_policy = HeadersPolicy() @@ -101,10 +129,10 @@ async def on_request(self, request): # type: (PipelineRequest) -> None request.http_request.headers['x-ms-date'] = 'Thu, 14 Jun 2018 16:46:54 GMT' - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -142,19 +170,20 @@ async def on_request(self, request): @pytest.mark.asyncio -async def test_multipart_send_with_one_changeset(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_one_changeset(http_request): transport = MockAsyncHttpTransport() requests = [ - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ] - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( *requests, boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset, boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" @@ -190,22 +219,23 @@ async def test_multipart_send_with_one_changeset(): @pytest.mark.asyncio -async def test_multipart_send_with_multiple_changesets(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_multiple_changesets(http_request): transport = MockAsyncHttpTransport() - changeset1 = HttpRequest("", "") + changeset1 = http_request("", "") changeset1.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1"), boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - changeset2 = HttpRequest("", "") + changeset2 = http_request("", "") changeset2.set_multipart_mixed( - HttpRequest("DELETE", "/container2/blob2"), - HttpRequest("DELETE", "/container3/blob3"), + http_request("DELETE", "/container2/blob2"), + http_request("DELETE", "/container3/blob3"), boundary="changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset1, changeset2, @@ -263,19 +293,20 @@ async def test_multipart_send_with_multiple_changesets(): @pytest.mark.asyncio -async def test_multipart_send_with_combination_changeset_first(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_combination_changeset_first(http_request): transport = MockAsyncHttpTransport() - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1"), boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset, - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container2/blob2"), boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) @@ -317,17 +348,18 @@ async def test_multipart_send_with_combination_changeset_first(): @pytest.mark.asyncio -async def test_multipart_send_with_combination_changeset_last(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_combination_changeset_last(http_request): transport = MockAsyncHttpTransport() - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container1/blob1"), + http_request("DELETE", "/container2/blob2"), boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) @@ -370,18 +402,19 @@ async def test_multipart_send_with_combination_changeset_last(): @pytest.mark.asyncio -async def test_multipart_send_with_combination_changeset_middle(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_combination_changeset_middle(http_request): transport = MockAsyncHttpTransport() - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container1/blob1"), boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container2/blob2"), boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) @@ -423,7 +456,8 @@ async def test_multipart_send_with_combination_changeset_middle(): @pytest.mark.asyncio -async def test_multipart_receive(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +async def test_multipart_receive(http_request, mock_response): class ResponsePolicy(object): def on_response(self, request, response): @@ -435,10 +469,10 @@ async def on_response(self, request, response): # type: (PipelineRequest, PipelineResponse) -> None response.http_response.headers['x-ms-async-fun'] = 'true' - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -472,7 +506,7 @@ async def on_response(self, request, response): "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" ) - response = MockResponse( + response = mock_response( request, body_as_str.encode('ascii'), "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -496,14 +530,15 @@ async def on_response(self, request, response): @pytest.mark.asyncio -async def test_multipart_receive_with_one_changeset(): - changeset = HttpRequest("", "") +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +async def test_multipart_receive_with_one_changeset(http_request, mock_response): + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -534,7 +569,7 @@ async def test_multipart_receive_with_one_changeset(): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -550,20 +585,21 @@ async def test_multipart_receive_with_one_changeset(): @pytest.mark.asyncio -async def test_multipart_receive_with_multiple_changesets(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +async def test_multipart_receive_with_multiple_changesets(http_request, mock_response): - changeset1 = HttpRequest("", "") + changeset1 = http_request("", "") changeset1.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - changeset2 = HttpRequest("", "") + changeset2 = http_request("", "") changeset2.set_multipart_mixed( - HttpRequest("DELETE", "/container2/blob2"), - HttpRequest("DELETE", "/container3/blob3") + http_request("DELETE", "/container2/blob2"), + http_request("DELETE", "/container3/blob3") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset1, changeset2) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -619,7 +655,7 @@ async def test_multipart_receive_with_multiple_changesets(): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -636,16 +672,17 @@ async def test_multipart_receive_with_multiple_changesets(): @pytest.mark.asyncio -async def test_multipart_receive_with_combination_changeset_first(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +async def test_multipart_receive_with_combination_changeset_first(http_request, mock_response): - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") - request.set_multipart_mixed(changeset, HttpRequest("DELETE", "/container2/blob2")) + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(changeset, http_request("DELETE", "/container2/blob2")) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' @@ -685,7 +722,7 @@ async def test_multipart_receive_with_combination_changeset_first(): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -699,29 +736,32 @@ async def test_multipart_receive_with_combination_changeset_first(): assert parts[1].status_code == 202 assert parts[2].status_code == 404 -def test_raise_for_status_bad_response(): - response = MockResponse(request=None, body=None, content_type=None) +@pytest.mark.parametrize("mock_response", MOCK_RESPONSES) +def test_raise_for_status_bad_response(mock_response): + response = mock_response( request=None, body=None, content_type=None) response.status_code = 400 with pytest.raises(HttpResponseError): response.raise_for_status() -def test_raise_for_status_good_response(): - response = MockResponse(request=None, body=None, content_type=None) +@pytest.mark.parametrize("mock_response", MOCK_RESPONSES) +def test_raise_for_status_good_response(mock_response): + response = mock_response( request=None, body=None, content_type=None) response.status_code = 200 response.raise_for_status() @pytest.mark.asyncio -async def test_multipart_receive_with_combination_changeset_middle(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +async def test_multipart_receive_with_combination_changeset_middle(http_request, mock_response): - changeset = HttpRequest("", "") - changeset.set_multipart_mixed(HttpRequest("DELETE", "/container1/blob1")) + changeset = http_request("", "") + changeset.set_multipart_mixed(http_request("DELETE", "/container1/blob1")) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, - HttpRequest("DELETE", "/container2/blob2") + http_request("DELETE", "/container2/blob2") ) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -762,7 +802,7 @@ async def test_multipart_receive_with_combination_changeset_middle(): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -778,16 +818,17 @@ async def test_multipart_receive_with_combination_changeset_middle(): @pytest.mark.asyncio -async def test_multipart_receive_with_combination_changeset_last(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +async def test_multipart_receive_with_combination_changeset_last(http_request, mock_response): - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), - HttpRequest("DELETE", "/container2/blob2") + http_request("DELETE", "/container1/blob1"), + http_request("DELETE", "/container2/blob2") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") - request.set_multipart_mixed(HttpRequest("DELETE", "/container0/blob0"), changeset) + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(http_request("DELETE", "/container0/blob0"), changeset) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -828,7 +869,7 @@ async def test_multipart_receive_with_combination_changeset_last(): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -844,11 +885,12 @@ async def test_multipart_receive_with_combination_changeset_last(): @pytest.mark.asyncio -async def test_multipart_receive_with_bom(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +async def test_multipart_receive_with_bom(http_request, mock_response): - req0 = HttpRequest("DELETE", "/container0/blob0") + req0 = http_request("DELETE", "/container0/blob0") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(req0) body_as_bytes = ( b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\n" @@ -866,7 +908,7 @@ async def test_multipart_receive_with_bom(): b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -883,12 +925,13 @@ async def test_multipart_receive_with_bom(): @pytest.mark.asyncio -async def test_recursive_multipart_receive(): - req0 = HttpRequest("DELETE", "/container0/blob0") - internal_req0 = HttpRequest("DELETE", "/container0/blob0") +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +async def test_recursive_multipart_receive(http_request, mock_response): + req0 = http_request("DELETE", "/container0/blob0") + internal_req0 = http_request("DELETE", "/container0/blob0") req0.set_multipart_mixed(internal_req0) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(req0) internal_body_as_str = ( "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" @@ -914,7 +957,7 @@ async def test_recursive_multipart_receive(): "--batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6--" ).format(internal_body_as_str) - response = MockResponse( + response = mock_response( request, body_as_str.encode('ascii'), "multipart/mixed; boundary=batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6" diff --git a/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py b/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py index 0c4e931bc4e2..632ec69a8d5e 100644 --- a/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py @@ -14,16 +14,13 @@ PipelineRequest, PipelineContext ) -from azure.core.pipeline.transport import ( - HttpRequest, - HttpResponse, -) from azure.core.pipeline.policies import ( HttpLoggingPolicy, ) +from utils import create_http_response, HTTP_REQUESTS, HTTP_RESPONSES, pipeline_transport_and_rest_product - -def test_http_logger(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_http_logger(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -41,8 +38,8 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') - http_response = HttpResponse(universal_request, None) + universal_request = http_request('GET', 'http://localhost/') + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) @@ -136,8 +133,8 @@ def emit(self, record): mock_handler.reset() - -def test_http_logger_operation_level(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_http_logger_operation_level(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -156,8 +153,8 @@ def emit(self, record): policy = HttpLoggingPolicy() kwargs={'logger': logger} - universal_request = HttpRequest('GET', 'http://localhost/') - http_response = HttpResponse(universal_request, None) + universal_request = http_request('GET', 'http://localhost/') + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None, **kwargs)) @@ -207,8 +204,8 @@ def emit(self, record): mock_handler.reset() - -def test_http_logger_with_body(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_http_logger_with_body(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -226,9 +223,9 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') universal_request.body = "testbody" - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) @@ -249,7 +246,8 @@ def emit(self, record): @pytest.mark.skipif(sys.version_info < (3, 6), reason="types.AsyncGeneratorType does not exist in 3.5") -def test_http_logger_with_generator_body(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_http_logger_with_generator_body(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -267,11 +265,11 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') mock = Mock() mock.__class__ = types.AsyncGeneratorType universal_request.body = mock - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) diff --git a/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py b/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py index a84d53093ca3..6645c075fba2 100644 --- a/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py @@ -38,7 +38,6 @@ ) from azure.core.pipeline.transport import ( AsyncHttpTransport, - HttpRequest, AsyncioRequestsTransport, TrioRequestsTransport, AioHttpTransport @@ -50,7 +49,7 @@ from azure.core.configuration import Configuration from azure.core import AsyncPipelineClient from azure.core.exceptions import AzureError - +from utils import HTTP_REQUESTS import aiohttp import trio @@ -58,7 +57,8 @@ @pytest.mark.asyncio -async def test_sans_io_exception(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_sans_io_exception(http_request): class BrokenSender(AsyncHttpTransport): async def send(self, request, **config): raise ValueError("Broken") @@ -75,7 +75,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): pipeline = AsyncPipeline(BrokenSender(), [SansIOHTTPPolicy()]) - req = HttpRequest('GET', '/') + req = http_request('GET', '/') with pytest.raises(ValueError): await pipeline.run(req) @@ -89,9 +89,10 @@ def on_exception(self, requests, **kwargs): await pipeline.run(req) @pytest.mark.asyncio -async def test_basic_aiohttp(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_aiohttp(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), AsyncRedirectPolicy() @@ -104,10 +105,11 @@ async def test_basic_aiohttp(port): assert isinstance(response.http_response.status_code, int) @pytest.mark.asyncio -async def test_basic_aiohttp_separate_session(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_aiohttp_separate_session(port, http_request): session = aiohttp.ClientSession() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), AsyncRedirectPolicy() @@ -123,9 +125,10 @@ async def test_basic_aiohttp_separate_session(port): await transport.session.close() @pytest.mark.asyncio -async def test_basic_async_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_async_requests(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), AsyncRedirectPolicy() @@ -185,9 +188,10 @@ def test_pass_in_http_logging_policy(): assert http_logging_policy.allowed_header_names == HttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union({"x-ms-added-header"}) @pytest.mark.asyncio -async def test_conf_async_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_conf_async_requests(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), AsyncRedirectPolicy() @@ -197,10 +201,11 @@ async def test_conf_async_requests(port): assert isinstance(response.http_response.status_code, int) -def test_conf_async_trio_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_conf_async_trio_requests(port, http_request): async def do(): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), AsyncRedirectPolicy() @@ -212,7 +217,8 @@ async def do(): assert isinstance(response.http_response.status_code, int) @pytest.mark.asyncio -async def test_retry_without_http_response(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_retry_without_http_response(http_request): class NaughtyPolicy(AsyncHTTPPolicy): def send(*args): raise AzureError('boo') @@ -220,7 +226,7 @@ def send(*args): policies = [AsyncRetryPolicy(), NaughtyPolicy()] pipeline = AsyncPipeline(policies=policies, transport=None) with pytest.raises(AzureError): - await pipeline.run(HttpRequest('GET', url='https://foo.bar')) + await pipeline.run(http_request('GET', url='https://foo.bar')) @pytest.mark.asyncio async def test_add_custom_policy(): diff --git a/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py b/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py index 55856ae6ca94..8a8199e65429 100644 --- a/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py +++ b/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py @@ -5,13 +5,14 @@ # ------------------------------------------------------------------------- import json -from azure.core.pipeline.transport import AsyncioRequestsTransport, HttpRequest - +from azure.core.pipeline.transport import AsyncioRequestsTransport +from utils import HTTP_REQUESTS import pytest @pytest.mark.asyncio -async def test_async_gen_data(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_async_gen_data(port, http_request): class AsyncGen: def __init__(self): self._range = iter([b"azerty"]) @@ -26,13 +27,14 @@ async def __anext__(self): raise StopAsyncIteration async with AsyncioRequestsTransport() as transport: - req = HttpRequest('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) + req = http_request('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) response = await transport.send(req) assert json.loads(response.text())['data'] == "azerty" @pytest.mark.asyncio -async def test_send_data(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_send_data(port, http_request): async with AsyncioRequestsTransport() as transport: - req = HttpRequest('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") + req = http_request('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") response = await transport.send(req) assert json.loads(response.text())['data'] == "azerty" diff --git a/sdk/core/azure-core/tests/async_tests/test_request_trio.py b/sdk/core/azure-core/tests/async_tests/test_request_trio.py index 11a0058404f5..7fafa0d41c28 100644 --- a/sdk/core/azure-core/tests/async_tests/test_request_trio.py +++ b/sdk/core/azure-core/tests/async_tests/test_request_trio.py @@ -5,13 +5,15 @@ # ------------------------------------------------------------------------- import json -from azure.core.pipeline.transport import TrioRequestsTransport, HttpRequest +from azure.core.pipeline.transport import TrioRequestsTransport +from utils import HTTP_REQUESTS import pytest @pytest.mark.trio -async def test_async_gen_data(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_async_gen_data(port, http_request): class AsyncGen: def __init__(self): self._range = iter([b"azerty"]) @@ -26,14 +28,15 @@ async def __anext__(self): raise StopAsyncIteration async with TrioRequestsTransport() as transport: - req = HttpRequest('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) + req = http_request('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) response = await transport.send(req) assert json.loads(response.text())['data'] == "azerty" @pytest.mark.trio -async def test_send_data(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_send_data(port, http_request): async with TrioRequestsTransport() as transport: - req = HttpRequest('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") + req = http_request('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") response = await transport.send(req) assert json.loads(response.text())['data'] == "azerty" \ No newline at end of file diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py b/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py index 9f5052cdb4bd..bc9507aa8daa 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py @@ -5,10 +5,11 @@ # ------------------------------------------------------------------------- from azure.core.pipeline.transport import AsyncioRequestsTransport from azure.core.rest import HttpRequest +from azure.core.pipeline.transport._requests_asyncio import RestAsyncioRequestsTransportResponse from rest_client_async import AsyncTestRestClient import pytest - +from utils import readonly_checks @pytest.mark.asyncio async def test_async_gen_data(port): @@ -39,3 +40,14 @@ async def test_send_data(port): response = await client.send_request(request) assert response.json()['data'] == "azerty" + +@pytest.mark.asyncio +async def test_readonly(port): + """Make sure everything that is readonly is readonly""" + async with AsyncioRequestsTransport() as transport: + client = AsyncTestRestClient(port, transport=transport) + response = await client.send_request(HttpRequest("GET", "/health")) + response.raise_for_status() + + assert isinstance(response, RestAsyncioRequestsTransportResponse) + readonly_checks(response) diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py b/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py index 40587252a14f..81c8e4188774 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py @@ -8,8 +8,11 @@ # Thank you httpx for your wonderful tests! import io import pytest -from azure.core.rest import HttpRequest +import zlib +from azure.core.rest import HttpRequest, AsyncHttpResponse +from azure.core.pipeline.transport._aiohttp import RestAioHttpTransportResponse from azure.core.exceptions import HttpResponseError +from utils import readonly_checks @pytest.fixture def send_request(client): @@ -169,18 +172,6 @@ async def test_response_no_charset_with_iso_8859_1_content(send_request): assert response.text() == "Accented: �sterreich" assert response.encoding is None -# NOTE: aiohttp isn't liking this -# @pytest.mark.asyncio -# async def test_response_set_explicit_encoding(send_request): -# response = await send_request( -# request=HttpRequest("GET", "/encoding/latin-1-with-utf-8"), -# ) -# assert response.headers["Content-Type"] == "text/plain; charset=utf-8" -# response.encoding = "latin-1" -# await response.read() -# assert response.text() == "Latin 1: ÿ" -# assert response.encoding == "latin-1" - @pytest.mark.asyncio async def test_json(send_request): response = await send_request( @@ -269,6 +260,28 @@ async def test_text_and_encoding(send_request): assert response.text("latin-1") == 'ð\x9f\x91©' == response.content.decode("latin-1") assert response.encoding == "utf-16" +@pytest.mark.asyncio +async def test_aiohttp_response_decompression(send_request): + response = await send_request(HttpRequest("GET", "/decompression/gzip/pass")) + + expected = b'{"id":"e7877039-1376-4dcd-9b0a-192897cff780","createdDateTimeUtc":' \ + b'"2021-05-07T17:35:36.3121065Z","lastActionDateTimeUtc":' \ + b'"2021-05-07T17:35:36.3121069Z","status":"NotStarted",' \ + b'"summary":{"total":0,"failed":0,"success":0,"inProgress":0,' \ + b'"notYetStarted":0,"cancelled":0,"totalCharacterCharged":0}}' + assert response.content == expected + assert response.body() == expected + +@pytest.mark.asyncio +async def test_aiohttp_response_decompression_negative(send_request): + # here our behavior differs from the old aiohttp transport behavior, but is the same as httpx behavior + # current aiohttp transport will not fail until users access the response's body + # new aiohttp transport behavior (same as httpx), will fail as part of reading the response + # when we make the call + with pytest.raises(zlib.error): + await send_request(HttpRequest("GET", "/decompression/gzip/fail")) + + # @pytest.mark.asyncio # async def test_multipart_encode_non_seekable_filelike(send_request): # """ @@ -294,4 +307,17 @@ async def test_text_and_encoding(send_request): # "/multipart/non-seekable-filelike", # files=files, # ) -# await send_request(request) \ No newline at end of file +# await send_request(request) + +def test_initialize_response_abc(): + with pytest.raises(TypeError) as ex: + AsyncHttpResponse() + assert "Can't instantiate abstract class" in str(ex) + +@pytest.mark.asyncio +async def test_readonly(send_request): + """Make sure everything that is readonly is readonly""" + response = await send_request(HttpRequest("GET", "/health")) + + assert isinstance(response, RestAioHttpTransportResponse) + readonly_checks(response) diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py b/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py index ba9981df6697..ffbb0af0f30d 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py @@ -5,8 +5,9 @@ # ------------------------------------------------------------------------- from azure.core.pipeline.transport import TrioRequestsTransport from azure.core.rest import HttpRequest +from azure.core.pipeline.transport._requests_trio import RestTrioRequestsTransportResponse from rest_client_async import AsyncTestRestClient - +from utils import readonly_checks import pytest @@ -39,3 +40,15 @@ async def test_send_data(port): response = await client.send_request(request) assert response.json()['data'] == "azerty" + +@pytest.mark.trio +async def test_readonly(port): + """Make sure everything that is readonly is readonly""" + async with TrioRequestsTransport() as transport: + request = HttpRequest('GET', 'http://localhost:{}/health'.format(port)) + client = AsyncTestRestClient(port, transport=transport) + response = await client.send_request(HttpRequest("GET", "/health")) + response.raise_for_status() + + assert isinstance(response, RestTrioRequestsTransportResponse) + readonly_checks(response) diff --git a/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py b/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py index faa0944e6399..1ff9c8bd9275 100644 --- a/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py @@ -23,15 +23,14 @@ RetryMode, ) from azure.core.pipeline import AsyncPipeline, PipelineResponse -from azure.core.pipeline.transport import ( - HttpRequest, - HttpResponse, - AsyncHttpTransport, -) +from azure.core.pipeline.transport import AsyncHttpTransport import tempfile import os import time import asyncio +from utils import create_http_response, HTTP_REQUESTS, HTTP_RESPONSES, pipeline_transport_and_rest_product + +RETRY_AFTER_INPUTS = ['0', '800', '1000', '1200'] def test_retry_code_class_variables(): retry_policy = AsyncRetryPolicy() @@ -59,11 +58,11 @@ def test_retry_types(): backoff_time = retry_policy.get_backoff_time(settings) assert backoff_time == 4 -@pytest.mark.parametrize("retry_after_input", [('0'), ('800'), ('1000'), ('1200')]) -def test_retry_after(retry_after_input): +@pytest.mark.parametrize("retry_after_input,http_request,http_response", pipeline_transport_and_rest_product(RETRY_AFTER_INPUTS, HTTP_REQUESTS, HTTP_RESPONSES)) +def test_retry_after(retry_after_input, http_request, http_response): retry_policy = AsyncRetryPolicy() - request = HttpRequest("GET", "http://localhost") - response = HttpResponse(request, None) + request = http_request("GET", "http://localhost") + response = create_http_response(http_response, request, None) response.headers["retry-after-ms"] = retry_after_input pipeline_response = PipelineResponse(request, response, None) retry_after = retry_policy.get_retry_after(pipeline_response) @@ -77,11 +76,11 @@ def test_retry_after(retry_after_input): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) -@pytest.mark.parametrize("retry_after_input", [('0'), ('800'), ('1000'), ('1200')]) -def test_x_ms_retry_after(retry_after_input): +@pytest.mark.parametrize("retry_after_input,http_request,http_response", pipeline_transport_and_rest_product(RETRY_AFTER_INPUTS, HTTP_REQUESTS, HTTP_RESPONSES)) +def test_x_ms_retry_after(retry_after_input, http_request, http_response): retry_policy = AsyncRetryPolicy() - request = HttpRequest("GET", "http://localhost") - response = HttpResponse(request, None) + request = http_request("GET", "http://localhost") + response = create_http_response(http_response, request, None) response.headers["x-ms-retry-after-ms"] = retry_after_input pipeline_response = PipelineResponse(request, response, None) retry_after = retry_policy.get_retry_after(pipeline_response) @@ -96,7 +95,8 @@ def test_x_ms_retry_after(retry_after_input): assert retry_after == float(retry_after_input) @pytest.mark.asyncio -async def test_retry_on_429(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +async def test_retry_on_429(http_request, http_response): class MockTransport(AsyncHttpTransport): def __init__(self): self._count = 0 @@ -109,11 +109,11 @@ async def open(self): async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineResponse self._count += 1 - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 429 return response - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost') http_retry = AsyncRetryPolicy(retry_total = 1) transport = MockTransport() pipeline = AsyncPipeline(transport, [http_retry]) @@ -121,7 +121,8 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe assert transport._count == 2 @pytest.mark.asyncio -async def test_no_retry_on_201(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +async def test_no_retry_on_201(http_request, http_response): class MockTransport(AsyncHttpTransport): def __init__(self): self._count = 0 @@ -134,13 +135,13 @@ async def open(self): async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineResponse self._count += 1 - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 201 headers = {"Retry-After": "1"} response.headers = headers return response - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost') http_retry = AsyncRetryPolicy(retry_total = 1) transport = MockTransport() pipeline = AsyncPipeline(transport, [http_retry]) @@ -148,7 +149,8 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe assert transport._count == 1 @pytest.mark.asyncio -async def test_retry_seekable_stream(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +async def test_retry_seekable_stream(http_request, http_response): class MockTransport(AsyncHttpTransport): def __init__(self): self._first = True @@ -166,19 +168,20 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe raise AzureError('fail on first') position = request.body.tell() assert position == 0 - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 400 return response data = BytesIO(b"Lots of dataaaa") - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost') http_request.set_streamed_data_body(data) http_retry = AsyncRetryPolicy(retry_total = 1) pipeline = AsyncPipeline(MockTransport(), [http_retry]) await pipeline.run(http_request) @pytest.mark.asyncio -async def test_retry_seekable_file(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +async def test_retry_seekable_file(http_request, http_response): class MockTransport(AsyncHttpTransport): def __init__(self): self._first = True @@ -202,14 +205,14 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe if name and body and hasattr(body, 'read'): position = body.tell() assert not position - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 400 return response file = tempfile.NamedTemporaryFile(delete=False) file.write(b'Lots of dataaaa') file.close() - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost') headers = {'Content-Type': "multipart/form-data"} http_request.headers = headers with open(file.name, 'rb') as f: @@ -225,7 +228,8 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe @pytest.mark.asyncio -async def test_retry_timeout(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_retry_timeout(http_request): timeout = 1 def send(request, **kwargs): @@ -241,17 +245,18 @@ def send(request, **kwargs): pipeline = AsyncPipeline(transport, [AsyncRetryPolicy(timeout=timeout)]) with pytest.raises(ServiceResponseTimeoutError): - await pipeline.run(HttpRequest("GET", "http://localhost/")) + await pipeline.run(http_request("GET", "http://localhost")) @pytest.mark.asyncio -async def test_timeout_defaults(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +async def test_timeout_defaults(http_request, http_response): """When "timeout" is not set, the policy should not override the transport's timeout configuration""" async def send(request, **kwargs): for arg in ("connection_timeout", "read_timeout"): assert arg not in kwargs, "policy should defer to transport configuration when not given a timeout" - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 200 return response @@ -262,18 +267,17 @@ async def send(request, **kwargs): ) pipeline = AsyncPipeline(transport, [AsyncRetryPolicy()]) - await pipeline.run(HttpRequest("GET", "http://localhost/")) + await pipeline.run(http_request("GET", "http://localhost")) assert transport.send.call_count == 1, "policy should not retry: its first send succeeded" +TRANSPORT_AND_EXPECTED_TIMEOUT_ERRORS = [(ServiceRequestError, ServiceRequestTimeoutError), (ServiceResponseError, ServiceResponseTimeoutError)] @pytest.mark.asyncio -@pytest.mark.parametrize( - "transport_error,expected_timeout_error", - ((ServiceRequestError, ServiceRequestTimeoutError), (ServiceResponseError, ServiceResponseTimeoutError)), -) -async def test_does_not_sleep_after_timeout(transport_error, expected_timeout_error): +@pytest.mark.parametrize("transport_and_expected_timeout_errors,http_request", pipeline_transport_and_rest_product(TRANSPORT_AND_EXPECTED_TIMEOUT_ERRORS, HTTP_REQUESTS)) +async def test_does_not_sleep_after_timeout(transport_and_expected_timeout_errors, http_request): # With default settings policy will sleep twice before exhausting its retries: 1.6s, 3.2s. # It should not sleep the second time when given timeout=1 + transport_error, expected_timeout_error = transport_and_expected_timeout_errors timeout = 1 transport = Mock( @@ -284,6 +288,6 @@ async def test_does_not_sleep_after_timeout(transport_error, expected_timeout_er pipeline = AsyncPipeline(transport, [AsyncRetryPolicy(timeout=timeout)]) with pytest.raises(expected_timeout_error): - await pipeline.run(HttpRequest("GET", "http://localhost/")) + await pipeline.run(http_request("GET", "http://localhost")) assert transport.sleep.call_count == 1 diff --git a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py index d90b5b15b4c9..bea14200abd4 100644 --- a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py @@ -4,19 +4,18 @@ # ------------------------------------ import requests from azure.core.pipeline.transport import ( - HttpRequest, - AsyncHttpResponse, AsyncHttpTransport, - AsyncioRequestsTransportResponse, AioHttpTransport, ) from azure.core.pipeline import AsyncPipeline, PipelineResponse from azure.core.pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator from unittest import mock +from utils import HTTP_REQUESTS, HTTP_RESPONSES, ASYNCIO_REQUESTS_TRANSPORT_RESPONSES, create_http_response, pipeline_transport_and_rest_product import pytest @pytest.mark.asyncio -async def test_connection_error_response(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +async def test_connection_error_response(http_request, http_response): class MockSession(object): def __init__(self): self.auto_decompress = True @@ -38,8 +37,8 @@ async def open(self): pass async def send(self, request, **kwargs): - request = HttpRequest('GET', 'http://localhost/') - response = AsyncHttpResponse(request, None) + request = http_request('GET', 'http://localhost') + response = create_http_response(http_response, request, None) response.status_code = 200 return response @@ -65,9 +64,9 @@ class AsyncMock(mock.MagicMock): async def __call__(self, *args, **kwargs): return super(AsyncMock, self).__call__(*args, **kwargs) - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost') pipeline = AsyncPipeline(MockTransport()) - http_response = AsyncHttpResponse(http_request, None) + http_response = create_http_response(http_response, http_request, None) http_response.internal_response = MockInternalResponse() stream = AioHttpStreamDownloadGenerator(pipeline, http_response, decompress=False) with mock.patch('asyncio.sleep', new_callable=AsyncMock): @@ -75,7 +74,8 @@ async def __call__(self, *args, **kwargs): await stream.__anext__() @pytest.mark.asyncio -async def test_response_streaming_error_behavior(): +@pytest.mark.parametrize("http_response", ASYNCIO_REQUESTS_TRANSPORT_RESPONSES) +async def test_response_streaming_error_behavior(http_response): # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723 block_size = 103 total_response_size = 500 @@ -111,7 +111,8 @@ def close(self): req_response.raw = FakeStreamWithConnectionError() - response = AsyncioRequestsTransportResponse( + response = create_http_response( + http_response, req_request, req_response, block_size, diff --git a/sdk/core/azure-core/tests/async_tests/test_streaming_async.py b/sdk/core/azure-core/tests/async_tests/test_streaming_async.py index 6b4c9ac3b912..8ca4413be936 100644 --- a/sdk/core/azure-core/tests/async_tests/test_streaming_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_streaming_async.py @@ -27,16 +27,18 @@ import pytest from azure.core import AsyncPipelineClient from azure.core.exceptions import DecodeError +from utils import HTTP_REQUESTS @pytest.mark.asyncio -async def test_decompress_plain_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_decompress_plain_no_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -47,14 +49,15 @@ async def test_decompress_plain_no_header(): assert decoded == "test" @pytest.mark.asyncio -async def test_compress_plain_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_compress_plain_no_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -65,14 +68,15 @@ async def test_compress_plain_no_header(): assert decoded == "test" @pytest.mark.asyncio -async def test_decompress_compressed_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_decompress_compressed_no_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -86,14 +90,15 @@ async def test_decompress_compressed_no_header(): pass @pytest.mark.asyncio -async def test_compress_compressed_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_compress_compressed_no_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -107,7 +112,8 @@ async def test_compress_compressed_no_header(): pass @pytest.mark.asyncio -async def test_decompress_plain_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_decompress_plain_header(http_request): # expect error import zlib account_name = "coretests" @@ -115,7 +121,7 @@ async def test_decompress_plain_header(): url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -128,14 +134,15 @@ async def test_decompress_plain_header(): pass @pytest.mark.asyncio -async def test_compress_plain_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_compress_plain_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -146,14 +153,15 @@ async def test_compress_plain_header(): assert decoded == "test" @pytest.mark.asyncio -async def test_decompress_compressed_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_decompress_compressed_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -164,14 +172,15 @@ async def test_decompress_compressed_header(): assert decoded == "test" @pytest.mark.asyncio -async def test_compress_compressed_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_compress_compressed_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) diff --git a/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py b/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py index 62b606218a0e..0a29f816fba5 100644 --- a/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py @@ -15,12 +15,12 @@ import pytest from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.policies import HTTPPolicy -from azure.core.pipeline.transport import HttpTransport, HttpRequest +from azure.core.pipeline.transport import HttpTransport from azure.core.settings import settings from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from tracing_common import FakeSpan - +from utils import HTTP_REQUESTS @pytest.fixture(scope="module") def fake_span(): @@ -29,9 +29,9 @@ def fake_span(): class MockClient: @distributed_trace - def __init__(self, policies=None, assert_current_span=False): + def __init__(self, http_request, policies=None, assert_current_span=False): time.sleep(0.001) - self.request = HttpRequest("GET", "http://localhost") + self.request = http_request("GET", "http://localhost") if policies is None: policies = [] policies.append(mock.Mock(spec=HTTPPolicy, send=self.verify_request)) @@ -88,9 +88,10 @@ async def raising_exception(self): class TestAsyncDecorator(object): @pytest.mark.asyncio - async def test_decorator_tracing_attr(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_decorator_tracing_attr(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) await client.tracing_attr() assert len(parent.children) == 2 @@ -100,9 +101,10 @@ async def test_decorator_tracing_attr(self): @pytest.mark.asyncio - async def test_decorator_has_different_name(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_decorator_has_different_name(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) await client.check_name_is_different() assert len(parent.children) == 2 assert parent.children[0].name == "MockClient.__init__" @@ -110,9 +112,10 @@ async def test_decorator_has_different_name(self): @pytest.mark.asyncio - async def test_used(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_used(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient(policies=[]) + client = MockClient(http_request, policies=[]) await client.get_foo(parent_span=parent) await client.get_foo() @@ -126,9 +129,10 @@ async def test_used(self): @pytest.mark.asyncio - async def test_span_merge_span(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_span_merge_span(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) await client.merge_span_method() await client.no_merge_span_method() @@ -142,9 +146,10 @@ async def test_span_merge_span(self): @pytest.mark.asyncio - async def test_span_complicated(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_span_complicated(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) await client.make_request(2) with parent.span("child") as child: time.sleep(0.001) @@ -163,11 +168,12 @@ async def test_span_complicated(self): assert not parent.children[3].children @pytest.mark.asyncio - async def test_span_with_exception(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_span_with_exception(self, http_request): """Assert that if an exception is raised, the next sibling method is actually a sibling span. """ with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) try: await client.raising_exception() except: diff --git a/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py b/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py index fa65fc3353c5..be98b685c758 100644 --- a/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py @@ -23,8 +23,8 @@ # THE SOFTWARE. # #-------------------------------------------------------------------------- +from itertools import product from azure.core.pipeline.transport import ( - HttpRequest, AioHttpTransport, AioHttpTransportResponse, AsyncioRequestsTransport, @@ -34,12 +34,15 @@ import trio import pytest +from unittest import mock +from utils import HTTP_REQUESTS, AIOHTTP_TRANSPORT_RESPONSES, create_http_response, is_rest_http_response @pytest.mark.asyncio -async def test_basic_aiohttp(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_aiohttp(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) async with AioHttpTransport() as sender: response = await sender.send(request) assert response.body() is not None @@ -48,18 +51,20 @@ async def test_basic_aiohttp(port): assert isinstance(response.status_code, int) @pytest.mark.asyncio -async def test_aiohttp_auto_headers(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_aiohttp_auto_headers(port, http_request): - request = HttpRequest("POST", "http://localhost:{}/basic/string".format(port)) + request = http_request("POST", "http://localhost:{}/basic/string".format(port)) async with AioHttpTransport() as sender: response = await sender.send(request) auto_headers = response.internal_response.request_info.headers assert 'Content-Type' not in auto_headers @pytest.mark.asyncio -async def test_basic_async_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_async_requests(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) async with AsyncioRequestsTransport() as sender: response = await sender.send(request) assert response.body() is not None @@ -67,19 +72,21 @@ async def test_basic_async_requests(port): assert isinstance(response.status_code, int) @pytest.mark.asyncio -async def test_conf_async_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_conf_async_requests(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) async with AsyncioRequestsTransport() as sender: response = await sender.send(request) assert response.body() is not None assert isinstance(response.status_code, int) -def test_conf_async_trio_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_conf_async_trio_requests(port, http_request): async def do(): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) async with TrioRequestsTransport() as sender: return await sender.send(request) assert response.body() is not None @@ -88,7 +95,7 @@ async def do(): assert isinstance(response.status_code, int) -def _create_aiohttp_response(body_bytes, headers=None): +def _create_aiohttp_response(http_response, body_bytes, headers=None): class MockAiohttpClientResponse(aiohttp.ClientResponse): def __init__(self, body_bytes, headers=None): self._body = body_bytes @@ -99,21 +106,27 @@ def __init__(self, body_bytes, headers=None): req_response = MockAiohttpClientResponse(body_bytes, headers) - response = AioHttpTransportResponse( + response = create_http_response( + http_response, None, # Don't need a request here req_response ) - response._body = body_bytes + if is_rest_http_response(http_response): + response._content = body_bytes + else: + response._body = body_bytes return response @pytest.mark.asyncio -async def test_aiohttp_response_text(): +@pytest.mark.parametrize("http_response", AIOHTTP_TRANSPORT_RESPONSES) +async def test_aiohttp_response_text(http_response): for encoding in ["utf-8", "utf-8-sig", None]: res = _create_aiohttp_response( + http_response, b'\xef\xbb\xbf56', {'Content-Type': 'text/plain'} ) @@ -121,7 +134,9 @@ async def test_aiohttp_response_text(): @pytest.mark.asyncio async def test_aiohttp_response_decompression(): + # not parametrizing this test, added a test with the same name for rest testing res = _create_aiohttp_response( + AioHttpTransportResponse, b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x04\x00\x8d\x8d\xb1n\xc30\x0cD" b"\xff\x85s\x14HVlY\xda\x8av.\n4\x1d\x9a\x8d\xa1\xe5D\x80m\x01\x12=" b"\x14A\xfe\xbd\x92\x81d\xceB\x1c\xef\xf8\x8e7\x08\x038\xf0\xa67Fj+" @@ -142,9 +157,11 @@ async def test_aiohttp_response_decompression(): assert res.body() == expect, "Decompression didn't work" @pytest.mark.asyncio -async def test_aiohttp_response_decompression_negtive(): +async def test_aiohttp_response_decompression_negative(): import zlib + # not parametrizing this test, added a test with the same name for rest testing res = _create_aiohttp_response( + AioHttpTransportResponse, b"\xff\x85s\x14HVlY\xda\x8av.\n4\x1d\x9a\x8d\xa1\xe5D\x80m\x01\x12=" b"\x14A\xfe\xbd\x92\x81d\xceB\x1c\xef\xf8\x8e7\x08\x038\xf0\xa67Fj+" b"\x946\x9d8\x0c4\x08{\x96(\x94mzkh\x1cM/a\x07\x94<\xb2\x1f>\xca8\x86" @@ -158,11 +175,16 @@ async def test_aiohttp_response_decompression_negtive(): with pytest.raises(zlib.error): body = res.body() -def test_repr(): +@pytest.mark.parametrize("http_response", AIOHTTP_TRANSPORT_RESPONSES) +def test_repr(http_response): res = _create_aiohttp_response( + http_response, b'\xef\xbb\xbf56', {} ) res.content_type = "text/plain" - - assert repr(res) == "" + if is_rest_http_response(http_response): + cls_name = "AsyncHttpResponse" + else: + cls_name = "AioHttpTransportResponse" + assert repr(res) == f"<{cls_name}: 200 OK, Content-Type: text/plain>" diff --git a/sdk/core/azure-core/tests/test_authentication.py b/sdk/core/azure-core/tests/test_authentication.py index de029e8ea352..1d944285cda4 100644 --- a/sdk/core/azure-core/tests/test_authentication.py +++ b/sdk/core/azure-core/tests/test_authentication.py @@ -15,8 +15,7 @@ AzureKeyCredentialPolicy, AzureSasCredentialPolicy, ) -from azure.core.pipeline.transport import HttpRequest - +from itertools import product import pytest try: @@ -24,9 +23,10 @@ except ImportError: # python < 3.3 from mock import Mock +from utils import HTTP_REQUESTS - -def test_bearer_policy_adds_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_adds_header(http_request): """The bearer token policy should add a header containing a token from its credential""" # 2524608000 == 01/01/2050 @ 12:00am (UTC) expected_token = AccessToken("expected_token", 2524608000) @@ -39,19 +39,20 @@ def verify_authorization_header(request): policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] pipeline = Pipeline(transport=Mock(), policies=policies) - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) assert fake_credential.get_token.call_count == 1 - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) # Didn't need a new token assert fake_credential.get_token.call_count == 1 -def test_bearer_policy_send(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_send(http_request): """The bearer token policy should invoke the next policy's send method and return the result""" - expected_request = HttpRequest("GET", "https://spam.eggs") + expected_request = http_request("GET", "https://spam.eggs") expected_response = Mock() def verify_request(request): @@ -65,15 +66,16 @@ def verify_request(request): assert response is expected_response -def test_bearer_policy_token_caching(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_token_caching(http_request): good_for_one_hour = AccessToken("token", time.time() + 3600) credential = Mock(get_token=Mock(return_value=good_for_one_hour)) pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) assert credential.get_token.call_count == 1 # policy has no token at first request -> it should call get_token - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) assert credential.get_token.call_count == 1 # token is good for an hour -> policy should return it from cache expired_token = AccessToken("token", time.time()) @@ -81,14 +83,15 @@ def test_bearer_policy_token_caching(): credential.get_token.return_value = expired_token pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) assert credential.get_token.call_count == 1 - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) assert credential.get_token.call_count == 2 # token expired -> policy should call get_token -def test_bearer_policy_optionally_enforces_https(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_optionally_enforces_https(http_request): """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" def assert_option_popped(request, **kwargs): @@ -102,20 +105,21 @@ def assert_option_popped(request, **kwargs): # by default and when enforce_https=True, the policy should raise when given an insecure request with pytest.raises(ServiceRequestError): - pipeline.run(HttpRequest("GET", "http://not.secure")) + pipeline.run(http_request("GET", "http://not.secure")) with pytest.raises(ServiceRequestError): - pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True) + pipeline.run(http_request("GET", "http://not.secure"), enforce_https=True) # when enforce_https=False, an insecure request should pass - pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + pipeline.run(http_request("GET", "http://not.secure"), enforce_https=False) # https requests should always pass - pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False) - pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True) - pipeline.run(HttpRequest("GET", "https://secure")) + pipeline.run(http_request("GET", "https://secure"), enforce_https=False) + pipeline.run(http_request("GET", "https://secure"), enforce_https=True) + pipeline.run(http_request("GET", "https://secure")) -def test_bearer_policy_preserves_enforce_https_opt_out(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_preserves_enforce_https_opt_out(http_request): """The policy should use request context to preserve an opt out from https enforcement""" class ContextValidator(SansIOHTTPPolicy): @@ -127,10 +131,11 @@ def on_request(self, request): policies = [BearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = Pipeline(transport=Mock(), policies=policies) - pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + pipeline.run(http_request("GET", "http://not.secure"), enforce_https=False) -def test_bearer_policy_default_context(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_default_context(http_request): """The policy should call get_token with the scopes given at construction, and no keyword arguments, by default""" expected_scope = "scope" token = AccessToken("", 0) @@ -138,12 +143,13 @@ def test_bearer_policy_default_context(): policy = BearerTokenCredentialPolicy(credential, expected_scope) pipeline = Pipeline(transport=Mock(), policies=[policy]) - pipeline.run(HttpRequest("GET", "https://localhost")) + pipeline.run(http_request("GET", "https://localhost")) credential.get_token.assert_called_once_with(expected_scope) -def test_bearer_policy_context_unmodified_by_default(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_context_unmodified_by_default(http_request): """When no options for the policy accompany a request, the policy shouldn't add anything to the request context""" class ContextValidator(SansIOHTTPPolicy): @@ -154,10 +160,11 @@ def on_request(self, request): policies = [BearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = Pipeline(transport=Mock(), policies=policies) - pipeline.run(HttpRequest("GET", "https://secure")) + pipeline.run(http_request("GET", "https://secure")) -def test_bearer_policy_calls_on_challenge(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_calls_on_challenge(http_request): """BearerTokenCredentialPolicy should call its on_challenge method when it receives an authentication challenge""" class TestPolicy(BearerTokenCredentialPolicy): @@ -173,12 +180,13 @@ def on_challenge(self, request, challenge): transport = Mock(send=Mock(return_value=response)) pipeline = Pipeline(transport=transport, policies=policies) - pipeline.run(HttpRequest("GET", "https://localhost")) + pipeline.run(http_request("GET", "https://localhost")) assert TestPolicy.called -def test_bearer_policy_cannot_complete_challenge(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_cannot_complete_challenge(http_request): """BearerTokenCredentialPolicy should return the 401 response when it can't complete its challenge""" expected_scope = "scope" @@ -189,14 +197,15 @@ def test_bearer_policy_cannot_complete_challenge(): policies = [BearerTokenCredentialPolicy(credential, expected_scope)] pipeline = Pipeline(transport=transport, policies=policies) - response = pipeline.run(HttpRequest("GET", "https://localhost")) + response = pipeline.run(http_request("GET", "https://localhost")) assert response.http_response is expected_response assert transport.send.call_count == 1 credential.get_token.assert_called_once_with(expected_scope) -def test_bearer_policy_calls_sansio_methods(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_calls_sansio_methods(http_request): """BearerTokenCredentialPolicy should call SansIOHttpPolicy methods as does _SansIOHTTPPolicyRunner""" class TestPolicy(BearerTokenCredentialPolicy): @@ -216,7 +225,7 @@ def send(self, request): transport = Mock(send=Mock(return_value=Mock(status_code=200))) pipeline = Pipeline(transport=transport, policies=[policy]) - pipeline.run(HttpRequest("GET", "https://localhost")) + pipeline.run(http_request("GET", "https://localhost")) policy.on_request.assert_called_once_with(policy.request) policy.on_response.assert_called_once_with(policy.request, policy.response) @@ -230,7 +239,7 @@ class TestException(Exception): policy = TestPolicy(credential, "scope") pipeline = Pipeline(transport=transport, policies=[policy]) with pytest.raises(TestException): - pipeline.run(HttpRequest("GET", "https://localhost")) + pipeline.run(http_request("GET", "https://localhost")) policy.on_exception.assert_called_once_with(policy.request) # ...or the second @@ -246,7 +255,7 @@ def raise_the_second_time(*args, **kwargs): transport = Mock(send=Mock(wraps=raise_the_second_time)) pipeline = Pipeline(transport=transport, policies=[policy]) with pytest.raises(TestException): - pipeline.run(HttpRequest("GET", "https://localhost")) + pipeline.run(http_request("GET", "https://localhost")) assert transport.send.call_count == 2 policy.on_challenge.assert_called_once() policy.on_exception.assert_called_once_with(policy.request) @@ -273,7 +282,8 @@ def test_key_vault_regression(): assert policy._token.token == token -def test_azure_key_credential_policy(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_azure_key_credential_policy(http_request): """Tests to see if we can create an AzureKeyCredentialPolicy""" key_header = "api_key" @@ -287,8 +297,7 @@ def verify_authorization_header(request): credential_policy = AzureKeyCredentialPolicy(credential=credential, name=key_header) pipeline = Pipeline(transport=transport, policies=[credential_policy]) - pipeline.run(HttpRequest("GET", "https://test_key_credential")) - + pipeline.run(http_request("GET", "https://test_key_credential")) def test_azure_key_credential_policy_raises(): """Tests AzureKeyCredential and AzureKeyCredentialPolicy raises with non-string input parameters.""" @@ -301,7 +310,6 @@ def test_azure_key_credential_policy_raises(): with pytest.raises(TypeError): credential_policy = AzureKeyCredentialPolicy(credential=credential, name=key_header) - def test_azure_key_credential_updates(): """Tests AzureKeyCredential updates""" api_key = "original" @@ -313,7 +321,7 @@ def test_azure_key_credential_updates(): credential.update(api_key) assert credential.key == api_key -@pytest.mark.parametrize("sas,url,expected_url", [ +sas_url_expected_url = [ ("sig=test_signature", "https://test_sas_credential", "https://test_sas_credential?sig=test_signature"), ("?sig=test_signature", "https://test_sas_credential", "https://test_sas_credential?sig=test_signature"), ("sig=test_signature", "https://test_sas_credential?sig=test_signature", "https://test_sas_credential?sig=test_signature"), @@ -322,9 +330,11 @@ def test_azure_key_credential_updates(): ("?sig=test_signature", "https://test_sas_credential?", "https://test_sas_credential?sig=test_signature"), ("sig=test_signature", "https://test_sas_credential?foo=bar", "https://test_sas_credential?foo=bar&sig=test_signature"), ("?sig=test_signature", "https://test_sas_credential?foo=bar", "https://test_sas_credential?foo=bar&sig=test_signature"), -]) -def test_azure_sas_credential_policy(sas, url, expected_url): +] +@pytest.mark.parametrize("sas_url_expected_url,http_request", product(sas_url_expected_url, HTTP_REQUESTS)) +def test_azure_sas_credential_policy(sas_url_expected_url, http_request): """Tests to see if we can create an AzureSasCredentialPolicy""" + sas, url, expected_url = sas_url_expected_url def verify_authorization(request): assert request.url == expected_url @@ -334,7 +344,7 @@ def verify_authorization(request): credential_policy = AzureSasCredentialPolicy(credential=credential) pipeline = Pipeline(transport=transport, policies=[credential_policy]) - pipeline.run(HttpRequest("GET", url)) + pipeline.run(http_request("GET", url)) def test_azure_sas_credential_updates(): """Tests AzureSasCredential updates""" diff --git a/sdk/core/azure-core/tests/test_base_polling.py b/sdk/core/azure-core/tests/test_base_polling.py index f5c1db343302..8307c3e7cb1b 100644 --- a/sdk/core/azure-core/tests/test_base_polling.py +++ b/sdk/core/azure-core/tests/test_base_polling.py @@ -44,11 +44,19 @@ from azure.core.exceptions import DecodeError, HttpResponseError from azure.core import PipelineClient from azure.core.pipeline import PipelineResponse, Pipeline, PipelineContext -from azure.core.pipeline.transport import RequestsTransportResponse, HttpTransport +from azure.core.pipeline.transport import HttpTransport from azure.core.polling.base_polling import LROBasePolling from azure.core.pipeline.policies._utils import _FixedOffset +from utils import ( + HTTP_REQUESTS, + REQUESTS_TRANSPORT_RESPONSES, + create_http_response, + pipeline_transport_and_rest_product, + is_rest_http_request, +) + class SimpleResource: """An implementation of Python 3 SimpleNamespace. @@ -83,7 +91,9 @@ class BadEndpointError(Exception): CLIENT = PipelineClient("http://example.org") def mock_run(client_self, request, **kwargs): - return TestBasePolling.mock_update(request.url, request.headers) + return TestBasePolling.mock_update(client_self.http_request_type, client_self.http_response_type, request.url, request.headers) +CLIENT.http_request_type = None +CLIENT.http_response_type = None CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) @@ -126,27 +136,30 @@ def cb(pipeline_response): @pytest.fixture def polling_response(): - polling = LROBasePolling() - headers = {} + def _callback(http_response): + polling = LROBasePolling() + headers = {} - response = Response() - response.headers = headers - response.status_code = 200 + response = Response() + response.headers = headers + response.status_code = 200 - polling._pipeline_response = PipelineResponse( - None, - RequestsTransportResponse( + polling._pipeline_response = PipelineResponse( None, - response, - ), - PipelineContext(None) - ) - polling._initial_response = polling._pipeline_response - return polling, headers - + create_http_response( + http_response, + None, + response, + ), + PipelineContext(None) + ) + polling._initial_response = polling._pipeline_response + return polling, headers + return _callback -def test_base_polling_continuation_token(client, polling_response): - polling, _ = polling_response +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_base_polling_continuation_token(client, polling_response, http_response): + polling, _ = polling_response(http_response) continuation_token = polling.get_continuation_token() assert isinstance(continuation_token, six.string_types) @@ -159,17 +172,18 @@ def test_base_polling_continuation_token(client, polling_response): new_polling = LROBasePolling() new_polling.initialize(*polling_args) - -def test_delay_extraction_int(polling_response): - polling, headers = polling_response +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_delay_extraction_int(polling_response, http_response): + polling, headers = polling_response(http_response) headers['Retry-After'] = "10" assert polling._extract_delay() == 10 @pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="https://stackoverflow.com/questions/11146725/isinstance-and-mocking") -def test_delay_extraction_httpdate(polling_response): - polling, headers = polling_response +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_delay_extraction_httpdate(polling_response, http_response): + polling, headers = polling_response(http_response) # Test that I need to retry exactly one hour after, by mocking "now" headers['Retry-After'] = "Mon, 20 Nov 1995 19:12:08 -0500" @@ -183,13 +197,15 @@ def test_delay_extraction_httpdate(polling_response): assert polling._extract_delay() == 60*60 # one hour in seconds assert str(mock_datetime.now.call_args[0][0]) == "" - -def test_post(pipeline_client_builder, deserialization_cb): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, REQUESTS_TRANSPORT_RESPONSES)) +def test_post(pipeline_client_builder, deserialization_cb, http_request, http_response): # Test POST LRO with both Location and Operation-Location # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, { @@ -204,12 +220,16 @@ def send(request, **kwargs): if request.url == 'http://example.org/location': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body={'location_result': True} ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body={'status': 'Succeeded'} @@ -235,12 +255,16 @@ def send(request, **kwargs): if request.url == 'http://example.org/location': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body=None ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body={'status': 'Succeeded'} @@ -258,13 +282,15 @@ def send(request, **kwargs): result = poll.result() assert result is None - -def test_post_resource_location(pipeline_client_builder, deserialization_cb): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, REQUESTS_TRANSPORT_RESPONSES)) +def test_post_resource_location(pipeline_client_builder, deserialization_cb, http_request, http_response): # ResourceLocation # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, { @@ -278,12 +304,16 @@ def send(request, **kwargs): if request.url == 'http://example.org/resource_location': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body={'location_result': True} ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, + http_response, 'GET', 200, body={'status': 'Succeeded', 'resourceLocation': 'http://example.org/resource_location'} @@ -307,7 +337,7 @@ class TestBasePolling(object): convert = re.compile('([a-z0-9])([A-Z])') @staticmethod - def mock_send(method, status, headers=None, body=RESPONSE_BODY): + def mock_send(http_request, http_response, method, status, headers=None, body=RESPONSE_BODY): if headers is None: headers = {} response = Response() @@ -324,19 +354,28 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): response.headers.update({"content-type": "application/json; charset=utf8"}) response.reason = "OK" - request = CLIENT._request( - response.request.method, - response.request.url, - None, # params - response.request.headers, - body, - None, # form_content - None # stream_content - ) + if is_rest_http_request(http_request): + request = http_request( + response.request.method, + response.request.url, + headers=response.request.headers, + content=body, + ) + else: + request = CLIENT._request( + response.request.method, + response.request.url, + None, # params + response.request.headers, + body, + None, # form_content + None # stream_content + ) return PipelineResponse( request, - RequestsTransportResponse( + create_http_response( + http_response, request, response, ), @@ -344,7 +383,7 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): ) @staticmethod - def mock_update(url, headers=None): + def mock_update(http_request, http_response, url, headers=None): response = Response() response._content_consumed = True response.request = mock.create_autospec(Request) @@ -376,19 +415,27 @@ def mock_update(url, headers=None): else: raise Exception('URL does not match') - request = CLIENT._request( - response.request.method, - response.request.url, - None, # params - {}, # request has no headers - None, # Request has no body - None, # form_content - None # stream_content - ) + if is_rest_http_request(http_request): + request = http_request( + response.request.method, + response.request.url, + headers={}, + ) + else: + request = CLIENT._request( + response.request.method, + response.request.url, + None, # params + {}, # request has no headers + None, # Request has no body + None, # form_content + None # stream_content + ) return PipelineResponse( request, - RequestsTransportResponse( + create_http_response( + http_response, request, response, ), @@ -425,11 +472,14 @@ def mock_deserialization_no_body(pipeline_response): """ return None - def test_long_running_put(self): + @pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, REQUESTS_TRANSPORT_RESPONSES)) + def test_long_running_put(self, http_request, http_response): #TODO: Test custom header field # Test throw on non LRO related status code - response = TestBasePolling.mock_send('PUT', 1000, {}) + response = TestBasePolling.mock_send(http_request, http_response,'PUT', 1000, {}) + CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response with pytest.raises(HttpResponseError): LROPoller(CLIENT, response, TestBasePolling.mock_outputs, @@ -441,6 +491,8 @@ def test_long_running_put(self): 'name': TEST_NAME } response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {}, response_body ) @@ -455,6 +507,8 @@ def no_update_allowed(url, headers=None): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {'operation-location': ASYNC_URL}) poll = LROPoller(CLIENT, response, @@ -465,6 +519,8 @@ def no_update_allowed(url, headers=None): # Test polling location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {'location': LOCATION_URL}) poll = LROPoller(CLIENT, response, @@ -476,6 +532,8 @@ def no_update_allowed(url, headers=None): # Test polling initial payload invalid (SQLDb) response_body = {} # Empty will raise response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {'location': LOCATION_URL}, response_body) poll = LROPoller(CLIENT, response, @@ -486,6 +544,8 @@ def no_update_allowed(url, headers=None): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -495,6 +555,8 @@ def no_update_allowed(url, headers=None): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PUT', 201, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -502,10 +564,14 @@ def no_update_allowed(url, headers=None): TestBasePolling.mock_outputs, LROBasePolling(0)).result() - def test_long_running_patch(self): - + @pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, REQUESTS_TRANSPORT_RESPONSES)) + def test_long_running_patch(self, http_request, http_response): + CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test polling from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -517,6 +583,8 @@ def test_long_running_patch(self): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -528,6 +596,8 @@ def test_long_running_patch(self): # Test polling from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 200, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -539,6 +609,8 @@ def test_long_running_patch(self): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 200, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -550,6 +622,8 @@ def test_long_running_patch(self): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -559,6 +633,8 @@ def test_long_running_patch(self): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'PATCH', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -566,9 +642,14 @@ def test_long_running_patch(self): TestBasePolling.mock_outputs, LROBasePolling(0)).result() - def test_long_running_delete(self): + @pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, REQUESTS_TRANSPORT_RESPONSES)) + def test_long_running_delete(self, http_request, http_response): + CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'DELETE', 202, {'operation-location': ASYNC_URL}, body="" @@ -579,11 +660,16 @@ def test_long_running_delete(self): poll.wait() assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None - def test_long_running_post_legacy(self): + @pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, REQUESTS_TRANSPORT_RESPONSES)) + def test_long_running_post_legacy(self, http_request, http_response): + CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Former oooooold tests to refactor one day to something more readble # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 201, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -595,6 +681,8 @@ def test_long_running_post_legacy(self): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -606,6 +694,8 @@ def test_long_running_post_legacy(self): # Test polling from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -617,6 +707,8 @@ def test_long_running_post_legacy(self): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -626,6 +718,8 @@ def test_long_running_post_legacy(self): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -633,13 +727,18 @@ def test_long_running_post_legacy(self): TestBasePolling.mock_outputs, LROBasePolling(0)).result() - def test_long_running_negative(self): + @pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, REQUESTS_TRANSPORT_RESPONSES)) + def test_long_running_negative(self, http_request, http_response): + CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response global LOCATION_BODY global POLLING_STATUS # Test LRO PUT throws for invalid json LOCATION_BODY = '{' response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = LROPoller( @@ -653,6 +752,8 @@ def test_long_running_negative(self): LOCATION_BODY = '{\'"}' response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = LROPoller(CLIENT, response, @@ -664,6 +765,8 @@ def test_long_running_negative(self): LOCATION_BODY = '{' POLLING_STATUS = 203 response = TestBasePolling.mock_send( + http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = LROPoller(CLIENT, response, diff --git a/sdk/core/azure-core/tests/test_basic_transport.py b/sdk/core/azure-core/tests/test_basic_transport.py index 7c57f53dfeee..aa707a5814b9 100644 --- a/sdk/core/azure-core/tests/test_basic_transport.py +++ b/sdk/core/azure-core/tests/test_basic_transport.py @@ -12,28 +12,58 @@ except ImportError: import mock -from azure.core.pipeline.transport import HttpRequest, HttpResponse, RequestsTransport -from azure.core.pipeline.transport._base import HttpClientTransportResponse, HttpTransport, _deserialize_response, _urljoin +from azure.core.pipeline.transport import ( + HttpResponse as PipelineTransportHttpResponse, RequestsTransport +) +from azure.core.pipeline.transport._base import ( + HttpClientTransportResponse as PipelineTransportHttpClientTransportResponse, + RestHttpClientTransportResponse, + HttpTransport, + _deserialize_response, + _urljoin, + RestHttpResponseImpl, +) from azure.core.pipeline.policies import HeadersPolicy from azure.core.pipeline import Pipeline from azure.core.exceptions import HttpResponseError import logging import pytest +from utils import HTTP_REQUESTS, pipeline_transport_and_rest_product, create_http_response, is_rest_http_response - -class MockResponse(HttpResponse): +class PipelineTransportMockResponse(PipelineTransportHttpResponse): def __init__(self, request, body, content_type): - super(MockResponse, self).__init__(request, None) + super(PipelineTransportMockResponse, self).__init__(request, None) self._body = body self.content_type = content_type def body(self): return self._body +class RestMockResponse(RestHttpResponseImpl): + def __init__(self, request, body, content_type): + super(RestMockResponse, self).__init__( + request=request, + internal_response=None, + content_type=content_type, + block_size=None, + status_code=200, + reason="OK", + headers={}, + stream_download_generator=None, + ) + self._content = body + + @property + def content(self): + return self._content + +MOCK_RESPONSES = [PipelineTransportMockResponse, RestMockResponse] + @pytest.mark.skipif(sys.version_info < (3, 6), reason="Multipart serialization not supported on 2.7 + dict order not deterministic on 3.5") -def test_http_request_serialization(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_http_request_serialization(http_request): # Method + Url - request = HttpRequest("DELETE", "/container0/blob0") + request = http_request("DELETE", "/container0/blob0") serialized = request.serialize() expected = ( @@ -44,7 +74,7 @@ def test_http_request_serialization(): assert serialized == expected # Method + Url + Headers - request = HttpRequest( + request = http_request( "DELETE", "/container0/blob0", # Use OrderedDict to get consistent test result on 3.5 where order is not guaranteed @@ -67,7 +97,7 @@ def test_http_request_serialization(): # Method + Url + Headers + Body - request = HttpRequest( + request = http_request( "DELETE", "/container0/blob0", headers={ @@ -87,23 +117,28 @@ def test_http_request_serialization(): assert serialized == expected -def test_url_join(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_url_join(http_request): assert _urljoin('devstoreaccount1', '') == 'devstoreaccount1/' assert _urljoin('devstoreaccount1', 'testdir/') == 'devstoreaccount1/testdir/' assert _urljoin('devstoreaccount1/', '') == 'devstoreaccount1/' assert _urljoin('devstoreaccount1/', 'testdir/') == 'devstoreaccount1/testdir/' -def test_http_client_response(port): +CLIENT_TRANSPORT_RESPONSES = [PipelineTransportHttpClientTransportResponse, RestHttpClientTransportResponse] +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, CLIENT_TRANSPORT_RESPONSES)) +def test_http_client_response(port, http_request, http_response): # Create a core request - request = HttpRequest("GET", "http://localhost:{}".format(port)) + request = http_request("GET", "http://localhost:{}".format(port)) # Fake a transport based on http.client conn = HTTPConnection("localhost", port) conn.request("GET", "/get") r1 = conn.getresponse() - response = HttpClientTransportResponse(request, r1) + response = create_http_response(http_response, request, r1) + if is_rest_http_response(response): + response.read() # Don't assume too much in those assert, since we reach a real server assert response.internal_response is r1 @@ -115,10 +150,11 @@ def test_http_client_response(port): assert "Content-Type" in response.headers -def test_response_deserialization(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_response_deserialization(http_request): # Method + Url - request = HttpRequest("DELETE", "/container0/blob0") + request = http_request("DELETE", "/container0/blob0") body = ( b'HTTP/1.1 202 Accepted\r\n' b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' @@ -135,7 +171,7 @@ def test_response_deserialization(): } # Method + Url + Headers + Body - request = HttpRequest( + request = http_request( "DELETE", "/container0/blob0", headers={ @@ -161,9 +197,10 @@ def test_response_deserialization(): } assert response.text() == "I am groot" -def test_response_deserialization_utf8_bom(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_response_deserialization_utf8_bom(http_request): - request = HttpRequest("DELETE", "/container0/blob0") + request = http_request("DELETE", "/container0/blob0") body = ( b'HTTP/1.1 400 One of the request inputs is not valid.\r\n' b'x-ms-error-code: InvalidInput\r\n' @@ -181,7 +218,8 @@ def test_response_deserialization_utf8_bom(): @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -189,10 +227,10 @@ def test_multipart_send(): 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -227,17 +265,18 @@ def test_multipart_send(): @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_context(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_context(http_request): transport = mock.MagicMock(spec=HttpTransport) header_policy = HeadersPolicy({ 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -275,7 +314,8 @@ def test_multipart_send_with_context(): @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_one_changeset(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_one_changeset(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -284,18 +324,18 @@ def test_multipart_send_with_one_changeset(): }) requests = [ - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ] - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( *requests, policies=[header_policy], boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset, boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", @@ -333,7 +373,8 @@ def test_multipart_send_with_one_changeset(): @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_multiple_changesets(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_multiple_changesets(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -341,22 +382,22 @@ def test_multipart_send_with_multiple_changesets(): 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - changeset1 = HttpRequest("", "") + changeset1 = http_request("", "") changeset1.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1"), policies=[header_policy], boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - changeset2 = HttpRequest("", "") + changeset2 = http_request("", "") changeset2.set_multipart_mixed( - HttpRequest("DELETE", "/container2/blob2"), - HttpRequest("DELETE", "/container3/blob3"), + http_request("DELETE", "/container2/blob2"), + http_request("DELETE", "/container3/blob3"), policies=[header_policy], boundary="changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset1, changeset2, @@ -419,7 +460,8 @@ def test_multipart_send_with_multiple_changesets(): @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_combination_changeset_first(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_combination_changeset_first(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -427,17 +469,17 @@ def test_multipart_send_with_combination_changeset_first(): 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1"), policies=[header_policy], boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset, - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container2/blob2"), policies=[header_policy], boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) @@ -482,7 +524,8 @@ def test_multipart_send_with_combination_changeset_first(): ) @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_combination_changeset_last(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_combination_changeset_last(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -490,16 +533,16 @@ def test_multipart_send_with_combination_changeset_last(): 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container1/blob1"), + http_request("DELETE", "/container2/blob2"), policies=[header_policy], boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, policies=[header_policy], boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" @@ -545,7 +588,8 @@ def test_multipart_send_with_combination_changeset_last(): ) @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_combination_changeset_middle(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_combination_changeset_middle(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -553,17 +597,17 @@ def test_multipart_send_with_combination_changeset_middle(): 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container1/blob1"), policies=[header_policy], boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container2/blob2"), policies=[header_policy], boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) @@ -608,17 +652,18 @@ def test_multipart_send_with_combination_changeset_middle(): ) -def test_multipart_receive(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +def test_multipart_receive(http_request, mock_response): class ResponsePolicy(object): def on_response(self, request, response): # type: (PipelineRequest, PipelineResponse) -> None response.http_response.headers['x-ms-fun'] = 'true' - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -652,7 +697,7 @@ def on_response(self, request, response): "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" ) - response = MockResponse( + response = mock_response( request, body_as_str.encode('ascii'), "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -670,27 +715,30 @@ def on_response(self, request, response): assert res1.status_code == 404 assert res1.headers['x-ms-fun'] == 'true' -def test_raise_for_status_bad_response(): - response = MockResponse(request=None, body=None, content_type=None) +@pytest.mark.parametrize("mock_response", MOCK_RESPONSES) +def test_raise_for_status_bad_response(mock_response): + response = mock_response(request=None, body=None, content_type=None) response.status_code = 400 with pytest.raises(HttpResponseError): response.raise_for_status() -def test_raise_for_status_good_response(): - response = MockResponse(request=None, body=None, content_type=None) +@pytest.mark.parametrize("mock_response", MOCK_RESPONSES) +def test_raise_for_status_good_response(mock_response): + response = mock_response(request=None, body=None, content_type=None) response.status_code = 200 response.raise_for_status() -def test_multipart_receive_with_one_changeset(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +def test_multipart_receive_with_one_changeset(http_request, mock_response): - changeset = HttpRequest(None, None) + changeset = http_request(None, None) changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset) body_as_bytes = ( @@ -722,7 +770,7 @@ def test_multipart_receive_with_one_changeset(): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -737,20 +785,21 @@ def test_multipart_receive_with_one_changeset(): assert res0.status_code == 202 -def test_multipart_receive_with_multiple_changesets(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +def test_multipart_receive_with_multiple_changesets(http_request, mock_response): - changeset1 = HttpRequest(None, None) + changeset1 = http_request(None, None) changeset1.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - changeset2 = HttpRequest(None, None) + changeset2 = http_request(None, None) changeset2.set_multipart_mixed( - HttpRequest("DELETE", "/container2/blob2"), - HttpRequest("DELETE", "/container3/blob3") + http_request("DELETE", "/container2/blob2"), + http_request("DELETE", "/container3/blob3") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset1, changeset2) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -806,7 +855,7 @@ def test_multipart_receive_with_multiple_changesets(): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -822,16 +871,17 @@ def test_multipart_receive_with_multiple_changesets(): assert parts[3].status_code == 409 -def test_multipart_receive_with_combination_changeset_first(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +def test_multipart_receive_with_combination_changeset_first(http_request, mock_response): - changeset = HttpRequest(None, None) + changeset = http_request(None, None) changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") - request.set_multipart_mixed(changeset, HttpRequest("DELETE", "/container2/blob2")) + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(changeset, http_request("DELETE", "/container2/blob2")) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' @@ -871,7 +921,7 @@ def test_multipart_receive_with_combination_changeset_first(): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -886,16 +936,17 @@ def test_multipart_receive_with_combination_changeset_first(): assert parts[2].status_code == 404 -def test_multipart_receive_with_combination_changeset_middle(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +def test_multipart_receive_with_combination_changeset_middle(http_request, mock_response): - changeset = HttpRequest(None, None) - changeset.set_multipart_mixed(HttpRequest("DELETE", "/container1/blob1")) + changeset = http_request(None, None) + changeset.set_multipart_mixed(http_request("DELETE", "/container1/blob1")) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, - HttpRequest("DELETE", "/container2/blob2") + http_request("DELETE", "/container2/blob2") ) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -936,7 +987,7 @@ def test_multipart_receive_with_combination_changeset_middle(): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -951,16 +1002,17 @@ def test_multipart_receive_with_combination_changeset_middle(): assert parts[2].status_code == 404 -def test_multipart_receive_with_combination_changeset_last(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +def test_multipart_receive_with_combination_changeset_last(http_request, mock_response): - changeset = HttpRequest(None, None) + changeset = http_request(None, None) changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), - HttpRequest("DELETE", "/container2/blob2") + http_request("DELETE", "/container1/blob1"), + http_request("DELETE", "/container2/blob2") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") - request.set_multipart_mixed(HttpRequest("DELETE", "/container0/blob0"), changeset) + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(http_request("DELETE", "/container0/blob0"), changeset) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -1001,7 +1053,7 @@ def test_multipart_receive_with_combination_changeset_last(): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -1016,11 +1068,12 @@ def test_multipart_receive_with_combination_changeset_last(): assert parts[2].status_code == 404 -def test_multipart_receive_with_bom(): +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +def test_multipart_receive_with_bom(http_request, mock_response): - req0 = HttpRequest("DELETE", "/container0/blob0") + req0 = http_request("DELETE", "/container0/blob0") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(req0) body_as_bytes = ( b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\n" @@ -1038,7 +1091,7 @@ def test_multipart_receive_with_bom(): b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -1052,12 +1105,13 @@ def test_multipart_receive_with_bom(): assert res0.body().startswith(b'\xef\xbb\xbf') -def test_recursive_multipart_receive(): - req0 = HttpRequest("DELETE", "/container0/blob0") - internal_req0 = HttpRequest("DELETE", "/container0/blob0") +@pytest.mark.parametrize("http_request,mock_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, MOCK_RESPONSES)) +def test_recursive_multipart_receive(http_request, mock_response): + req0 = http_request("DELETE", "/container0/blob0") + internal_req0 = http_request("DELETE", "/container0/blob0") req0.set_multipart_mixed(internal_req0) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(req0) internal_body_as_str = ( "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" @@ -1083,7 +1137,7 @@ def test_recursive_multipart_receive(): "--batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6--" ).format(internal_body_as_str) - response = MockResponse( + response = mock_response( request, body_as_str.encode('ascii'), "multipart/mixed; boundary=batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6" @@ -1102,15 +1156,16 @@ def test_recursive_multipart_receive(): assert internal_response0.status_code == 400 -def test_close_unopened_transport(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_close_unopened_transport(http_request): transport = RequestsTransport() transport.close() - -def test_timeout(caplog, port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_timeout(caplog, port, http_request): transport = RequestsTransport() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) with caplog.at_level(logging.WARNING, logger="azure.core.pipeline.transport"): with Pipeline(transport) as pipeline: @@ -1118,11 +1173,11 @@ def test_timeout(caplog, port): assert "Tuple timeout setting is deprecated" not in caplog.text - -def test_tuple_timeout(caplog, port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_tuple_timeout(caplog, port, http_request): transport = RequestsTransport() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) with caplog.at_level(logging.WARNING, logger="azure.core.pipeline.transport"): with Pipeline(transport) as pipeline: @@ -1130,11 +1185,11 @@ def test_tuple_timeout(caplog, port): assert "Tuple timeout setting is deprecated" in caplog.text - -def test_conflict_timeout(caplog, port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_conflict_timeout(caplog, port, http_request): transport = RequestsTransport() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) with pytest.raises(ValueError): with Pipeline(transport) as pipeline: @@ -1147,4 +1202,4 @@ def test_aiohttp_loop(): from azure.core.pipeline.transport import AioHttpTransport loop = asyncio._get_running_loop() with pytest.raises(ValueError): - transport = AioHttpTransport(loop=loop) + transport = AioHttpTransport(loop=loop) \ No newline at end of file diff --git a/sdk/core/azure-core/tests/test_custom_hook_policy.py b/sdk/core/azure-core/tests/test_custom_hook_policy.py index e553ca5a9811..2bc4e5eb7960 100644 --- a/sdk/core/azure-core/tests/test_custom_hook_policy.py +++ b/sdk/core/azure-core/tests/test_custom_hook_policy.py @@ -11,8 +11,10 @@ from azure.core.pipeline.policies import CustomHookPolicy, UserAgentPolicy from azure.core.pipeline.transport import HttpTransport import pytest +from utils import HTTP_REQUESTS -def test_response_hook_policy_in_init(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_response_hook_policy_in_init(http_request): def test_callback(response): raise ValueError() @@ -23,12 +25,13 @@ def test_callback(response): UserAgentPolicy("myuseragent"), custom_hook_policy ] - client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + client = PipelineClient(base_url=url, policies=policies, transport=transport) + request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request) -def test_response_hook_policy_in_request(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_response_hook_policy_in_request(http_request): def test_callback(response): raise ValueError() @@ -40,11 +43,12 @@ def test_callback(response): custom_hook_policy ] client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request, raw_response_hook=test_callback) -def test_response_hook_policy_in_both(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_response_hook_policy_in_both(http_request): def test_callback(response): raise ValueError() @@ -59,11 +63,12 @@ def test_callback_request(response): custom_hook_policy ] client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + request = http_request("GET", url) with pytest.raises(TypeError): client._pipeline.run(request, raw_response_hook=test_callback_request) -def test_request_hook_policy_in_init(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_hook_policy_in_init(http_request): def test_callback(response): raise ValueError() @@ -75,11 +80,12 @@ def test_callback(response): custom_hook_policy ] client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request) -def test_request_hook_policy_in_request(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_hook_policy_in_request(http_request): def test_callback(response): raise ValueError() @@ -91,11 +97,12 @@ def test_callback(response): custom_hook_policy ] client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request, raw_request_hook=test_callback) -def test_request_hook_policy_in_both(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_hook_policy_in_both(http_request): def test_callback(response): raise ValueError() @@ -110,6 +117,6 @@ def test_callback_request(response): custom_hook_policy ] client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + request = http_request("GET", url) with pytest.raises(TypeError): client._pipeline.run(request, raw_request_hook=test_callback_request) diff --git a/sdk/core/azure-core/tests/test_error_map.py b/sdk/core/azure-core/tests/test_error_map.py index 6fa355b490ab..93c85e541899 100644 --- a/sdk/core/azure-core/tests/test_error_map.py +++ b/sdk/core/azure-core/tests/test_error_map.py @@ -30,41 +30,42 @@ map_error, ErrorMap, ) -from azure.core.pipeline.transport import ( - HttpRequest, - HttpResponse, -) +from utils import create_http_response, pipeline_transport_and_rest_product, HTTP_RESPONSES, HTTP_REQUESTS -def test_error_map(): - request = HttpRequest("GET", "") - response = HttpResponse(request, None) +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_error_map(http_request, http_response): + request = http_request("GET", "") + response = create_http_response(http_response, request, None) error_map = { 404: ResourceNotFoundError } with pytest.raises(ResourceNotFoundError): map_error(404, response, error_map) -def test_error_map_no_default(): - request = HttpRequest("GET", "") - response = HttpResponse(request, None) +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_error_map_no_default(http_request, http_response): + request = http_request("GET", "") + response = create_http_response(http_response, request, None) error_map = ErrorMap({ 404: ResourceNotFoundError }) with pytest.raises(ResourceNotFoundError): map_error(404, response, error_map) -def test_error_map_with_default(): - request = HttpRequest("GET", "") - response = HttpResponse(request, None) +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_error_map_with_default(http_request, http_response): + request = http_request("GET", "") + response = create_http_response(http_response, request, None) error_map = ErrorMap({ 404: ResourceNotFoundError }, default_error=ResourceExistsError) with pytest.raises(ResourceExistsError): map_error(401, response, error_map) -def test_only_default(): - request = HttpRequest("GET", "") - response = HttpResponse(request, None) +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_only_default(http_request, http_response): + request = http_request("GET", "") + response = create_http_response(http_response, request, None) error_map = ErrorMap(default_error=ResourceExistsError) with pytest.raises(ResourceExistsError): map_error(401, response, error_map) diff --git a/sdk/core/azure-core/tests/test_exceptions.py b/sdk/core/azure-core/tests/test_exceptions.py index a5d494e559b7..391fb56cf646 100644 --- a/sdk/core/azure-core/tests/test_exceptions.py +++ b/sdk/core/azure-core/tests/test_exceptions.py @@ -24,6 +24,7 @@ # # -------------------------------------------------------------------------- import json +import pytest import requests try: from unittest.mock import Mock @@ -33,27 +34,46 @@ # module under test from azure.core.exceptions import HttpResponseError, ODataV4Error, ODataV4Format -from azure.core.pipeline.transport import RequestsTransportResponse -from azure.core.pipeline.transport._base import _HttpResponseBase - - -def _build_response(json_body): - class MockResponse(_HttpResponseBase): - def __init__(self): - super(MockResponse, self).__init__( - request=None, - internal_response = None, - ) - self.status_code = 400 - self.reason = "Bad Request" - self.content_type = "application/json" - self._body = json_body - - def body(self): - return self._body - - return MockResponse() - +from azure.core.pipeline.transport._base import ( + _HttpResponseBase as PipelineTransportHttpResponseBase, + _RestHttpResponseBaseImpl +) +from utils import is_rest_http_response, create_http_response, REQUESTS_TRANSPORT_RESPONSES + + + +class PipelineTransportMockResponse(PipelineTransportHttpResponseBase): + def __init__(self, json_body): + super(PipelineTransportMockResponse, self).__init__( + request=None, + internal_response = None, + ) + self.status_code = 400 + self.reason = "Bad Request" + self.content_type = "application/json" + self._body = json_body + + def body(self): + return self._body + +class RestMockResponse(_RestHttpResponseBaseImpl): + def __init__(self, json_body): + super().__init__( + request=None, + internal_response=None, + status_code=400, + reason="Bad Request", + content_type="application/json", + headers={}, + stream_download_generator=None, + ) + self._body = json_body + + @property + def content(self): + return self._body + +MOCK_RESPONSES = [PipelineTransportMockResponse, RestMockResponse] class FakeErrorOne(object): @@ -105,7 +125,8 @@ def test_error_continuation_token(self): assert error.status_code is None assert error.continuation_token == 'foo' - def test_deserialized_httpresponse_error_code(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_deserialized_httpresponse_error_code(self, mock_response): """This is backward compat support of autorest azure-core (KV 4.0.0, Storage 12.0.0). Do NOT adapt this test unless you know what you're doing. @@ -116,7 +137,7 @@ def test_deserialized_httpresponse_error_code(self): "message": "A fake error", } } - response = _build_response(json.dumps(message).encode("utf-8")) + response = mock_response(json.dumps(message).encode("utf-8")) error = FakeHttpResponse(response, FakeErrorOne()) assert "(FakeErrorOne) A fake error" in error.message assert "(FakeErrorOne) A fake error" in str(error.error) @@ -133,7 +154,8 @@ def test_deserialized_httpresponse_error_code(self): assert error.error.error.message == "A fake error" - def test_deserialized_httpresponse_error_message(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_deserialized_httpresponse_error_message(self, mock_response): """This is backward compat support for weird responses, adn even if it's likely just the autorest testserver, should be fine parsing. @@ -143,7 +165,7 @@ def test_deserialized_httpresponse_error_message(self): "code": "FakeErrorTwo", "message": "A different fake error", } - response = _build_response(json.dumps(message).encode("utf-8")) + response = mock_response(json.dumps(message).encode("utf-8")) error = FakeHttpResponse(response, FakeErrorTwo()) assert "(FakeErrorTwo) A different fake error" in error.message assert "(FakeErrorTwo) A different fake error" in str(error.error) @@ -155,9 +177,10 @@ def test_deserialized_httpresponse_error_message(self): assert isinstance(error.model, FakeErrorTwo) assert isinstance(error.error, ODataV4Format) - def test_httpresponse_error_with_response(self, port): + @pytest.mark.parametrize("requests_transport_response", REQUESTS_TRANSPORT_RESPONSES) + def test_httpresponse_error_with_response(self, port, requests_transport_response): response = requests.get("http://localhost:{}/basic/string".format(port)) - http_response = RequestsTransportResponse(None, response) + http_response = create_http_response(requests_transport_response, None, response) error = HttpResponseError(response=http_response) assert error.message == "Operation returned an invalid status 'OK'" @@ -166,7 +189,8 @@ def test_httpresponse_error_with_response(self, port): assert isinstance(error.status_code, int) assert error.error is None - def test_odata_v4_exception(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_odata_v4_exception(self, mock_response): message = { "error": { "code": "501", @@ -183,7 +207,7 @@ def test_odata_v4_exception(self): } } } - exp = ODataV4Error(_build_response(json.dumps(message).encode("utf-8"))) + exp = ODataV4Error(mock_response(json.dumps(message).encode("utf-8"))) assert exp.code == "501" assert exp.message == "Unsupported functionality" @@ -194,14 +218,15 @@ def test_odata_v4_exception(self): assert "context" in exp.innererror message = {} - exp = ODataV4Error(_build_response(json.dumps(message).encode("utf-8"))) + exp = ODataV4Error(mock_response(json.dumps(message).encode("utf-8"))) assert exp.message == "Operation returned an invalid status 'Bad Request'" - exp = ODataV4Error(_build_response(b"")) + exp = ODataV4Error(mock_response(b"")) assert exp.message == "Operation returned an invalid status 'Bad Request'" assert str(exp) == "Operation returned an invalid status 'Bad Request'" - def test_odata_v4_minimal(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_odata_v4_minimal(self, mock_response): """Minimal valid OData v4 is code/message and nothing else. """ message = { @@ -210,14 +235,15 @@ def test_odata_v4_minimal(self): "message": "Unsupported functionality", } } - exp = ODataV4Error(_build_response(json.dumps(message).encode("utf-8"))) + exp = ODataV4Error(mock_response(json.dumps(message).encode("utf-8"))) assert exp.code == "501" assert exp.message == "Unsupported functionality" assert exp.target is None assert exp.details == [] assert exp.innererror == {} - def test_broken_odata_details(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_broken_odata_details(self, mock_response): """Do not block creating a nice exception if "details" only is broken """ message = { @@ -244,10 +270,11 @@ def test_broken_odata_details(self): "innererror": None, } } - exp = HttpResponseError(response=_build_response(json.dumps(message).encode("utf-8"))) + exp = HttpResponseError(response=mock_response(json.dumps(message).encode("utf-8"))) assert exp.error.code == "Conflict" - def test_null_odata_details(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_null_odata_details(self, mock_response): message = { "error": { "code": "501", @@ -257,5 +284,5 @@ def test_null_odata_details(self): "innererror": None, } } - exp = HttpResponseError(response=_build_response(json.dumps(message).encode("utf-8"))) + exp = HttpResponseError(response=mock_response(json.dumps(message).encode("utf-8"))) assert exp.error.code == "501" \ No newline at end of file diff --git a/sdk/core/azure-core/tests/test_http_logging_policy.py b/sdk/core/azure-core/tests/test_http_logging_policy.py index c69b8b5f0e7c..9423a4a31805 100644 --- a/sdk/core/azure-core/tests/test_http_logging_policy.py +++ b/sdk/core/azure-core/tests/test_http_logging_policy.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ """Tests for the HttpLoggingPolicy.""" - +import pytest import logging import types try: @@ -15,16 +15,14 @@ PipelineRequest, PipelineContext ) -from azure.core.pipeline.transport import ( - HttpRequest, - HttpResponse, -) from azure.core.pipeline.policies import ( HttpLoggingPolicy, ) +from utils import create_http_response, HTTP_RESPONSES, pipeline_transport_and_rest_product, HTTP_REQUESTS -def test_http_logger(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_http_logger(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -42,8 +40,8 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') - http_response = HttpResponse(universal_request, None) + universal_request = http_request('GET', 'http://localhost/') + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) @@ -138,7 +136,8 @@ def emit(self, record): -def test_http_logger_operation_level(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_http_logger_operation_level(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -157,8 +156,8 @@ def emit(self, record): policy = HttpLoggingPolicy() kwargs={'logger': logger} - universal_request = HttpRequest('GET', 'http://localhost/') - http_response = HttpResponse(universal_request, None) + universal_request = http_request('GET', 'http://localhost/') + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None, **kwargs)) @@ -209,7 +208,8 @@ def emit(self, record): mock_handler.reset() -def test_http_logger_with_body(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_http_logger_with_body(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -227,9 +227,9 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') universal_request.body = "testbody" - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) @@ -249,7 +249,8 @@ def emit(self, record): mock_handler.reset() -def test_http_logger_with_generator_body(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_http_logger_with_generator_body(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -267,11 +268,11 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') mock = Mock() mock.__class__ = types.GeneratorType universal_request.body = mock - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) diff --git a/sdk/core/azure-core/tests/test_pipeline.py b/sdk/core/azure-core/tests/test_pipeline.py index 6d260c07e9fe..8d378eb91bd5 100644 --- a/sdk/core/azure-core/tests/test_pipeline.py +++ b/sdk/core/azure-core/tests/test_pipeline.py @@ -50,12 +50,11 @@ ) from azure.core.pipeline.transport._base import PipelineClientBase from azure.core.pipeline.transport import ( - HttpRequest, HttpTransport, RequestsTransport, ) - from azure.core.exceptions import AzureError +from utils import is_rest_http_request, HTTP_REQUESTS def test_default_http_logging_policy(): config = Configuration() @@ -77,8 +76,8 @@ def test_pass_in_http_logging_policy(): http_logging_policy = pipeline._impl_policies[-1]._policy assert http_logging_policy.allowed_header_names == HttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union({"x-ms-added-header"}) - -def test_sans_io_exception(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_sans_io_exception(http_request): class BrokenSender(HttpTransport): def send(self, request, **config): raise ValueError("Broken") @@ -95,7 +94,7 @@ def __exit__(self, exc_type, exc_value, traceback): pipeline = Pipeline(BrokenSender(), [SansIOHTTPPolicy()]) - req = HttpRequest("GET", "/") + req = http_request("GET", "/") with pytest.raises(ValueError): pipeline.run(req) @@ -108,9 +107,10 @@ def on_exception(self, requests, **kwargs): with pytest.raises(NotImplementedError): pipeline.run(req) -def test_requests_socket_timeout(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_requests_socket_timeout(http_request): conf = Configuration() - request = HttpRequest("GET", "https://bing.com") + request = http_request("GET", "https://bing.com") policies = [ UserAgentPolicy("myusergant"), RedirectPolicy() @@ -179,26 +179,29 @@ def test_format_incorrect_endpoint(): client.format_url("foo/bar") assert str(exp.value) == "The value provided for the url part Endpoint was incorrect, and resulted in an invalid url" -def test_request_json(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_json(http_request): - request = HttpRequest("GET", "/") + request = http_request("GET", "/") data = "Lots of dataaaa" request.set_json_body(data) assert request.data == json.dumps(data) assert request.headers.get("Content-Length") == "17" -def test_request_data(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_data(http_request): - request = HttpRequest("GET", "/") + request = http_request("GET", "/") data = "Lots of dataaaa" request.set_bytes_body(data) assert request.data == data assert request.headers.get("Content-Length") == "15" -def test_request_stream(): - request = HttpRequest("GET", "/") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_stream(http_request): + request = http_request("GET", "/") data = b"Lots of dataaaa" request.set_streamed_data_body(data) @@ -216,45 +219,51 @@ def data_gen(): assert request.data == data -def test_request_xml(): - request = HttpRequest("GET", "/") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_xml(http_request): + request = http_request("GET", "/") data = ET.Element("root") request.set_xml_body(data) assert request.data == b"\n" -def test_request_url_with_params(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_url_with_params(http_request): - request = HttpRequest("GET", "/") + request = http_request("GET", "/") request.url = "a/b/c?t=y" request.format_parameters({"g": "h"}) assert request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_url_with_params_as_list(http_request): -def test_request_url_with_params_as_list(): - - request = HttpRequest("GET", "/") + request = http_request("GET", "/") request.url = "a/b/c?t=y" request.format_parameters({"g": ["h","i"]}) assert request.url in ["a/b/c?g=h&g=i&t=y", "a/b/c?t=y&g=h&g=i"] -def test_request_url_with_params_with_none_in_list(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_url_with_params_with_none_in_list(http_request): - request = HttpRequest("GET", "/") + request = http_request("GET", "/") request.url = "a/b/c?t=y" with pytest.raises(ValueError): request.format_parameters({"g": ["h",None]}) -def test_request_url_with_params_with_none(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_url_with_params_with_none(http_request): - request = HttpRequest("GET", "/") + request = http_request("GET", "/") request.url = "a/b/c?t=y" with pytest.raises(ValueError): request.format_parameters({"g": None}) -def test_repr(): - request = HttpRequest("GET", "hello.com") + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_repr(http_request): + request = http_request("GET", "hello.com") assert repr(request) == "" def test_add_custom_policy(): @@ -355,10 +364,11 @@ def send(*args): with pytest.raises(ValueError): client = PipelineClient(base_url="test", policies=policies, per_retry_policies=[foo_policy]) -def test_basic_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_basic_requests(http_request): conf = Configuration() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "https://bing.com") policies = [ UserAgentPolicy("myusergant"), RedirectPolicy() @@ -369,9 +379,10 @@ def test_basic_requests(port): assert pipeline._transport.session is None assert isinstance(response.http_response.status_code, int) -def test_basic_options_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_basic_options_requests(port, http_request): - request = HttpRequest("OPTIONS", "http://localhost:{}/basic/string".format(port)) + request = http_request("OPTIONS", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), RedirectPolicy() @@ -382,10 +393,11 @@ def test_basic_options_requests(port): assert pipeline._transport.session is None assert isinstance(response.http_response.status_code, int) -def test_basic_requests_separate_session(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_basic_requests_separate_session(http_request): session = requests.Session() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "https://bing.com") policies = [ UserAgentPolicy("myusergant"), RedirectPolicy() @@ -400,21 +412,28 @@ def test_basic_requests_separate_session(port): assert transport.session transport.session.close() -def test_request_text(port): - client = PipelineClientBase("http://localhost:{}".format(port)) - request = client.get( - "/", - content="foo" - ) +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_text(http_request): + client = PipelineClientBase('http://example.org') + if is_rest_http_request(http_request): + request = http_request("GET", "/", json="foo") + else: + request = client.get( + "/", + content="foo" + ) # In absence of information, everything is JSON (double quote added) assert request.data == json.dumps("foo") - request = client.post( - "/", - headers={'content-type': 'text/whatever'}, - content="foo" - ) + if is_rest_http_request(http_request): + request = http_request("POST", "/", headers={'content-type': 'text/whatever'}, content="foo") + else: + request = client.post( + "/", + headers={'content-type': 'text/whatever'}, + content="foo" + ) # We want a direct string assert request.data == "foo" diff --git a/sdk/core/azure-core/tests/test_request_id_policy.py b/sdk/core/azure-core/tests/test_request_id_policy.py index 7da467b467c2..a9e08238f3e5 100644 --- a/sdk/core/azure-core/tests/test_request_id_policy.py +++ b/sdk/core/azure-core/tests/test_request_id_policy.py @@ -4,8 +4,10 @@ # ------------------------------------ """Tests for the request id policy.""" from azure.core.pipeline.policies import RequestIdPolicy -from azure.core.pipeline.transport import HttpRequest +from azure.core.pipeline.transport import HttpRequest as PipelineTransportHttpRequest +from azure.core.rest import HttpRequest as RestHttpRequest from azure.core.pipeline import PipelineRequest, PipelineContext +from utils import HTTP_REQUESTS try: from unittest import mock except ImportError: @@ -17,10 +19,10 @@ request_id_init_values = ("foo", None, "_unset") request_id_set_values = ("bar", None, "_unset") request_id_req_values = ("baz", None, "_unset") -full_combination = list(product(auto_request_id_values, request_id_init_values, request_id_set_values, request_id_req_values)) +full_combination = list(product(auto_request_id_values, request_id_init_values, request_id_set_values, request_id_req_values, HTTP_REQUESTS)) -@pytest.mark.parametrize("auto_request_id, request_id_init, request_id_set, request_id_req", full_combination) -def test_request_id_policy(auto_request_id, request_id_init, request_id_set, request_id_req): +@pytest.mark.parametrize("auto_request_id, request_id_init, request_id_set, request_id_req, http_request", full_combination) +def test_request_id_policy(auto_request_id, request_id_init, request_id_set, request_id_req, http_request): """Test policy with no other policy and happy path""" kwargs = {} if auto_request_id is not None: @@ -30,7 +32,7 @@ def test_request_id_policy(auto_request_id, request_id_init, request_id_set, req request_id_policy = RequestIdPolicy(**kwargs) if request_id_set != "_unset": request_id_policy.set_request_id(request_id_set) - request = HttpRequest('GET', 'http://localhost/') + request = http_request('GET', 'http://localhost/') pipeline_request = PipelineRequest(request, PipelineContext(None)) if request_id_req != "_unset": pipeline_request.context.options['request_id'] = request_id_req @@ -55,10 +57,11 @@ def test_request_id_policy(auto_request_id, request_id_init, request_id_set, req else: assert not "x-ms-client-request-id" in request.headers -def test_request_id_already_exists(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_id_already_exists(http_request): """Test policy with no other policy and happy path""" request_id_policy = RequestIdPolicy() - request = HttpRequest('GET', 'http://localhost/') + request = http_request('GET', 'http://localhost/') request.headers["x-ms-client-request-id"] = "VALUE" pipeline_request = PipelineRequest(request, PipelineContext(None)) request_id_policy.on_request(pipeline_request) diff --git a/sdk/core/azure-core/tests/test_requests_universal.py b/sdk/core/azure-core/tests/test_requests_universal.py index b965d4fd3d69..a24d6a582729 100644 --- a/sdk/core/azure-core/tests/test_requests_universal.py +++ b/sdk/core/azure-core/tests/test_requests_universal.py @@ -25,8 +25,10 @@ # -------------------------------------------------------------------------- import concurrent.futures import requests.utils +import pytest -from azure.core.pipeline.transport import HttpRequest, RequestsTransport, RequestsTransportResponse +from azure.core.pipeline.transport import RequestsTransport +from utils import create_http_response, is_rest_http_response, HTTP_REQUESTS, REQUESTS_TRANSPORT_RESPONSES def test_threading_basic_requests(): @@ -44,14 +46,15 @@ def thread_body(local_sender): future = executor.submit(thread_body, sender) assert future.result() -def test_requests_auto_headers(port): - request = HttpRequest("POST", "http://localhost:{}/basic/string".format(port)) +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_requests_auto_headers(port, http_request): + request = http_request("POST", "http://localhost:{}/basic/string".format(port)) with RequestsTransport() as sender: response = sender.send(request) auto_headers = response.internal_response.request.headers assert 'Content-Type' not in auto_headers -def _create_requests_response(body_bytes, headers=None): +def _create_requests_response(http_response, body_bytes, headers=None): # https://github.com/psf/requests/blob/67a7b2e8336951d527e223429672354989384197/requests/adapters.py#L255 req_response = requests.Response() req_response._content = body_bytes @@ -63,27 +66,34 @@ def _create_requests_response(body_bytes, headers=None): req_response.headers.update(headers) req_response.encoding = requests.utils.get_encoding_from_headers(req_response.headers) - response = RequestsTransportResponse( + response = create_http_response( + http_response, None, # Don't need a request here req_response ) return response - -def test_requests_response_text(): +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_requests_response_text(http_response): for encoding in ["utf-8", "utf-8-sig", None]: res = _create_requests_response( + http_response, b'\xef\xbb\xbf56', {'Content-Type': 'text/plain'} ) assert res.text(encoding) == '56', "Encoding {} didn't work".format(encoding) -def test_repr(): +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_repr(http_response): res = _create_requests_response( + http_response, b'\xef\xbb\xbf56', {'Content-Type': 'text/plain'} ) - assert repr(res) == "" + if is_rest_http_response(http_response): + assert repr(res) == "" + else: + assert repr(res) == "" diff --git a/sdk/core/azure-core/tests/test_rest_http_response.py b/sdk/core/azure-core/tests/test_rest_http_response.py index f3abec23a30a..7d2859ec79c3 100644 --- a/sdk/core/azure-core/tests/test_rest_http_response.py +++ b/sdk/core/azure-core/tests/test_rest_http_response.py @@ -11,9 +11,11 @@ import io import sys import pytest -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse +from azure.core.pipeline.transport._requests_basic import RestRequestsTransportResponse from azure.core.exceptions import HttpResponseError import xml.etree.ElementTree as ET +from utils import readonly_checks @pytest.fixture def send_request(client): @@ -153,16 +155,6 @@ def test_response_no_charset_with_iso_8859_1_content(send_request): assert response.text() == u"Accented: �sterreich" assert response.encoding is None -def test_response_set_explicit_encoding(send_request): - # Deliberately incorrect charset - response = send_request( - request=HttpRequest("GET", "/encoding/latin-1-with-utf-8"), - ) - assert response.headers["Content-Type"] == "text/plain; charset=utf-8" - response.encoding = "latin-1" - assert response.text() == u"Latin 1: ÿ" - assert response.encoding == "latin-1" - def test_json(send_request): response = send_request( request=HttpRequest("GET", "/basic/json"), @@ -311,3 +303,15 @@ def test_text_and_encoding(send_request): # assert latin-1 changes text decoding without changing encoding property assert response.text("latin-1") == u'ð\x9f\x91©' == response.content.decode("latin-1") assert response.encoding == "utf-16" + +def test_initialize_response_abc(): + with pytest.raises(TypeError) as ex: + HttpResponse() + assert "Can't instantiate abstract class" in str(ex) + +def test_readonly(send_request): + """Make sure everything that is readonly is readonly""" + response = send_request(HttpRequest("GET", "/health")) + + assert isinstance(response, RestRequestsTransportResponse) + readonly_checks(response) diff --git a/sdk/core/azure-core/tests/test_retry_policy.py b/sdk/core/azure-core/tests/test_retry_policy.py index ce1a590fba26..3201be78b287 100644 --- a/sdk/core/azure-core/tests/test_retry_policy.py +++ b/sdk/core/azure-core/tests/test_retry_policy.py @@ -21,20 +21,20 @@ RetryMode, ) from azure.core.pipeline import Pipeline, PipelineResponse -from azure.core.pipeline.transport import ( - HttpRequest, - HttpResponse, - HttpTransport, -) +from azure.core.pipeline.transport import HttpTransport import tempfile import os import time +from itertools import product +from utils import create_http_response, pipeline_transport_and_rest_product, HTTP_RESPONSES, HTTP_REQUESTS try: from unittest.mock import Mock except ImportError: from mock import Mock +retry_after_input = ('0', '800', '1000', '1200') +full_combination = list(product(retry_after_input, pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES))) def test_retry_code_class_variables(): retry_policy = RetryPolicy() @@ -62,11 +62,13 @@ def test_retry_types(): backoff_time = retry_policy.get_backoff_time(settings) assert backoff_time == 4 -@pytest.mark.parametrize("retry_after_input", [('0'), ('800'), ('1000'), ('1200')]) -def test_retry_after(retry_after_input): + +@pytest.mark.parametrize("retry_after_input,request_and_response", full_combination) +def test_retry_after(retry_after_input, request_and_response): + http_request, http_response = request_and_response retry_policy = RetryPolicy() - request = HttpRequest("GET", "http://localhost") - response = HttpResponse(request, None) + request = http_request("GET", "http://localhost") + response = create_http_response(http_response, request, None) response.headers["retry-after-ms"] = retry_after_input pipeline_response = PipelineResponse(request, response, None) retry_after = retry_policy.get_retry_after(pipeline_response) @@ -80,11 +82,12 @@ def test_retry_after(retry_after_input): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) -@pytest.mark.parametrize("retry_after_input", [('0'), ('800'), ('1000'), ('1200')]) -def test_x_ms_retry_after(retry_after_input): +@pytest.mark.parametrize("retry_after_input,request_and_response", full_combination) +def test_x_ms_retry_after(retry_after_input, request_and_response): + http_request, http_response = request_and_response retry_policy = RetryPolicy() - request = HttpRequest("GET", "http://localhost") - response = HttpResponse(request, None) + request = http_request("GET", "http://localhost") + response = create_http_response(http_response, request, None) response.headers["x-ms-retry-after-ms"] = retry_after_input pipeline_response = PipelineResponse(request, response, None) retry_after = retry_policy.get_retry_after(pipeline_response) @@ -98,7 +101,8 @@ def test_x_ms_retry_after(retry_after_input): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) -def test_retry_on_429(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_retry_on_429(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._count = 0 @@ -111,18 +115,19 @@ def open(self): def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineResponse self._count += 1 - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 429 return response - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') http_retry = RetryPolicy(retry_total = 1) transport = MockTransport() pipeline = Pipeline(transport, [http_retry]) pipeline.run(http_request) assert transport._count == 2 -def test_no_retry_on_201(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_no_retry_on_201(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._count = 0 @@ -135,20 +140,21 @@ def open(self): def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineResponse self._count += 1 - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 201 headers = {"Retry-After": "1"} response.headers = headers return response - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') http_retry = RetryPolicy(retry_total = 1) transport = MockTransport() pipeline = Pipeline(transport, [http_retry]) pipeline.run(http_request) assert transport._count == 1 -def test_retry_seekable_stream(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_retry_seekable_stream(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._first = True @@ -166,18 +172,19 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe raise AzureError('fail on first') position = request.body.tell() assert position == 0 - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 400 return response data = BytesIO(b"Lots of dataaaa") - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') http_request.set_streamed_data_body(data) http_retry = RetryPolicy(retry_total = 1) pipeline = Pipeline(MockTransport(), [http_retry]) pipeline.run(http_request) -def test_retry_seekable_file(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_retry_seekable_file(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._first = True @@ -201,14 +208,14 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe if name and body and hasattr(body, 'read'): position = body.tell() assert not position - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 400 return response file = tempfile.NamedTemporaryFile(delete=False) file.write(b'Lots of dataaaa') file.close() - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') headers = {'Content-Type': "multipart/form-data"} http_request.headers = headers with open(file.name, 'rb') as f: @@ -222,8 +229,8 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe pipeline.run(http_request) os.unlink(f.name) - -def test_retry_timeout(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_retry_timeout(http_request): timeout = 1 def send(request, **kwargs): @@ -239,16 +246,17 @@ def send(request, **kwargs): pipeline = Pipeline(transport, [RetryPolicy(timeout=timeout)]) with pytest.raises(ServiceResponseTimeoutError): - response = pipeline.run(HttpRequest("GET", "http://localhost/")) + response = pipeline.run(http_request("GET", "http://localhost/")) -def test_timeout_defaults(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_timeout_defaults(http_request, http_response): """When "timeout" is not set, the policy should not override the transport's timeout configuration""" def send(request, **kwargs): for arg in ("connection_timeout", "read_timeout"): assert arg not in kwargs, "policy should defer to transport configuration when not given a timeout" - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 200 return response @@ -259,19 +267,22 @@ def send(request, **kwargs): ) pipeline = Pipeline(transport, [RetryPolicy()]) - pipeline.run(HttpRequest("GET", "http://localhost/")) + pipeline.run(http_request("GET", "http://localhost/")) assert transport.send.call_count == 1, "policy should not retry: its first send succeeded" +transport_expected_timeout_error = [ + (ServiceRequestError, ServiceRequestTimeoutError), (ServiceResponseError, ServiceResponseTimeoutError) +] @pytest.mark.parametrize( - "transport_error,expected_timeout_error", - ((ServiceRequestError, ServiceRequestTimeoutError), (ServiceResponseError, ServiceResponseTimeoutError)), + "transport_expected_timeout_error,http_request", + product(transport_expected_timeout_error, HTTP_REQUESTS), ) -def test_does_not_sleep_after_timeout(transport_error, expected_timeout_error): +def test_does_not_sleep_after_timeout(transport_expected_timeout_error, http_request): # With default settings policy will sleep twice before exhausting its retries: 1.6s, 3.2s. # It should not sleep the second time when given timeout=1 timeout = 1 - + transport_error, expected_timeout_error = transport_expected_timeout_error transport = Mock( spec=HttpTransport, send=Mock(side_effect=transport_error("oops")), @@ -280,6 +291,6 @@ def test_does_not_sleep_after_timeout(transport_error, expected_timeout_error): pipeline = Pipeline(transport, [RetryPolicy(timeout=timeout)]) with pytest.raises(expected_timeout_error): - pipeline.run(HttpRequest("GET", "http://localhost/")) + pipeline.run(http_request("GET", "http://localhost/")) assert transport.sleep.call_count == 1 diff --git a/sdk/core/azure-core/tests/test_stream_generator.py b/sdk/core/azure-core/tests/test_stream_generator.py index 1d8b3c7172c4..ffdbe55124cc 100644 --- a/sdk/core/azure-core/tests/test_stream_generator.py +++ b/sdk/core/azure-core/tests/test_stream_generator.py @@ -4,11 +4,8 @@ # ------------------------------------ import requests from azure.core.pipeline.transport import ( - HttpRequest, - HttpResponse, HttpTransport, RequestsTransport, - RequestsTransportResponse, ) from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.transport._requests_basic import StreamDownloadGenerator @@ -17,8 +14,9 @@ except ImportError: import mock import pytest - -def test_connection_error_response(): +from utils import create_http_response, HTTP_RESPONSES, REQUESTS_TRANSPORT_RESPONSES, HTTP_REQUESTS, pipeline_transport_and_rest_product +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_connection_error_response(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._count = 0 @@ -31,8 +29,8 @@ def open(self): pass def send(self, request, **kwargs): - request = HttpRequest('GET', 'http://localhost/') - response = HttpResponse(request, None) + request = http_request('GET', 'http://localhost/') + response = create_http_response(http_response, request, None) response.status_code = 200 return response @@ -43,7 +41,7 @@ def __next__(self): if self._count == 0: self._count += 1 raise requests.exceptions.ConnectionError - + def stream(self, chunk_size, decode_content=False): if self._count == 0: self._count += 1 @@ -58,16 +56,17 @@ def __init__(self): def close(self): pass - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') pipeline = Pipeline(MockTransport()) - http_response = HttpResponse(http_request, None) + http_response = create_http_response(http_response, http_request, None) http_response.internal_response = MockInternalResponse() stream = StreamDownloadGenerator(pipeline, http_response, decompress=False) with mock.patch('time.sleep', return_value=None): with pytest.raises(requests.exceptions.ConnectionError): stream.__next__() -def test_response_streaming_error_behavior(): +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_response_streaming_error_behavior(http_response): # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723 block_size = 103 total_response_size = 500 @@ -104,7 +103,8 @@ def close(self): s = FakeStreamWithConnectionError() req_response.raw = FakeStreamWithConnectionError() - response = RequestsTransportResponse( + response = create_http_response( + http_response, req_request, req_response, block_size, diff --git a/sdk/core/azure-core/tests/test_streaming.py b/sdk/core/azure-core/tests/test_streaming.py index 2a5e6a4d0bb8..dd72ae3e9744 100644 --- a/sdk/core/azure-core/tests/test_streaming.py +++ b/sdk/core/azure-core/tests/test_streaming.py @@ -23,16 +23,20 @@ # THE SOFTWARE. # # -------------------------------------------------------------------------- +import pytest from azure.core import PipelineClient from azure.core.exceptions import DecodeError +from utils import HTTP_REQUESTS -def test_decompress_plain_no_header(): + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_decompress_plain_no_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -40,13 +44,15 @@ def test_decompress_plain_no_header(): decoded = content.decode('utf-8') assert decoded == "test" -def test_compress_plain_no_header(): + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_compress_plain_no_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -54,13 +60,15 @@ def test_compress_plain_no_header(): decoded = content.decode('utf-8') assert decoded == "test" -def test_decompress_compressed_no_header(): + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_decompress_compressed_no_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -71,13 +79,15 @@ def test_decompress_compressed_no_header(): except UnicodeDecodeError: pass -def test_compress_compressed_no_header(): + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_compress_compressed_no_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -88,14 +98,16 @@ def test_compress_compressed_no_header(): except UnicodeDecodeError: pass -def test_decompress_plain_header(): + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_decompress_plain_header(http_request): # expect error import requests account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -105,13 +117,15 @@ def test_decompress_plain_header(): except (requests.exceptions.ContentDecodingError, DecodeError): pass -def test_compress_plain_header(): + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_compress_plain_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -119,13 +133,15 @@ def test_compress_plain_header(): decoded = content.decode('utf-8') assert decoded == "test" -def test_decompress_compressed_header(): + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_decompress_compressed_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -133,13 +149,15 @@ def test_decompress_compressed_header(): decoded = content.decode('utf-8') assert decoded == "test" -def test_compress_compressed_header(): + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_compress_compressed_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) diff --git a/sdk/core/azure-core/tests/test_tracing_decorator.py b/sdk/core/azure-core/tests/test_tracing_decorator.py index be9a820747e3..3aa57f2062ab 100644 --- a/sdk/core/azure-core/tests/test_tracing_decorator.py +++ b/sdk/core/azure-core/tests/test_tracing_decorator.py @@ -14,23 +14,22 @@ import pytest from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.policies import HTTPPolicy -from azure.core.pipeline.transport import HttpTransport, HttpRequest +from azure.core.pipeline.transport import HttpTransport from azure.core.settings import settings from azure.core.tracing import common from azure.core.tracing.decorator import distributed_trace from tracing_common import FakeSpan - +from utils import HTTP_REQUESTS @pytest.fixture(scope="module") def fake_span(): settings.tracing_implementation.set_value(FakeSpan) - class MockClient: @distributed_trace - def __init__(self, policies=None, assert_current_span=False): + def __init__(self, http_request, policies=None, assert_current_span=False): time.sleep(0.001) - self.request = HttpRequest("GET", "http://localhost") + self.request = http_request("GET", "http://localhost") if policies is None: policies = [] policies.append(mock.Mock(spec=HTTPPolicy, send=self.verify_request)) @@ -86,9 +85,9 @@ def raising_exception(self): def random_function(): pass - -def test_get_function_and_class_name(): - client = MockClient() +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_get_function_and_class_name(http_request): + client = MockClient(http_request) assert common.get_function_and_class_name(client.get_foo, client) == "MockClient.get_foo" assert common.get_function_and_class_name(random_function) == "random_function" @@ -96,9 +95,10 @@ def test_get_function_and_class_name(): @pytest.mark.usefixtures("fake_span") class TestDecorator(object): - def test_decorator_tracing_attr(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_decorator_tracing_attr(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) client.tracing_attr() assert len(parent.children) == 2 @@ -106,18 +106,20 @@ def test_decorator_tracing_attr(self): assert parent.children[1].name == "MockClient.tracing_attr" assert parent.children[1].attributes == {'foo': 'bar'} - def test_decorator_has_different_name(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_decorator_has_different_name(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) client.check_name_is_different() assert len(parent.children) == 2 assert parent.children[0].name == "MockClient.__init__" assert parent.children[1].name == "different name" - def test_used(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_used(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient(policies=[]) + client = MockClient(http_request, policies=[]) client.get_foo(parent_span=parent) client.get_foo() @@ -129,9 +131,10 @@ def test_used(self): assert parent.children[2].name == "MockClient.get_foo" assert not parent.children[2].children - def test_span_merge_span(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_span_merge_span(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) client.merge_span_method() client.no_merge_span_method() @@ -143,9 +146,10 @@ def test_span_merge_span(self): assert parent.children[2].name == "MockClient.no_merge_span_method" assert parent.children[2].children[0].name == "MockClient.get_foo" - def test_span_complicated(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_span_complicated(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) client.make_request(2) with parent.span("child") as child: time.sleep(0.001) @@ -163,11 +167,12 @@ def test_span_complicated(self): assert parent.children[3].name == "MockClient.make_request" assert not parent.children[3].children - def test_span_with_exception(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_span_with_exception(self, http_request): """Assert that if an exception is raised, the next sibling method is actually a sibling span. """ with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) try: client.raising_exception() except: diff --git a/sdk/core/azure-core/tests/test_tracing_policy.py b/sdk/core/azure-core/tests/test_tracing_policy.py index 2a0fc03a78e7..86e8259cd83b 100644 --- a/sdk/core/azure-core/tests/test_tracing_policy.py +++ b/sdk/core/azure-core/tests/test_tracing_policy.py @@ -7,30 +7,31 @@ from azure.core.pipeline import PipelineResponse, PipelineRequest, PipelineContext from azure.core.pipeline.policies import DistributedTracingPolicy, UserAgentPolicy -from azure.core.pipeline.transport import HttpRequest, HttpResponse from azure.core.settings import settings from tracing_common import FakeSpan import time +import pytest +from utils import create_http_response, HTTP_RESPONSES, HTTP_REQUESTS, pipeline_transport_and_rest_product try: from unittest import mock except ImportError: import mock - -def test_distributed_tracing_policy_solo(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_distributed_tracing_policy_solo(http_request, http_response): """Test policy with no other policy and happy path""" settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: policy = DistributedTracingPolicy() - request = HttpRequest("GET", "http://localhost/temp?query=query") + request = http_request("GET", "http://localhost/temp?query=query") request.headers["x-ms-client-request-id"] = "some client request id" pipeline_request = PipelineRequest(request, PipelineContext(None)) policy.on_request(pipeline_request) - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 response.headers["x-ms-request-id"] = "some request id" @@ -68,7 +69,8 @@ def test_distributed_tracing_policy_solo(): assert network_span.attributes.get("http.status_code") == 504 -def test_distributed_tracing_policy_attributes(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_distributed_tracing_policy_attributes(http_request, http_response): """Test policy with no other policy and happy path""" settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: @@ -76,12 +78,12 @@ def test_distributed_tracing_policy_attributes(): 'myattr': 'myvalue' }) - request = HttpRequest("GET", "http://localhost/temp?query=query") + request = http_request("GET", "http://localhost/temp?query=query") pipeline_request = PipelineRequest(request, PipelineContext(None)) policy.on_request(pipeline_request) - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 @@ -91,14 +93,14 @@ def test_distributed_tracing_policy_attributes(): network_span = root_span.children[0] assert network_span.attributes.get("myattr") == "myvalue" - -def test_distributed_tracing_policy_badurl(caplog): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_distributed_tracing_policy_badurl(caplog, http_request, http_response): """Test policy with a bad url that will throw, and be sure policy ignores it""" settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: policy = DistributedTracingPolicy() - request = HttpRequest("GET", "http://[[[") + request = http_request("GET", "http://[[[") request.headers["x-ms-client-request-id"] = "some client request id" pipeline_request = PipelineRequest(request, PipelineContext(None)) @@ -106,7 +108,7 @@ def test_distributed_tracing_policy_badurl(caplog): policy.on_request(pipeline_request) assert "Unable to start network span" in caplog.text - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 response.headers["x-ms-request-id"] = "some request id" @@ -124,15 +126,15 @@ def test_distributed_tracing_policy_badurl(caplog): assert len(root_span.children) == 0 - -def test_distributed_tracing_policy_with_user_agent(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_distributed_tracing_policy_with_user_agent(http_request, http_response): """Test policy working with user agent.""" settings.tracing_implementation.set_value(FakeSpan) with mock.patch.dict('os.environ', {"AZURE_HTTP_USER_AGENT": "mytools"}): with FakeSpan(name="parent") as root_span: policy = DistributedTracingPolicy() - request = HttpRequest("GET", "http://localhost") + request = http_request("GET", "http://localhost") request.headers["x-ms-client-request-id"] = "some client request id" pipeline_request = PipelineRequest(request, PipelineContext(None)) @@ -141,7 +143,7 @@ def test_distributed_tracing_policy_with_user_agent(): user_agent.on_request(pipeline_request) policy.on_request(pipeline_request) - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 response.headers["x-ms-request-id"] = "some request id" @@ -182,12 +184,12 @@ def test_distributed_tracing_policy_with_user_agent(): # Exception should propagate status for Opencensus assert network_span.status == 'Transport trouble' - -def test_span_namer(): +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) +def test_span_namer(http_request, http_response): settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: - request = HttpRequest("GET", "http://localhost/temp?query=query") + request = http_request("GET", "http://localhost/temp?query=query") pipeline_request = PipelineRequest(request, PipelineContext(None)) def fixed_namer(http_request): @@ -198,7 +200,7 @@ def fixed_namer(http_request): policy.on_request(pipeline_request) - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 @@ -212,7 +214,7 @@ def operation_namer(http_request): policy.on_request(pipeline_request) - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 diff --git a/sdk/core/azure-core/tests/test_universal_pipeline.py b/sdk/core/azure-core/tests/test_universal_pipeline.py index ea5676374458..de5fea15516d 100644 --- a/sdk/core/azure-core/tests/test_universal_pipeline.py +++ b/sdk/core/azure-core/tests/test_universal_pipeline.py @@ -30,7 +30,7 @@ from unittest import mock except ImportError: import mock - +from itertools import product import requests import pytest @@ -41,11 +41,6 @@ PipelineRequest, PipelineContext ) -from azure.core.pipeline.transport import ( - HttpRequest, - HttpResponse, - RequestsTransportResponse, -) from azure.core.pipeline.policies import ( NetworkTraceLoggingPolicy, @@ -54,6 +49,17 @@ RetryPolicy, HTTPPolicy, ) +from azure.core.pipeline.transport import HttpResponse as PipelineTransportHttpResponse +from azure.core.pipeline.transport._base import RestHttpResponseImpl +from utils import ( + create_http_request, + create_http_response, + HTTP_REQUESTS, + HTTP_RESPONSES, + REQUESTS_TRANSPORT_RESPONSES, + pipeline_transport_and_rest_product, + is_rest_http_response, +) def test_pipeline_context(): kwargs={ @@ -86,38 +92,41 @@ def test_pipeline_context(): assert len(revived_context) == 1 -def test_request_history(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_history(http_request): class Non_deep_copiable(object): def __deepcopy__(self, memodict={}): raise ValueError() body = Non_deep_copiable() - request = HttpRequest('GET', 'http://localhost/', {'user-agent': 'test_request_history'}) + request = create_http_request(http_request, 'GET', 'http://localhost/', {'user-agent': 'test_request_history'}) request.body = body request_history = RequestHistory(request) assert request_history.http_request.headers == request.headers assert request_history.http_request.url == request.url assert request_history.http_request.method == request.method -def test_request_history_type_error(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_history_type_error(http_request): class Non_deep_copiable(object): def __deepcopy__(self, memodict={}): raise TypeError() body = Non_deep_copiable() - request = HttpRequest('GET', 'http://localhost/', {'user-agent': 'test_request_history'}) + request = create_http_request(http_request, 'GET', 'http://localhost/', {'user-agent': 'test_request_history'}) request.body = body request_history = RequestHistory(request) assert request_history.http_request.headers == request.headers assert request_history.http_request.url == request.url assert request_history.http_request.method == request.method +@pytest.mark.parametrize("http_request,http_response", pipeline_transport_and_rest_product(HTTP_REQUESTS, HTTP_RESPONSES)) @mock.patch('azure.core.pipeline.policies._universal._LOGGER') -def test_no_log(mock_http_logger): - universal_request = HttpRequest('GET', 'http://localhost/') +def test_no_log(mock_http_logger, http_request, http_response): + universal_request = http_request('GET', 'http://localhost/') request = PipelineRequest(universal_request, PipelineContext(None)) http_logger = NetworkTraceLoggingPolicy() - response = PipelineResponse(request, HttpResponse(universal_request, None), request.context) + response = PipelineResponse(request, create_http_response(http_response, universal_request, None), request.context) # By default, no log handler for HTTP http_logger.on_request(request) @@ -178,7 +187,8 @@ def test_no_log(mock_http_logger): second_count = mock_http_logger.debug.call_count assert second_count == first_count * 2 -def test_retry_without_http_response(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_retry_without_http_response(http_request): class NaughtyPolicy(HTTPPolicy): def send(*args): raise AzureError('boo') @@ -186,25 +196,51 @@ def send(*args): policies = [RetryPolicy(), NaughtyPolicy()] pipeline = Pipeline(policies=policies, transport=None) with pytest.raises(AzureError): - pipeline.run(HttpRequest('GET', url='https://foo.bar')) - -def test_raw_deserializer(): + pipeline.run(http_request('GET', url='https://foo.bar')) +@pytest.mark.parametrize( + "http_request,http_response,requests_transport_response", + pipeline_transport_and_rest_product( + HTTP_REQUESTS, HTTP_RESPONSES, REQUESTS_TRANSPORT_RESPONSES + ) +) +def test_raw_deserializer(http_request, http_response, requests_transport_response): raw_deserializer = ContentDecodePolicy() context = PipelineContext(None, stream=False) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') request = PipelineRequest(universal_request, context) def build_response(body, content_type=None): - class MockResponse(HttpResponse): + class PipelineTransportMockResponse(PipelineTransportHttpResponse): def __init__(self, body, content_type): - super(MockResponse, self).__init__(None, None) + super(PipelineTransportMockResponse, self).__init__(request=None, internal_response=None) self._body = body self.content_type = content_type def body(self): return self._body - return PipelineResponse(request, MockResponse(body, content_type), context) + class RestMockResponse(RestHttpResponseImpl): + def __init__(self, body, content_type): + super(RestMockResponse, self).__init__( + request=None, + internal_response=None, + status_code=200, + reason="OK", + content_type=content_type, + headers={}, + stream_download_generator=None, + ) + self._content = body + + @property + def content(self): + return self._content + + if is_rest_http_response(http_response): + mock_response = PipelineTransportMockResponse + else: + mock_response = RestMockResponse + return PipelineResponse(request, mock_response(body, content_type), context) response = build_response(b"", content_type="application/xml") raw_deserializer.on_response(request, response) @@ -287,7 +323,7 @@ def body(self): req_response.headers["content-type"] = "application/json" req_response._content = b'{"success": true}' req_response._content_consumed = True - response = PipelineResponse(None, RequestsTransportResponse(None, req_response), PipelineContext(None, stream=False)) + response = PipelineResponse(None, create_http_response(requests_transport_response, None, req_response), PipelineContext(None, stream=False)) raw_deserializer.on_response(request, response) result = response.context["deserialized_data"] diff --git a/sdk/core/azure-core/tests/test_user_agent_policy.py b/sdk/core/azure-core/tests/test_user_agent_policy.py index 4f8b01c93b7e..efaa5d2a5b63 100644 --- a/sdk/core/azure-core/tests/test_user_agent_policy.py +++ b/sdk/core/azure-core/tests/test_user_agent_policy.py @@ -3,15 +3,17 @@ # Licensed under the MIT License. # ------------------------------------ """Tests for the user agent policy.""" +import pytest from azure.core.pipeline.policies import UserAgentPolicy -from azure.core.pipeline.transport import HttpRequest from azure.core.pipeline import PipelineRequest, PipelineContext +from utils import HTTP_REQUESTS try: from unittest import mock except ImportError: import mock -def test_user_agent_policy(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_user_agent_policy(http_request): user_agent = UserAgentPolicy(base_user_agent='foo') assert user_agent._user_agent == 'foo' @@ -21,7 +23,7 @@ def test_user_agent_policy(): user_agent = UserAgentPolicy(base_user_agent='foo', user_agent='bar', user_agent_use_env=False) assert user_agent._user_agent == 'bar foo' - request = HttpRequest('GET', 'http://localhost/') + request = http_request('GET', 'http://localhost/') pipeline_request = PipelineRequest(request, PipelineContext(None)) pipeline_request.context.options['user_agent'] = 'xyz' @@ -29,12 +31,13 @@ def test_user_agent_policy(): assert request.headers['User-Agent'] == 'xyz bar foo' -def test_user_agent_environ(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_user_agent_environ(http_request): with mock.patch.dict('os.environ', {'AZURE_HTTP_USER_AGENT': "mytools"}): policy = UserAgentPolicy(None) assert policy.user_agent.endswith("mytools") - request = HttpRequest('GET', 'http://localhost/') + request = http_request('GET', 'http://localhost/') policy.on_request(PipelineRequest(request, PipelineContext(None))) assert request.headers["user-agent"].endswith("mytools") diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/__init__.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/__init__.py index 63560847a01f..eee7765e9055 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/__init__.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/__init__.py @@ -8,12 +8,13 @@ from flask import Flask, Response from .test_routes import ( basic_api, + decompression_api, encoding_api, errors_api, streams_api, urlencoded_api, multipart_api, - xml_api + xml_api, ) app = Flask(__name__) @@ -24,6 +25,7 @@ app.register_blueprint(urlencoded_api, url_prefix="/urlencoded") app.register_blueprint(multipart_api, url_prefix="/multipart") app.register_blueprint(xml_api, url_prefix="/xml") +app.register_blueprint(decompression_api, url_prefix="/decompression") @app.route('/health', methods=['GET']) def latin_1_charset_utf8(): diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/__init__.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/__init__.py index 82f4e7ac4566..28547256b92e 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/__init__.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/__init__.py @@ -12,9 +12,11 @@ from .streams import streams_api from .urlencoded import urlencoded_api from .xml_route import xml_api +from .decompression import decompression_api __all__ = [ "basic_api", + "decompression_api", "encoding_api", "errors_api", "multipart_api", diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/decompression.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/decompression.py new file mode 100644 index 000000000000..ed448ccca638 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/decompression.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from flask import ( + Response, + Blueprint, + request +) + +decompression_api = Blueprint('decompression_api', __name__) + +@decompression_api.route('/gzip/pass', methods=['GET']) +def gzip_pass(): + return Response( + b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x04\x00\x8d\x8d\xb1n\xc30\x0cD" + b"\xff\x85s\x14HVlY\xda\x8av.\n4\x1d\x9a\x8d\xa1\xe5D\x80m\x01\x12=" + b"\x14A\xfe\xbd\x92\x81d\xceB\x1c\xef\xf8\x8e7\x08\x038\xf0\xa67Fj+" + b"\x946\x9d8\x0c4\x08{\x96(\x94mzkh\x1cM/a\x07\x94<\xb2\x1f>\xca8\x86" + b"\xd9\xff0\x15\xb6\x91\x8d\x12\xb2\x15\xd2\x1c\x95q\xbau\xba\xdbk" + b"\xd5(\xd9\xb5\xa7\xc2L\x98\xf9\x8d8\xc4\xe5U\xccV,3\xf2\x9a\xcb\xddg" + b"\xe4o\xc6T\xdeVw\x9dgL\x7f\xe0n\xc0\x91q\x02'w0b\x98JZe^\x89|\xce\x9b" + b"\x0e\xcbW\x8a\x97\xf4X\x97\xc8\xbf\xfeYU\x1d\xc2\x85\xfc\xf4@\xb7\xbe" + b"\xf7+&$\xf6\xa9\x8a\xcb\x96\xdc\xef\xff\xaa\xa1\x1c\xf9$\x01\x00\x00", + status=200, + headers={'Content-Type': 'text/plain', 'Content-Encoding':"gzip"}, + ) + +@decompression_api.route('/gzip/fail', methods=['GET']) +def gzip_fail(): + return Response( + b"\xff\x85s\x14HVlY\xda\x8av.\n4\x1d\x9a\x8d\xa1\xe5D\x80m\x01\x12=" + b"\x14A\xfe\xbd\x92\x81d\xceB\x1c\xef\xf8\x8e7\x08\x038\xf0\xa67Fj+" + b"\x946\x9d8\x0c4\x08{\x96(\x94mzkh\x1cM/a\x07\x94<\xb2\x1f>\xca8\x86" + b"\xd9\xff0\x15\xb6\x91\x8d\x12\xb2\x15\xd2\x1c\x95q\xbau\xba\xdbk" + b"\xd5(\xd9\xb5\xa7\xc2L\x98\xf9\x8d8\xc4\xe5U\xccV,3\xf2\x9a\xcb\xddg" + b"\xe4o\xc6T\xdeVw\x9dgL\x7f\xe0n\xc0\x91q\x02'w0b\x98JZe^\x89|\xce\x9b" + b"\x0e\xcbW\x8a\x97\xf4X\x97\xc8\xbf\xfeYU\x1d\xc2\x85\xfc\xf4@\xb7\xbe" + b"\xf7+&$\xf6\xa9\x8a\xcb\x96\xdc\xef\xff\xaa\xa1\x1c\xf9$\x01\x00\x00", + status=200, + headers={'Content-Type': 'text/plain', 'Content-Encoding':"gzip"} + ) \ No newline at end of file diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/headers.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/headers.py new file mode 100644 index 000000000000..ffb5232587b0 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/headers.py @@ -0,0 +1,60 @@ +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +from flask import ( + Response, + Blueprint, + request +) + +headers_api = Blueprint('headers_api', __name__) + +@headers_api.route("/case-insensitive", methods=['GET']) +def case_insensitive(): + return Response( + status=200, + headers={ + "lowercase-header": "lowercase", + "ALLCAPS-HEADER": "ALLCAPS", + "CamelCase-Header": "camelCase", + } + ) + +@headers_api.route("/empty", methods=['GET']) +def empty(): + return Response( + status=200, + headers={} + ) + +@headers_api.route("/duplicate/numbers", methods=['GET']) +def duplicate_numbers(): + return Response( + status=200, + headers=[("a", "123"), ("a", "456"), ("b", "789")] + ) + +@headers_api.route("/duplicate/case-insensitive", methods=['GET']) +def duplicate_case_insensitive(): + return Response( + status=200, + headers=[("Duplicate-Header", "one"), ("Duplicate-Header", "two"), ("duplicate-header", "three")] + ) + +@headers_api.route("/duplicate/commas", methods=['GET']) +def duplicate_commas(): + return Response( + status=200, + headers=[("Set-Cookie", "a, b"), ("Set-Cookie", "c")] + ) + +@headers_api.route("/ordered", methods=['GET']) +def ordered(): + return Response( + status=200, + headers={"a": "a", "b": "b", "c": "c"}, + ) diff --git a/sdk/core/azure-core/tests/utils.py b/sdk/core/azure-core/tests/utils.py new file mode 100644 index 000000000000..9ae2824a98b9 --- /dev/null +++ b/sdk/core/azure-core/tests/utils.py @@ -0,0 +1,173 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from typing import Type +import pytest +import abc +############################## LISTS USED TO PARAMETERIZE TESTS ############################## +from azure.core.pipeline.transport import( + HttpRequest as PipelineTransportHttpRequest, + HttpResponse as PipelineTransportHttpResponse, +) +from azure.core.pipeline.transport._base import RestHttpResponseImpl + +from azure.core.rest import ( + HttpRequest as RestHttpRequest, +) + +HTTP_REQUESTS = [PipelineTransportHttpRequest, RestHttpRequest] + +HTTP_RESPONSES = [PipelineTransportHttpResponse, RestHttpResponseImpl] + +try: + from azure.core.pipeline.transport import AsyncHttpResponse as PipelineTransportAsyncHttpResponse + from azure.core.pipeline.transport._base_async import RestAsyncHttpResponseImpl + ASYNC_HTTP_RESPONSES = [PipelineTransportAsyncHttpResponse, RestAsyncHttpResponseImpl] +except (SyntaxError, ImportError): + ASYNC_HTTP_RESPONSES = [] + +try: + from azure.core.pipeline.transport._requests_basic import ( + RequestsTransportResponse as PipelineTransportRequestsTransportResponse, + RestRequestsTransportResponse, + ) + REQUESTS_TRANSPORT_RESPONSES = [PipelineTransportRequestsTransportResponse, RestRequestsTransportResponse] +except (SyntaxError, ImportError): + REQUESTS_TRANSPORT_RESPONSES = [] + +try: + from azure.core.pipeline.transport._aiohttp import ( + AioHttpTransportResponse as PipelineTransportAioHttpTransportResponse, + RestAioHttpTransportResponse + ) + AIOHTTP_TRANSPORT_RESPONSES = [PipelineTransportAioHttpTransportResponse, RestAioHttpTransportResponse] +except (SyntaxError, ImportError): + AIOHTTP_TRANSPORT_RESPONSES = [] + + +try: + from azure.core.pipeline.transport._requests_asyncio import ( + AsyncioRequestsTransportResponse as PipelineTransportAsyncioRequestsTransportResponse, + RestAsyncioRequestsTransportResponse, + ) + ASYNCIO_REQUESTS_TRANSPORT_RESPONSES = [PipelineTransportAsyncioRequestsTransportResponse, RestAsyncioRequestsTransportResponse] +except (SyntaxError, ImportError): + ASYNCIO_REQUESTS_TRANSPORT_RESPONSES = [] + +def pipeline_transport_and_rest_product(*args): + # add pipeline transport requests / responses + my_response = [ + tuple(arg[0] for arg in args) + ] + # add rest requests / responses + my_response.extend([ + tuple(arg[1] for arg in args) + ]) + return my_response + +############################## HELPER FUNCTIONS ############################## + +def create_http_request(http_request, *args, **kwargs): + if hasattr(http_request, "content"): + method = args[0] + url = args[1] + try: + headers = args[2] + except IndexError: + headers = None + try: + files = args[3] + except IndexError: + files = None + try: + data = args[4] + except IndexError: + data = None + return http_request( + method=method, + url=url, + headers=headers, + files=files, + data=data, + **kwargs + ) + return http_request(*args, **kwargs) + +def create_http_response(http_response, *args, **kwargs): + if hasattr(http_response, "content"): + if len(args) > 2: + block_size = args[2] + else: + block_size = None + try: + response = http_response( + request=args[0], + internal_response=args[1], + block_size=block_size, + **kwargs + ) + except KeyError: + kwargs.update({ + "status_code": 200, + "reason": "OK", + "content_type": "application/json", + "headers": {}, + "stream_download_generator": None, + }) + response = http_response( + request=args[0], + internal_response=args[1], + block_size=block_size, + **kwargs + ) + return response + return http_response(*args, **kwargs) + +def is_rest_http_request(http_request): + return hasattr(http_request, "content") + +def is_rest_http_response(http_response): + return hasattr(http_response, "content") + +def readonly_checks(response): + assert isinstance(response.request, RestHttpRequest) + with pytest.raises(AttributeError): + response.request = None + + assert isinstance(response.status_code, int) + with pytest.raises(AttributeError): + response.status_code = 200 + + assert response.headers + with pytest.raises(AttributeError): + response.headers = {"hello": "world"} + + assert response.reason == "OK" + with pytest.raises(AttributeError): + response.reason = "Not OK" + + assert response.content_type == 'text/html; charset=utf-8' + with pytest.raises(AttributeError): + response.content_type = "bad content type" + + assert response.is_closed + with pytest.raises(AttributeError): + response.is_closed = False + + assert response.is_stream_consumed + with pytest.raises(AttributeError): + response.is_stream_consumed = False + + # you can set encoding + assert response.encoding == "utf-8" + response.encoding = "blah" + assert response.encoding == "blah" + + assert isinstance(response.url, str) + with pytest.raises(AttributeError): + response.url = "http://fakeurl" + + assert response.content is not None + with pytest.raises(AttributeError): + response.content = b"bad"