diff --git a/README.md b/README.md index 8834edc..219bdbf 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,8 @@ class RetryOptionsBase: statuses: Optional[Iterable[int]] = None, # On which statuses we should retry exceptions: Optional[Iterable[Type[Exception]]] = None, # On which exceptions we should retry retry_all_server_errors: bool = True, # If should retry all 500 errors or not + # a callback that will run on response to decide if retry + evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None, ): ... @@ -149,6 +151,9 @@ You can define your own timeouts logic or use: However this response can be None, server didn't make a response or you have set up ```raise_for_status=True``` Look here for an example: https://github.com/inyutin/aiohttp_retry/issues/59 +Additionally, you can specify ```evaluate_response_callback```. It receive a ```ClientResponse``` and decide to retry or not by returning a bool. +It can be useful, if server API sometimes response with malformed data. + #### Request Trace Context `RetryClient` add *current attempt number* to `request_trace_ctx` (see examples, for more info see [aiohttp doc](https://docs.aiohttp.org/en/stable/client_advanced.html#aiohttp-client-tracing)). diff --git a/aiohttp_retry/client.py b/aiohttp_retry/client.py index c3edb29..3be9e98 100644 --- a/aiohttp_retry/client.py +++ b/aiohttp_retry/client.py @@ -27,6 +27,9 @@ def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: pass @abstractmethod def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: pass + @abstractmethod + def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: pass + # url itself or list of urls for changing between retries _RAW_URL_TYPE = Union[StrOrURL, YARL_URL] @@ -81,10 +84,23 @@ async def _do_request(self) -> ClientResponse: if self._is_status_code_ok(response.status) or current_attempt == self._retry_options.attempts: if self._raise_for_status: response.raise_for_status() - self._response = response - return response - self._logger.debug(f"Retrying after response code: {response.status}") + if self._retry_options.evaluate_response_callback is not None: + try: + is_response_correct = await self._retry_options.evaluate_response_callback(response) + except Exception: + self._logger.exception('while evaluating response an exception occurred') + is_response_correct = False + else: + is_response_correct = True + + if is_response_correct: + self._response = response + return response + else: + self._logger.debug(f"Retrying after evaluate response callback check") + else: + self._logger.debug(f"Retrying after response code: {response.status}") retry_wait = self._retry_options.get_timeout(attempt=current_attempt, response=response) except Exception as e: if current_attempt >= self._retry_options.attempts: diff --git a/aiohttp_retry/retry_options.py b/aiohttp_retry/retry_options.py index eefdf68..d5a2506 100644 --- a/aiohttp_retry/retry_options.py +++ b/aiohttp_retry/retry_options.py @@ -1,10 +1,12 @@ import abc import random -from typing import Any, Callable, Iterable, List, Optional, Set, Type +from typing import Any, Awaitable, Callable, Iterable, List, Optional, Set, Type from warnings import warn from aiohttp import ClientResponse +EvaluateResponseCallbackType = Callable[[ClientResponse], Awaitable[bool]] + class RetryOptionsBase: def __init__( @@ -13,6 +15,8 @@ def __init__( statuses: Optional[Iterable[int]] = None, # On which statuses we should retry exceptions: Optional[Iterable[Type[Exception]]] = None, # On which exceptions we should retry retry_all_server_errors: bool = True, # If should retry all 500 errors or not + # a callback that will run on response to decide if retry + evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None, ): self.attempts: int = attempts if statuses is None: @@ -24,6 +28,7 @@ def __init__( self.exceptions: Iterable[Type[Exception]] = exceptions self.retry_all_server_errors = retry_all_server_errors + self.evaluate_response_callback = evaluate_response_callback @abc.abstractmethod def get_timeout(self, attempt: int, response: Optional[ClientResponse] = None) -> float: @@ -40,8 +45,15 @@ def __init__( statuses: Optional[Set[int]] = None, # On which statuses we should retry exceptions: Optional[Set[Type[Exception]]] = None, # On which exceptions we should retry retry_all_server_errors: bool = True, + evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None, ): - super().__init__(attempts, statuses, exceptions, retry_all_server_errors) + super().__init__( + attempts=attempts, + statuses=statuses, + exceptions=exceptions, + retry_all_server_errors=retry_all_server_errors, + evaluate_response_callback=evaluate_response_callback, + ) self._start_timeout: float = start_timeout self._max_timeout: float = max_timeout @@ -68,8 +80,16 @@ def __init__( max_timeout: float = 3.0, # Maximum possible timeout between tries random_func: Callable[[], float] = random.random, # Random number generator retry_all_server_errors: bool = True, + evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None, ): - super().__init__(attempts, statuses, exceptions, retry_all_server_errors) + super().__init__( + attempts=attempts, + statuses=statuses, + exceptions=exceptions, + retry_all_server_errors=retry_all_server_errors, + evaluate_response_callback=evaluate_response_callback, + ) + self.attempts: int = attempts self.min_timeout: float = min_timeout self.max_timeout: float = max_timeout @@ -105,8 +125,15 @@ def __init__( exceptions: Optional[Iterable[Type[Exception]]] = None, max_timeout: float = 3.0, # Maximum possible timeout between tries retry_all_server_errors: bool = True, + evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None, ): - super().__init__(attempts, statuses, exceptions, retry_all_server_errors) + super().__init__( + attempts=attempts, + statuses=statuses, + exceptions=exceptions, + retry_all_server_errors=retry_all_server_errors, + evaluate_response_callback=evaluate_response_callback, + ) self.max_timeout = max_timeout self.multiplier = multiplier @@ -134,6 +161,7 @@ def __init__( exceptions: Optional[Set[Type[Exception]]] = None, # On which exceptions we should retry random_interval_size: float = 2.0, # size of interval for random component retry_all_server_errors: bool = True, + evaluate_response_callback: Optional[EvaluateResponseCallbackType] = None, ): super().__init__( attempts=attempts, @@ -143,6 +171,7 @@ def __init__( statuses=statuses, exceptions=exceptions, retry_all_server_errors=retry_all_server_errors, + evaluate_response_callback=evaluate_response_callback, ) self._start_timeout: float = start_timeout diff --git a/tests/app.py b/tests/app.py index 3569a09..8653d43 100644 --- a/tests/app.py +++ b/tests/app.py @@ -10,6 +10,7 @@ def __init__(self): app.router.add_get('/internal_error', self.internal_error_handler) app.router.add_get('/not_found_error', self.not_found_error_handler) app.router.add_get('/sometimes_error', self.sometimes_error) + app.router.add_get('/sometimes_json', self.sometimes_json) app.router.add_options('/options_handler', self.ping_handler) app.router.add_head('/head_handler', self.ping_handler) @@ -39,6 +40,13 @@ async def sometimes_error(self, _: web.Request) -> web.Response: raise web.HTTPInternalServerError() + async def sometimes_json(self, _: web.Request) -> web.Response: + self.counter += 1 + if self.counter == 3: + return web.json_response(data={'status': 'Ok!'}, status=200) + + return web.Response(text='Ok!', status=200) + @property def web_app(self) -> web.Application: return self._web_app diff --git a/tests/test_client.py b/tests/test_client.py index c55249e..75b6849 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,6 +3,7 @@ import pytest from aiohttp import ( + ClientResponse, ClientResponseError, ClientSession, TraceConfig, @@ -344,3 +345,22 @@ async def test_implicit_client(aiohttp_client): assert response.status == 200 await retry_client.close() + + +async def test_evaluate_response_callback(aiohttp_client): + async def evaluate_response(response: ClientResponse) -> bool: + try: + await response.json() + except: + return False + return True + + retry_options = ExponentialRetry(attempts=5, evaluate_response_callback=evaluate_response) + retry_client, test_app = await get_retry_client_and_test_app_for_test(aiohttp_client, retry_options=retry_options) + + async with retry_client.get('/sometimes_json') as response: + body = await response.json() + assert response.status == 200 + assert body == {'status': 'Ok!'} + + assert test_app.counter == 3