diff --git a/google/auth/_exponential_backoff.py b/google/auth/_exponential_backoff.py index 04f9f9764..89853448f 100644 --- a/google/auth/_exponential_backoff.py +++ b/google/auth/_exponential_backoff.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import random import time @@ -38,9 +39,8 @@ """ -class ExponentialBackoff: - """An exponential backoff iterator. This can be used in a for loop to - perform requests with exponential backoff. +class _BaseExponentialBackoff: + """An exponential backoff iterator base class. Args: total_attempts Optional[int]: @@ -84,9 +84,40 @@ def __init__( self._multiplier = multiplier self._backoff_count = 0 - def __iter__(self): + @property + def total_attempts(self): + """The total amount of backoff attempts that will be made.""" + return self._total_attempts + + @property + def backoff_count(self): + """The current amount of backoff attempts that have been made.""" + return self._backoff_count + + def _reset(self): self._backoff_count = 0 self._current_wait_in_seconds = self._initial_wait_seconds + + def _calculate_jitter(self): + jitter_variance = self._current_wait_in_seconds * self._randomization_factor + jitter = random.uniform( + self._current_wait_in_seconds - jitter_variance, + self._current_wait_in_seconds + jitter_variance, + ) + + return jitter + + +class ExponentialBackoff(_BaseExponentialBackoff): + """An exponential backoff iterator. This can be used in a for loop to + perform requests with exponential backoff. + """ + + def __init__(self, *args, **kwargs): + super(ExponentialBackoff, self).__init__(*args, **kwargs) + + def __iter__(self): + self._reset() return self def __next__(self): @@ -97,23 +128,37 @@ def __next__(self): if self._backoff_count <= 1: return self._backoff_count - jitter_variance = self._current_wait_in_seconds * self._randomization_factor - jitter = random.uniform( - self._current_wait_in_seconds - jitter_variance, - self._current_wait_in_seconds + jitter_variance, - ) + jitter = self._calculate_jitter() time.sleep(jitter) self._current_wait_in_seconds *= self._multiplier return self._backoff_count - @property - def total_attempts(self): - """The total amount of backoff attempts that will be made.""" - return self._total_attempts - @property - def backoff_count(self): - """The current amount of backoff attempts that have been made.""" +class AsyncExponentialBackoff(_BaseExponentialBackoff): + """An async exponential backoff iterator. This can be used in a for loop to + perform async requests with exponential backoff. + """ + + def __init__(self, *args, **kwargs): + super(AsyncExponentialBackoff, self).__init__(*args, **kwargs) + + def __aiter__(self): + self._reset() + return self + + async def __anext__(self): + if self._backoff_count >= self._total_attempts: + raise StopAsyncIteration + self._backoff_count += 1 + + if self._backoff_count <= 1: + return self._backoff_count + + jitter = self._calculate_jitter() + + await asyncio.sleep(jitter) + + self._current_wait_in_seconds *= self._multiplier return self._backoff_count diff --git a/google/auth/aio/transport/__init__.py b/google/auth/aio/transport/__init__.py index fb3143de3..7a037ef4f 100644 --- a/google/auth/aio/transport/__init__.py +++ b/google/auth/aio/transport/__init__.py @@ -25,7 +25,24 @@ """ import abc -from typing import AsyncGenerator, Dict, Mapping, Optional +from typing import AsyncGenerator, Mapping, Optional + +import google.auth.transport + + +_DEFAULT_TIMEOUT_SECONDS = 180 + +DEFAULT_RETRYABLE_STATUS_CODES = google.auth.transport.DEFAULT_RETRYABLE_STATUS_CODES +"""Sequence[int]: HTTP status codes indicating a request can be retried. +""" + +DEFAULT_REFRESH_STATUS_CODES = google.auth.transport.DEFAULT_REFRESH_STATUS_CODES +"""Sequence[int]: Which HTTP status code indicate that credentials should be +refreshed. +""" + +DEFAULT_MAX_REFRESH_ATTEMPTS = 3 +"""int: How many times to refresh the credentials and retry a request.""" class Response(metaclass=abc.ABCMeta): @@ -35,7 +52,7 @@ class Response(metaclass=abc.ABCMeta): @abc.abstractmethod def status_code(self) -> int: """ - The HTTP response status code.. + The HTTP response status code. Returns: int: The HTTP response status code. @@ -45,11 +62,11 @@ def status_code(self) -> int: @property @abc.abstractmethod - def headers(self) -> Dict[str, str]: + def headers(self) -> Mapping[str, str]: """The HTTP response headers. Returns: - Dict[str, str]: The HTTP response headers. + Mapping[str, str]: The HTTP response headers. """ raise NotImplementedError("headers must be implemented.") @@ -95,7 +112,7 @@ async def __call__( self, url: str, method: str, - body: bytes, + body: Optional[bytes], headers: Optional[Mapping[str, str]], timeout: float, **kwargs @@ -106,7 +123,7 @@ async def __call__( url (str): The URI to be requested. method (str): The HTTP method to use for the request. Defaults to 'GET'. - body (bytes): The payload / body in HTTP request. + body (Optional[bytes]): The payload / body in HTTP request. headers (Mapping[str, str]): Request headers. timeout (float): The number of seconds to wait for a response from the server. If not specified or if None, the diff --git a/google/auth/aio/transport/aiohttp.py b/google/auth/aio/transport/aiohttp.py index a1bbbd639..89d3766f9 100644 --- a/google/auth/aio/transport/aiohttp.py +++ b/google/auth/aio/transport/aiohttp.py @@ -16,12 +16,10 @@ """ import asyncio -from contextlib import asynccontextmanager -import time -from typing import AsyncGenerator, Dict, Mapping, Optional +from typing import AsyncGenerator, Mapping, Optional try: - import aiohttp + import aiohttp # type: ignore except ImportError as caught_exc: # pragma: NO COVER raise ImportError( "The aiohttp library is not installed from please install the aiohttp package to use the aiohttp transport." @@ -30,54 +28,6 @@ from google.auth import _helpers from google.auth import exceptions from google.auth.aio import transport -from google.auth.exceptions import TimeoutError - - -_DEFAULT_TIMEOUT_SECONDS = 180 - - -@asynccontextmanager -async def timeout_guard(timeout): - """ - timeout_guard is an asynchronous context manager to apply a timeout to an asynchronous block of code. - - Args: - timeout (float): The time in seconds before the context manager times out. - - Raises: - google.auth.exceptions.TimeoutError: If the code within the context exceeds the provided timeout. - - Usage: - async with timeout_guard(10) as with_timeout: - await with_timeout(async_function()) - """ - start = time.monotonic() - total_timeout = timeout - - def _remaining_time(): - elapsed = time.monotonic() - start - remaining = total_timeout - elapsed - if remaining <= 0: - raise TimeoutError( - f"Context manager exceeded the configured timeout of {total_timeout}s." - ) - return remaining - - async def with_timeout(coro): - try: - remaining = _remaining_time() - response = await asyncio.wait_for(coro, remaining) - return response - except (asyncio.TimeoutError, TimeoutError) as e: - raise TimeoutError( - f"The operation {coro} exceeded the configured timeout of {total_timeout}s." - ) from e - - try: - yield with_timeout - - finally: - _remaining_time() class Response(transport.Response): @@ -89,7 +39,7 @@ class Response(transport.Response): Attributes: status_code (int): The HTTP status code of the response. - headers (Dict[str, str]): A case-insensitive multidict proxy wiht HTTP headers of response. + headers (Mapping[str, str]): The HTTP headers of the response. """ def __init__(self, response: aiohttp.ClientResponse): @@ -102,7 +52,7 @@ def status_code(self) -> int: @property @_helpers.copy_docstring(transport.Response) - def headers(self) -> Dict[str, str]: + def headers(self) -> Mapping[str, str]: return {key: value for key, value in self._response.headers.items()} @_helpers.copy_docstring(transport.Response) @@ -158,7 +108,7 @@ async def __call__( method: str = "GET", body: Optional[bytes] = None, headers: Optional[Mapping[str, str]] = None, - timeout: float = _DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, **kwargs, ) -> transport.Response: """ @@ -199,14 +149,14 @@ async def __call__( return Response(response) except aiohttp.ClientError as caught_exc: - new_exc = exceptions.TransportError(f"Failed to send request to {url}.") - raise new_exc from caught_exc + client_exc = exceptions.TransportError(f"Failed to send request to {url}.") + raise client_exc from caught_exc except asyncio.TimeoutError as caught_exc: - new_exc = exceptions.TimeoutError( + timeout_exc = exceptions.TimeoutError( f"Request timed out after {timeout} seconds." ) - raise new_exc from caught_exc + raise timeout_exc from caught_exc async def close(self) -> None: """ diff --git a/google/auth/aio/transport/sessions.py b/google/auth/aio/transport/sessions.py new file mode 100644 index 000000000..60d33df3e --- /dev/null +++ b/google/auth/aio/transport/sessions.py @@ -0,0 +1,268 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from contextlib import asynccontextmanager +import functools +import time +from typing import Mapping, Optional + +from google.auth import _exponential_backoff, exceptions +from google.auth.aio import transport +from google.auth.aio.credentials import Credentials +from google.auth.exceptions import TimeoutError + +try: + from google.auth.aio.transport.aiohttp import Request as AiohttpRequest + + AIOHTTP_INSTALLED = True +except ImportError: # pragma: NO COVER + AIOHTTP_INSTALLED = False + + +@asynccontextmanager +async def timeout_guard(timeout): + """ + timeout_guard is an asynchronous context manager to apply a timeout to an asynchronous block of code. + + Args: + timeout (float): The time in seconds before the context manager times out. + + Raises: + google.auth.exceptions.TimeoutError: If the code within the context exceeds the provided timeout. + + Usage: + async with timeout_guard(10) as with_timeout: + await with_timeout(async_function()) + """ + start = time.monotonic() + total_timeout = timeout + + def _remaining_time(): + elapsed = time.monotonic() - start + remaining = total_timeout - elapsed + if remaining <= 0: + raise TimeoutError( + f"Context manager exceeded the configured timeout of {total_timeout}s." + ) + return remaining + + async def with_timeout(coro): + try: + remaining = _remaining_time() + response = await asyncio.wait_for(coro, remaining) + return response + except (asyncio.TimeoutError, TimeoutError) as e: + raise TimeoutError( + f"The operation {coro} exceeded the configured timeout of {total_timeout}s." + ) from e + + try: + yield with_timeout + + finally: + _remaining_time() + + +class AuthorizedSession: + """This is an asynchronous implementation of :class:`google.auth.requests.AuthorizedSession` class. + We utilize an instance of a class that implements :class:`google.auth.aio.transport.Request` configured + by the caller or otherwise default to `google.auth.aio.transport.aiohttp.Request` if the external aiohttp + package is installed. + + A Requests Session class with credentials. + + This class is used to perform asynchronous requests to API endpoints that require + authorization:: + + import aiohttp + from google.auth.aio.transport import sessions + + async with sessions.AuthorizedSession(credentials) as authed_session: + response = await authed_session.request( + 'GET', 'https://www.googleapis.com/storage/v1/b') + + The underlying :meth:`request` implementation handles adding the + credentials' headers to the request and refreshing credentials as needed. + + Args: + credentials (google.auth.aio.credentials.Credentials): + The credentials to add to the request. + auth_request (Optional[google.auth.aio.transport.Request]): + An instance of a class that implements + :class:`~google.auth.aio.transport.Request` used to make requests + and refresh credentials. If not passed, + an instance of :class:`~google.auth.aio.transport.aiohttp.Request` + is created. + + Raises: + - google.auth.exceptions.TransportError: If `auth_request` is `None` + and the external package `aiohttp` is not installed. + - google.auth.exceptions.InvalidType: If the provided credentials are + not of type `google.auth.aio.credentials.Credentials`. + """ + + def __init__( + self, credentials: Credentials, auth_request: Optional[transport.Request] = None + ): + if not isinstance(credentials, Credentials): + raise exceptions.InvalidType( + f"The configured credentials of type {type(credentials)} are invalid and must be of type `google.auth.aio.credentials.Credentials`" + ) + self._credentials = credentials + _auth_request = auth_request + if not _auth_request and AIOHTTP_INSTALLED: + _auth_request = AiohttpRequest() + if _auth_request is None: + raise exceptions.TransportError( + "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value." + ) + self._auth_request = _auth_request + + async def request( + self, + method: str, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + """ + Args: + method (str): The http method used to make the request. + url (str): The URI to be requested. + data (Optional[bytes]): The payload or body in HTTP request. + headers (Optional[Mapping[str, str]]): Request headers. + timeout (float): + The amount of time in seconds to wait for the server response + with each individual request. + max_allowed_time (float): + If the method runs longer than this, a ``Timeout`` exception is + automatically raised. Unlike the ``timeout`` parameter, this + value applies to the total method execution time, even if + multiple requests are made under the hood. + + Mind that it is not guaranteed that the timeout error is raised + at ``max_allowed_time``. It might take longer, for example, if + an underlying request takes a lot of time, but the request + itself does not timeout, e.g. if a large file is being + transmitted. The timout error will be raised after such + request completes. + + Returns: + google.auth.aio.transport.Response: The HTTP response. + + Raises: + google.auth.exceptions.TimeoutError: If the method does not complete within + the configured `max_allowed_time` or the request exceeds the configured + `timeout`. + """ + + retries = _exponential_backoff.AsyncExponentialBackoff( + total_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS + ) + async with timeout_guard(max_allowed_time) as with_timeout: + await with_timeout( + # Note: before_request will attempt to refresh credentials if expired. + self._credentials.before_request( + self._auth_request, method, url, headers + ) + ) + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for _ in retries: # pragma: no branch + response = await with_timeout( + self._auth_request(url, method, data, headers, timeout, **kwargs) + ) + if response.status_code not in transport.DEFAULT_RETRYABLE_STATUS_CODES: + break + return response + + @functools.wraps(request) + async def get( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "GET", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def post( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "POST", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def put( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "PUT", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def patch( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "PATCH", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def delete( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "DELETE", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + async def close(self) -> None: + """ + Close the underlying auth request session. + """ + await self._auth_request.close() diff --git a/tests/test__exponential_backoff.py b/tests/test__exponential_backoff.py index 95422502b..b7b6877b2 100644 --- a/tests/test__exponential_backoff.py +++ b/tests/test__exponential_backoff.py @@ -54,3 +54,44 @@ def test_minimum_total_attempts(): with pytest.raises(exceptions.InvalidValue): _exponential_backoff.ExponentialBackoff(total_attempts=-1) _exponential_backoff.ExponentialBackoff(total_attempts=1) + + +@pytest.mark.asyncio +@mock.patch("asyncio.sleep", return_value=None) +async def test_exponential_backoff_async(mock_time_async): + eb = _exponential_backoff.AsyncExponentialBackoff() + curr_wait = eb._current_wait_in_seconds + iteration_count = 0 + + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for attempt in eb: # pragma: no branch + if attempt == 1: + assert mock_time_async.call_count == 0 + else: + backoff_interval = mock_time_async.call_args[0][0] + jitter = curr_wait * eb._randomization_factor + + assert (curr_wait - jitter) <= backoff_interval <= (curr_wait + jitter) + assert attempt == iteration_count + 1 + assert eb.backoff_count == iteration_count + 1 + assert eb._current_wait_in_seconds == eb._multiplier ** iteration_count + + curr_wait = eb._current_wait_in_seconds + iteration_count += 1 + + assert eb.total_attempts == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS + assert eb.backoff_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS + assert iteration_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS + assert ( + mock_time_async.call_count + == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS - 1 + ) + + +def test_minimum_total_attempts_async(): + with pytest.raises(exceptions.InvalidValue): + _exponential_backoff.AsyncExponentialBackoff(total_attempts=0) + with pytest.raises(exceptions.InvalidValue): + _exponential_backoff.AsyncExponentialBackoff(total_attempts=-1) + _exponential_backoff.AsyncExponentialBackoff(total_attempts=1) diff --git a/tests/transport/aio/test_aiohttp.py b/tests/transport/aio/test_aiohttp.py index 4b1c65f93..00b92911c 100644 --- a/tests/transport/aio/test_aiohttp.py +++ b/tests/transport/aio/test_aiohttp.py @@ -13,30 +13,24 @@ # limitations under the License. import asyncio -from unittest.mock import AsyncMock, Mock, patch -from aioresponses import aioresponses +from aioresponses import aioresponses # type: ignore +from mock import AsyncMock, Mock import pytest # type: ignore -import pytest_asyncio +import pytest_asyncio # type: ignore from google.auth import exceptions import google.auth.aio.transport.aiohttp as auth_aiohttp -from google.auth.exceptions import TimeoutError try: - import aiohttp + import aiohttp # type: ignore except ImportError as caught_exc: # pragma: NO COVER raise ImportError( "The aiohttp library is not installed from please install the aiohttp package to use the aiohttp transport." ) from caught_exc -@pytest.fixture -async def simple_async_task(): - return True - - @pytest.fixture def mock_response(): response = Mock() @@ -91,65 +85,6 @@ async def test_response_content_stream(self, mock_response): assert b"".join(content) == b"Cavefish have no sight." -class TestTimeoutGuard(object): - default_timeout = 1 - - def make_timeout_guard(self, timeout): - return auth_aiohttp.timeout_guard(timeout) - - @pytest.mark.asyncio - async def test_timeout_with_simple_async_task_within_bounds( - self, simple_async_task - ): - task = False - with patch("time.monotonic", side_effect=[0, 0.25, 0.75]): - with patch("asyncio.wait_for", lambda coro, timeout: coro): - async with self.make_timeout_guard( - timeout=self.default_timeout - ) as with_timeout: - task = await with_timeout(simple_async_task) - - # Task succeeds. - assert task is True - - @pytest.mark.asyncio - async def test_timeout_with_simple_async_task_out_of_bounds( - self, simple_async_task - ): - task = False - with patch("time.monotonic", side_effect=[0, 1, 1]): - with patch("asyncio.wait_for", lambda coro, timeout: coro): - with pytest.raises(TimeoutError) as exc: - async with self.make_timeout_guard( - timeout=self.default_timeout - ) as with_timeout: - task = await with_timeout(simple_async_task) - - # Task does not succeed and the context manager times out i.e. no remaining time left. - assert task is False - assert exc.match( - f"Context manager exceeded the configured timeout of {self.default_timeout}s." - ) - - @pytest.mark.asyncio - async def test_timeout_with_async_task_timing_out_before_context( - self, simple_async_task - ): - task = False - with pytest.raises(TimeoutError) as exc: - async with self.make_timeout_guard( - timeout=self.default_timeout - ) as with_timeout: - with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError): - task = await with_timeout(simple_async_task) - - # Task does not complete i.e. the operation times out. - assert task is False - assert exc.match( - f"The operation {simple_async_task} exceeded the configured timeout of {self.default_timeout}s." - ) - - @pytest.mark.asyncio class TestRequest: @pytest_asyncio.fixture diff --git a/tests/transport/aio/test_sessions.py b/tests/transport/aio/test_sessions.py new file mode 100644 index 000000000..49567be9e --- /dev/null +++ b/tests/transport/aio/test_sessions.py @@ -0,0 +1,309 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import AsyncGenerator + +from aioresponses import aioresponses # type: ignore +from mock import Mock, patch +import pytest # type: ignore + +from google.auth.aio.credentials import AnonymousCredentials +from google.auth.aio.transport import ( + _DEFAULT_TIMEOUT_SECONDS, + DEFAULT_MAX_REFRESH_ATTEMPTS, + DEFAULT_RETRYABLE_STATUS_CODES, + Request, + Response, + sessions, +) +from google.auth.exceptions import InvalidType, TimeoutError, TransportError + + +@pytest.fixture +async def simple_async_task(): + return True + + +class MockRequest(Request): + def __init__(self, response=None, side_effect=None): + self._closed = False + self._response = response + self._side_effect = side_effect + self.call_count = 0 + + async def __call__( + self, + url, + method="GET", + body=None, + headers=None, + timeout=_DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ): + self.call_count += 1 + if self._side_effect: + raise self._side_effect + return self._response + + async def close(self): + self._closed = True + return None + + +class MockResponse(Response): + def __init__(self, status_code, headers=None, content=None): + self._status_code = status_code + self._headers = headers + self._content = content + self._close = False + + @property + def status_code(self): + return self._status_code + + @property + def headers(self): + return self._headers + + async def read(self) -> bytes: + content = await self.content(1024) + return b"".join([chunk async for chunk in content]) + + async def content(self, chunk_size=None) -> AsyncGenerator: + return self._content + + async def close(self) -> None: + self._close = True + + +class TestTimeoutGuard(object): + default_timeout = 1 + + def make_timeout_guard(self, timeout): + return sessions.timeout_guard(timeout) + + @pytest.mark.asyncio + async def test_timeout_with_simple_async_task_within_bounds( + self, simple_async_task + ): + task = False + with patch("time.monotonic", side_effect=[0, 0.25, 0.75]): + with patch("asyncio.wait_for", lambda coro, _: coro): + async with self.make_timeout_guard( + timeout=self.default_timeout + ) as with_timeout: + task = await with_timeout(simple_async_task) + + # Task succeeds. + assert task is True + + @pytest.mark.asyncio + async def test_timeout_with_simple_async_task_out_of_bounds( + self, simple_async_task + ): + task = False + with patch("time.monotonic", side_effect=[0, 1, 1]): + with pytest.raises(TimeoutError) as exc: + async with self.make_timeout_guard( + timeout=self.default_timeout + ) as with_timeout: + task = await with_timeout(simple_async_task) + + # Task does not succeed and the context manager times out i.e. no remaining time left. + assert task is False + assert exc.match( + f"Context manager exceeded the configured timeout of {self.default_timeout}s." + ) + + @pytest.mark.asyncio + async def test_timeout_with_async_task_timing_out_before_context( + self, simple_async_task + ): + task = False + with pytest.raises(TimeoutError) as exc: + async with self.make_timeout_guard( + timeout=self.default_timeout + ) as with_timeout: + with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError): + task = await with_timeout(simple_async_task) + + # Task does not complete i.e. the operation times out. + assert task is False + assert exc.match( + f"The operation {simple_async_task} exceeded the configured timeout of {self.default_timeout}s." + ) + + +class TestAuthorizedSession(object): + TEST_URL = "http://example.com/" + credentials = AnonymousCredentials() + + @pytest.fixture + async def mocked_content(self): + content = [b"Cavefish ", b"have ", b"no ", b"sight."] + for chunk in content: + yield chunk + + @pytest.mark.asyncio + async def test_constructor_with_default_auth_request(self): + with patch("google.auth.aio.transport.sessions.AIOHTTP_INSTALLED", True): + authed_session = sessions.AuthorizedSession(self.credentials) + assert authed_session._credentials == self.credentials + await authed_session.close() + + @pytest.mark.asyncio + async def test_constructor_with_provided_auth_request(self): + auth_request = MockRequest() + authed_session = sessions.AuthorizedSession( + self.credentials, auth_request=auth_request + ) + + assert authed_session._auth_request is auth_request + await authed_session.close() + + @pytest.mark.asyncio + async def test_constructor_raises_no_auth_request_error(self): + with patch("google.auth.aio.transport.sessions.AIOHTTP_INSTALLED", False): + with pytest.raises(TransportError) as exc: + sessions.AuthorizedSession(self.credentials) + + exc.match( + "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value." + ) + + @pytest.mark.asyncio + async def test_constructor_raises_incorrect_credentials_error(self): + credentials = Mock() + with pytest.raises(InvalidType) as exc: + sessions.AuthorizedSession(credentials) + + exc.match( + f"The configured credentials of type {type(credentials)} are invalid and must be of type `google.auth.aio.credentials.Credentials`" + ) + + @pytest.mark.asyncio + async def test_request_default_auth_request_success(self): + with aioresponses() as m: + mocked_chunks = [b"Cavefish ", b"have ", b"no ", b"sight."] + mocked_response = b"".join(mocked_chunks) + m.get(self.TEST_URL, status=200, body=mocked_response) + authed_session = sessions.AuthorizedSession(self.credentials) + response = await authed_session.request("GET", self.TEST_URL) + assert response.status_code == 200 + assert response.headers == {"Content-Type": "application/json"} + assert await response.read() == b"Cavefish have no sight." + await response.close() + + await authed_session.close() + + @pytest.mark.asyncio + async def test_request_provided_auth_request_success(self, mocked_content): + mocked_response = MockResponse( + status_code=200, + headers={"Content-Type": "application/json"}, + content=mocked_content, + ) + auth_request = MockRequest(mocked_response) + authed_session = sessions.AuthorizedSession(self.credentials, auth_request) + response = await authed_session.request("GET", self.TEST_URL) + assert response.status_code == 200 + assert response.headers == {"Content-Type": "application/json"} + assert await response.read() == b"Cavefish have no sight." + await response.close() + assert response._close + + await authed_session.close() + + @pytest.mark.asyncio + async def test_request_raises_timeout_error(self): + auth_request = MockRequest(side_effect=asyncio.TimeoutError) + authed_session = sessions.AuthorizedSession(self.credentials, auth_request) + with pytest.raises(TimeoutError): + await authed_session.request("GET", self.TEST_URL) + + @pytest.mark.asyncio + async def test_request_raises_transport_error(self): + auth_request = MockRequest(side_effect=TransportError) + authed_session = sessions.AuthorizedSession(self.credentials, auth_request) + with pytest.raises(TransportError): + await authed_session.request("GET", self.TEST_URL) + + @pytest.mark.asyncio + async def test_request_max_allowed_time_exceeded_error(self): + auth_request = MockRequest(side_effect=TransportError) + authed_session = sessions.AuthorizedSession(self.credentials, auth_request) + with patch("time.monotonic", side_effect=[0, 1, 1]): + with pytest.raises(TimeoutError): + await authed_session.request("GET", self.TEST_URL, max_allowed_time=1) + + @pytest.mark.parametrize("retry_status", DEFAULT_RETRYABLE_STATUS_CODES) + @pytest.mark.asyncio + async def test_request_max_retries(self, retry_status): + mocked_response = MockResponse(status_code=retry_status) + auth_request = MockRequest(mocked_response) + with patch("asyncio.sleep", return_value=None): + authed_session = sessions.AuthorizedSession(self.credentials, auth_request) + await authed_session.request("GET", self.TEST_URL) + assert auth_request.call_count == DEFAULT_MAX_REFRESH_ATTEMPTS + + @pytest.mark.asyncio + async def test_http_get_method_success(self): + expected_payload = b"content is retrieved." + authed_session = sessions.AuthorizedSession(self.credentials) + with aioresponses() as m: + m.get(self.TEST_URL, status=200, body=expected_payload) + response = await authed_session.get(self.TEST_URL) + assert await response.read() == expected_payload + response = await authed_session.close() + + @pytest.mark.asyncio + async def test_http_post_method_success(self): + expected_payload = b"content is posted." + authed_session = sessions.AuthorizedSession(self.credentials) + with aioresponses() as m: + m.post(self.TEST_URL, status=200, body=expected_payload) + response = await authed_session.post(self.TEST_URL) + assert await response.read() == expected_payload + response = await authed_session.close() + + @pytest.mark.asyncio + async def test_http_put_method_success(self): + expected_payload = b"content is retrieved." + authed_session = sessions.AuthorizedSession(self.credentials) + with aioresponses() as m: + m.put(self.TEST_URL, status=200, body=expected_payload) + response = await authed_session.put(self.TEST_URL) + assert await response.read() == expected_payload + response = await authed_session.close() + + @pytest.mark.asyncio + async def test_http_patch_method_success(self): + expected_payload = b"content is retrieved." + authed_session = sessions.AuthorizedSession(self.credentials) + with aioresponses() as m: + m.patch(self.TEST_URL, status=200, body=expected_payload) + response = await authed_session.patch(self.TEST_URL) + assert await response.read() == expected_payload + response = await authed_session.close() + + @pytest.mark.asyncio + async def test_http_delete_method_success(self): + expected_payload = b"content is deleted." + authed_session = sessions.AuthorizedSession(self.credentials) + with aioresponses() as m: + m.delete(self.TEST_URL, status=200, body=expected_payload) + response = await authed_session.delete(self.TEST_URL) + assert await response.read() == expected_payload + response = await authed_session.close()