Skip to content

Commit

Permalink
feat: Implement asynchronous AuthorizedSession class (#1580)
Browse files Browse the repository at this point in the history
* feat: Implement Asynchronous AuthorizedSession class

* add comment for implementing locks within refresh

* move timeout guard to sessions

* add unit tests and code cleanup

* implement async exponential backoff iterator

* cleanup

* add testing for http methods and cleanup

* update number of retries to 3

* refactor test cases

* fix linter and mypy issues

* fix pytest code coverage
  • Loading branch information
ohmayr authored Aug 15, 2024
1 parent 5f46b60 commit 8833ad6
Show file tree
Hide file tree
Showing 7 changed files with 715 additions and 150 deletions.
77 changes: 61 additions & 16 deletions google/auth/_exponential_backoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import random
import time

Expand All @@ -38,9 +39,8 @@
"""


class ExponentialBackoff:
"""An exponential backoff iterator. This can be used in a for loop to
perform requests with exponential backoff.
class _BaseExponentialBackoff:
"""An exponential backoff iterator base class.
Args:
total_attempts Optional[int]:
Expand Down Expand Up @@ -84,9 +84,40 @@ def __init__(
self._multiplier = multiplier
self._backoff_count = 0

def __iter__(self):
@property
def total_attempts(self):
"""The total amount of backoff attempts that will be made."""
return self._total_attempts

@property
def backoff_count(self):
"""The current amount of backoff attempts that have been made."""
return self._backoff_count

def _reset(self):
self._backoff_count = 0
self._current_wait_in_seconds = self._initial_wait_seconds

def _calculate_jitter(self):
jitter_variance = self._current_wait_in_seconds * self._randomization_factor
jitter = random.uniform(
self._current_wait_in_seconds - jitter_variance,
self._current_wait_in_seconds + jitter_variance,
)

return jitter


class ExponentialBackoff(_BaseExponentialBackoff):
"""An exponential backoff iterator. This can be used in a for loop to
perform requests with exponential backoff.
"""

def __init__(self, *args, **kwargs):
super(ExponentialBackoff, self).__init__(*args, **kwargs)

def __iter__(self):
self._reset()
return self

def __next__(self):
Expand All @@ -97,23 +128,37 @@ def __next__(self):
if self._backoff_count <= 1:
return self._backoff_count

jitter_variance = self._current_wait_in_seconds * self._randomization_factor
jitter = random.uniform(
self._current_wait_in_seconds - jitter_variance,
self._current_wait_in_seconds + jitter_variance,
)
jitter = self._calculate_jitter()

time.sleep(jitter)

self._current_wait_in_seconds *= self._multiplier
return self._backoff_count

@property
def total_attempts(self):
"""The total amount of backoff attempts that will be made."""
return self._total_attempts

@property
def backoff_count(self):
"""The current amount of backoff attempts that have been made."""
class AsyncExponentialBackoff(_BaseExponentialBackoff):
"""An async exponential backoff iterator. This can be used in a for loop to
perform async requests with exponential backoff.
"""

def __init__(self, *args, **kwargs):
super(AsyncExponentialBackoff, self).__init__(*args, **kwargs)

def __aiter__(self):
self._reset()
return self

async def __anext__(self):
if self._backoff_count >= self._total_attempts:
raise StopAsyncIteration
self._backoff_count += 1

if self._backoff_count <= 1:
return self._backoff_count

jitter = self._calculate_jitter()

await asyncio.sleep(jitter)

self._current_wait_in_seconds *= self._multiplier
return self._backoff_count
29 changes: 23 additions & 6 deletions google/auth/aio/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,24 @@
"""

import abc
from typing import AsyncGenerator, Dict, Mapping, Optional
from typing import AsyncGenerator, Mapping, Optional

import google.auth.transport


_DEFAULT_TIMEOUT_SECONDS = 180

DEFAULT_RETRYABLE_STATUS_CODES = google.auth.transport.DEFAULT_RETRYABLE_STATUS_CODES
"""Sequence[int]: HTTP status codes indicating a request can be retried.
"""

DEFAULT_REFRESH_STATUS_CODES = google.auth.transport.DEFAULT_REFRESH_STATUS_CODES
"""Sequence[int]: Which HTTP status code indicate that credentials should be
refreshed.
"""

DEFAULT_MAX_REFRESH_ATTEMPTS = 3
"""int: How many times to refresh the credentials and retry a request."""


class Response(metaclass=abc.ABCMeta):
Expand All @@ -35,7 +52,7 @@ class Response(metaclass=abc.ABCMeta):
@abc.abstractmethod
def status_code(self) -> int:
"""
The HTTP response status code..
The HTTP response status code.
Returns:
int: The HTTP response status code.
Expand All @@ -45,11 +62,11 @@ def status_code(self) -> int:

@property
@abc.abstractmethod
def headers(self) -> Dict[str, str]:
def headers(self) -> Mapping[str, str]:
"""The HTTP response headers.
Returns:
Dict[str, str]: The HTTP response headers.
Mapping[str, str]: The HTTP response headers.
"""
raise NotImplementedError("headers must be implemented.")

Expand Down Expand Up @@ -95,7 +112,7 @@ async def __call__(
self,
url: str,
method: str,
body: bytes,
body: Optional[bytes],
headers: Optional[Mapping[str, str]],
timeout: float,
**kwargs
Expand All @@ -106,7 +123,7 @@ async def __call__(
url (str): The URI to be requested.
method (str): The HTTP method to use for the request. Defaults
to 'GET'.
body (bytes): The payload / body in HTTP request.
body (Optional[bytes]): The payload / body in HTTP request.
headers (Mapping[str, str]): Request headers.
timeout (float): The number of seconds to wait for a
response from the server. If not specified or if None, the
Expand Down
68 changes: 9 additions & 59 deletions google/auth/aio/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
"""

import asyncio
from contextlib import asynccontextmanager
import time
from typing import AsyncGenerator, Dict, Mapping, Optional
from typing import AsyncGenerator, Mapping, Optional

try:
import aiohttp
import aiohttp # type: ignore
except ImportError as caught_exc: # pragma: NO COVER
raise ImportError(
"The aiohttp library is not installed from please install the aiohttp package to use the aiohttp transport."
Expand All @@ -30,54 +28,6 @@
from google.auth import _helpers
from google.auth import exceptions
from google.auth.aio import transport
from google.auth.exceptions import TimeoutError


_DEFAULT_TIMEOUT_SECONDS = 180


@asynccontextmanager
async def timeout_guard(timeout):
"""
timeout_guard is an asynchronous context manager to apply a timeout to an asynchronous block of code.
Args:
timeout (float): The time in seconds before the context manager times out.
Raises:
google.auth.exceptions.TimeoutError: If the code within the context exceeds the provided timeout.
Usage:
async with timeout_guard(10) as with_timeout:
await with_timeout(async_function())
"""
start = time.monotonic()
total_timeout = timeout

def _remaining_time():
elapsed = time.monotonic() - start
remaining = total_timeout - elapsed
if remaining <= 0:
raise TimeoutError(
f"Context manager exceeded the configured timeout of {total_timeout}s."
)
return remaining

async def with_timeout(coro):
try:
remaining = _remaining_time()
response = await asyncio.wait_for(coro, remaining)
return response
except (asyncio.TimeoutError, TimeoutError) as e:
raise TimeoutError(
f"The operation {coro} exceeded the configured timeout of {total_timeout}s."
) from e

try:
yield with_timeout

finally:
_remaining_time()


class Response(transport.Response):
Expand All @@ -89,7 +39,7 @@ class Response(transport.Response):
Attributes:
status_code (int): The HTTP status code of the response.
headers (Dict[str, str]): A case-insensitive multidict proxy wiht HTTP headers of response.
headers (Mapping[str, str]): The HTTP headers of the response.
"""

def __init__(self, response: aiohttp.ClientResponse):
Expand All @@ -102,7 +52,7 @@ def status_code(self) -> int:

@property
@_helpers.copy_docstring(transport.Response)
def headers(self) -> Dict[str, str]:
def headers(self) -> Mapping[str, str]:
return {key: value for key, value in self._response.headers.items()}

@_helpers.copy_docstring(transport.Response)
Expand Down Expand Up @@ -158,7 +108,7 @@ async def __call__(
method: str = "GET",
body: Optional[bytes] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: float = _DEFAULT_TIMEOUT_SECONDS,
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
**kwargs,
) -> transport.Response:
"""
Expand Down Expand Up @@ -199,14 +149,14 @@ async def __call__(
return Response(response)

except aiohttp.ClientError as caught_exc:
new_exc = exceptions.TransportError(f"Failed to send request to {url}.")
raise new_exc from caught_exc
client_exc = exceptions.TransportError(f"Failed to send request to {url}.")
raise client_exc from caught_exc

except asyncio.TimeoutError as caught_exc:
new_exc = exceptions.TimeoutError(
timeout_exc = exceptions.TimeoutError(
f"Request timed out after {timeout} seconds."
)
raise new_exc from caught_exc
raise timeout_exc from caught_exc

async def close(self) -> None:
"""
Expand Down
Loading

0 comments on commit 8833ad6

Please sign in to comment.