Skip to content

Commit

Permalink
Add evaluate response callback
Browse files Browse the repository at this point in the history
  • Loading branch information
inyutin committed Aug 6, 2022
1 parent 4818efc commit 5f51f6a
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 7 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class RetryOptionsBase:
statuses: Optional[Iterable[int]] = None, # On which statuses we should retry
exceptions: Optional[Iterable[Type[Exception]]] = None, # On which exceptions we should retry
retry_all_server_errors: bool = True, # If should retry all 500 errors or not
# a callback that will run on response to decide if retry
evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None,
):
...

Expand All @@ -149,6 +151,9 @@ You can define your own timeouts logic or use:
However this response can be None, server didn't make a response or you have set up ```raise_for_status=True```
Look here for an example: https://github.com/inyutin/aiohttp_retry/issues/59

Additionally, you can specify ```evaluate_response_callback```. It receive a ```ClientResponse``` and decide to retry or not by returning a bool.
It can be useful, if server API sometimes response with malformed data.

#### Request Trace Context
`RetryClient` add *current attempt number* to `request_trace_ctx` (see examples,
for more info see [aiohttp doc](https://docs.aiohttp.org/en/stable/client_advanced.html#aiohttp-client-tracing)).
Expand Down
22 changes: 19 additions & 3 deletions aiohttp_retry/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: pass
@abstractmethod
def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: pass

@abstractmethod
def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: pass


# url itself or list of urls for changing between retries
_RAW_URL_TYPE = Union[StrOrURL, YARL_URL]
Expand Down Expand Up @@ -81,10 +84,23 @@ async def _do_request(self) -> ClientResponse:
if self._is_status_code_ok(response.status) or current_attempt == self._retry_options.attempts:
if self._raise_for_status:
response.raise_for_status()
self._response = response
return response

self._logger.debug(f"Retrying after response code: {response.status}")
if self._retry_options.evaluate_response_callback is not None:
try:
is_response_correct = await self._retry_options.evaluate_response_callback(response)
except Exception:
self._logger.exception('while evaluating response an exception occurred')
is_response_correct = False
else:
is_response_correct = True

if is_response_correct:
self._response = response
return response
else:
self._logger.debug(f"Retrying after evaluate response callback check")
else:
self._logger.debug(f"Retrying after response code: {response.status}")
retry_wait = self._retry_options.get_timeout(attempt=current_attempt, response=response)
except Exception as e:
if current_attempt >= self._retry_options.attempts:
Expand Down
37 changes: 33 additions & 4 deletions aiohttp_retry/retry_options.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import abc
import random
from typing import Any, Callable, Iterable, List, Optional, Set, Type
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Set, Type
from warnings import warn

from aiohttp import ClientResponse

EvaluateResponseCallbackType = Callable[[ClientResponse], Awaitable[bool]]


class RetryOptionsBase:
def __init__(
Expand All @@ -13,6 +15,8 @@ def __init__(
statuses: Optional[Iterable[int]] = None, # On which statuses we should retry
exceptions: Optional[Iterable[Type[Exception]]] = None, # On which exceptions we should retry
retry_all_server_errors: bool = True, # If should retry all 500 errors or not
# a callback that will run on response to decide if retry
evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None,
):
self.attempts: int = attempts
if statuses is None:
Expand All @@ -24,6 +28,7 @@ def __init__(
self.exceptions: Iterable[Type[Exception]] = exceptions

self.retry_all_server_errors = retry_all_server_errors
self.evaluate_response_callback = evaluate_response_callback

@abc.abstractmethod
def get_timeout(self, attempt: int, response: Optional[ClientResponse] = None) -> float:
Expand All @@ -40,8 +45,15 @@ def __init__(
statuses: Optional[Set[int]] = None, # On which statuses we should retry
exceptions: Optional[Set[Type[Exception]]] = None, # On which exceptions we should retry
retry_all_server_errors: bool = True,
evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None,
):
super().__init__(attempts, statuses, exceptions, retry_all_server_errors)
super().__init__(
attempts=attempts,
statuses=statuses,
exceptions=exceptions,
retry_all_server_errors=retry_all_server_errors,
evaluate_response_callback=evaluate_response_callback,
)

self._start_timeout: float = start_timeout
self._max_timeout: float = max_timeout
Expand All @@ -68,8 +80,16 @@ def __init__(
max_timeout: float = 3.0, # Maximum possible timeout between tries
random_func: Callable[[], float] = random.random, # Random number generator
retry_all_server_errors: bool = True,
evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None,
):
super().__init__(attempts, statuses, exceptions, retry_all_server_errors)
super().__init__(
attempts=attempts,
statuses=statuses,
exceptions=exceptions,
retry_all_server_errors=retry_all_server_errors,
evaluate_response_callback=evaluate_response_callback,
)

self.attempts: int = attempts
self.min_timeout: float = min_timeout
self.max_timeout: float = max_timeout
Expand Down Expand Up @@ -105,8 +125,15 @@ def __init__(
exceptions: Optional[Iterable[Type[Exception]]] = None,
max_timeout: float = 3.0, # Maximum possible timeout between tries
retry_all_server_errors: bool = True,
evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None,
):
super().__init__(attempts, statuses, exceptions, retry_all_server_errors)
super().__init__(
attempts=attempts,
statuses=statuses,
exceptions=exceptions,
retry_all_server_errors=retry_all_server_errors,
evaluate_response_callback=evaluate_response_callback,
)

self.max_timeout = max_timeout
self.multiplier = multiplier
Expand Down Expand Up @@ -134,6 +161,7 @@ def __init__(
exceptions: Optional[Set[Type[Exception]]] = None, # On which exceptions we should retry
random_interval_size: float = 2.0, # size of interval for random component
retry_all_server_errors: bool = True,
evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None,
):
super().__init__(
attempts=attempts,
Expand All @@ -143,6 +171,7 @@ def __init__(
statuses=statuses,
exceptions=exceptions,
retry_all_server_errors=retry_all_server_errors,
evaluate_response_callback=evaluate_response_callback,
)

self._start_timeout: float = start_timeout
Expand Down
8 changes: 8 additions & 0 deletions tests/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(self):
app.router.add_get('/internal_error', self.internal_error_handler)
app.router.add_get('/not_found_error', self.not_found_error_handler)
app.router.add_get('/sometimes_error', self.sometimes_error)
app.router.add_get('/sometimes_json', self.sometimes_json)

app.router.add_options('/options_handler', self.ping_handler)
app.router.add_head('/head_handler', self.ping_handler)
Expand Down Expand Up @@ -39,6 +40,13 @@ async def sometimes_error(self, _: web.Request) -> web.Response:

raise web.HTTPInternalServerError()

async def sometimes_json(self, _: web.Request) -> web.Response:
self.counter += 1
if self.counter == 3:
return web.json_response(data={'status': 'Ok!'}, status=200)

return web.Response(text='Ok!', status=200)

@property
def web_app(self) -> web.Application:
return self._web_app
20 changes: 20 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from aiohttp import (
ClientResponse,
ClientResponseError,
ClientSession,
TraceConfig,
Expand Down Expand Up @@ -344,3 +345,22 @@ async def test_implicit_client(aiohttp_client):
assert response.status == 200

await retry_client.close()


async def test_evaluate_response_callback(aiohttp_client):
async def evaluate_response(response: ClientResponse) -> bool:
try:
await response.json()
except:
return False
return True

retry_options = ExponentialRetry(attempts=5, evaluate_response_callback=evaluate_response)
retry_client, test_app = await get_retry_client_and_test_app_for_test(aiohttp_client, retry_options=retry_options)

async with retry_client.get('/sometimes_json') as response:
body = await response.json()
assert response.status == 200
assert body == {'status': 'Ok!'}

assert test_app.counter == 3

0 comments on commit 5f51f6a

Please sign in to comment.