From e49345c4272e13d52b83965c51fe7b73f37201bf Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 8 Aug 2023 21:13:03 +0100 Subject: [PATCH] [App] Client retries forever (#18065) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> (cherry picked from commit 176e4568142ae97eb042568ce2f781325415d1b5) --- src/lightning/app/plugin/plugin.py | 4 ++ src/lightning/app/runners/backends/cloud.py | 10 ++++- src/lightning/app/utilities/network.py | 42 ++++++++++++--------- tests/tests_app/plugin/test_plugin.py | 4 +- tests/tests_app/utilities/test_network.py | 18 ++++++++- 5 files changed, 56 insertions(+), 22 deletions(-) diff --git a/src/lightning/app/plugin/plugin.py b/src/lightning/app/plugin/plugin.py index 5abffbaeb7352..db66a4ad245ba 100644 --- a/src/lightning/app/plugin/plugin.py +++ b/src/lightning/app/plugin/plugin.py @@ -33,6 +33,8 @@ logger = Logger(__name__) +_PLUGIN_MAX_CLIENT_TRIES: int = 3 + class LightningPlugin: """A ``LightningPlugin`` is a single-file Python class that can be executed within a cloudspace to perform @@ -59,6 +61,7 @@ def run_job(self, name: str, app_entrypoint: str, env_vars: Dict[str, str] = {}) Returns: The relative URL of the created job. """ + from lightning.app.runners.backends.cloud import CloudBackend from lightning.app.runners.cloud import CloudRuntime logger.info(f"Processing job run request. name: {name}, app_entrypoint: {app_entrypoint}, env_vars: {env_vars}") @@ -79,6 +82,7 @@ def run_job(self, name: str, app_entrypoint: str, env_vars: Dict[str, str] = {}) env_vars=env_vars, secrets={}, run_app_comment_commands=True, + backend=CloudBackend(entrypoint_file, client_max_tries=_PLUGIN_MAX_CLIENT_TRIES), ) # Used to indicate Lightning has been dispatched os.environ["LIGHTNING_DISPATCHED"] = "1" diff --git a/src/lightning/app/runners/backends/cloud.py b/src/lightning/app/runners/backends/cloud.py index 186492aa10c4a..ebf45a3f27251 100644 --- a/src/lightning/app/runners/backends/cloud.py +++ b/src/lightning/app/runners/backends/cloud.py @@ -23,9 +23,15 @@ class CloudBackend(Backend): - def __init__(self, entrypoint_file, queue_id: Optional[str] = None, status_update_interval: Optional[int] = None): + def __init__( + self, + entrypoint_file, + queue_id: Optional[str] = None, + status_update_interval: Optional[int] = None, + client_max_tries: Optional[int] = None, + ): super().__init__(entrypoint_file, queues=QueuingSystem.MULTIPROCESS, queue_id=queue_id) - self.client = LightningClient() + self.client = LightningClient(max_tries=client_max_tries) def create_work(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork") -> None: raise NotImplementedError diff --git a/src/lightning/app/utilities/network.py b/src/lightning/app/utilities/network.py index 6ff4871e7c6cd..3d80479c3fe53 100644 --- a/src/lightning/app/utilities/network.py +++ b/src/lightning/app/utilities/network.py @@ -123,7 +123,7 @@ def _get_next_backoff_time(num_retries: int, backoff_value: float = 0.5) -> floa return min(_DEFAULT_BACKOFF_MAX, next_backoff_value) -def _retry_wrapper(self, func: Callable) -> Callable: +def _retry_wrapper(self, func: Callable, max_tries: Optional[int] = None) -> Callable: """Returns the function decorated by a wrapper that retries the call several times if a connection error occurs. @@ -133,31 +133,36 @@ def _retry_wrapper(self, func: Callable) -> Callable: @wraps(func) def wrapped(*args: Any, **kwargs: Any) -> Any: consecutive_errors = 0 - while _get_next_backoff_time(consecutive_errors) != _DEFAULT_BACKOFF_MAX: + + while True: try: return func(self, *args, **kwargs) - except lightning_cloud.openapi.rest.ApiException as ex: - # retry if the control plane fails with all errors except 4xx but not 408 - (Request Timeout) - if ex.status == 408 or ex.status == 409 or not str(ex.status).startswith("4"): + except (lightning_cloud.openapi.rest.ApiException, urllib3.exceptions.HTTPError) as ex: + # retry if the backend fails with all errors except 4xx but not 408 - (Request Timeout) + if ( + isinstance(ex, urllib3.exceptions.HTTPError) + or ex.status in (408, 409) + or not str(ex.status).startswith("4") + ): consecutive_errors += 1 backoff_time = _get_next_backoff_time(consecutive_errors) + + msg = ( + f"error: {str(ex)}" + if isinstance(ex, urllib3.exceptions.HTTPError) + else f"response: {ex.status}" + ) logger.debug( - f"The {func.__name__} request failed to reach the server, got a response {ex.status}." + f"The {func.__name__} request failed to reach the server, {msg}." f" Retrying after {backoff_time} seconds." ) + + if max_tries is not None and consecutive_errors == max_tries: + raise Exception(f"The {func.__name__} request failed to reach the server, {msg}.") + time.sleep(backoff_time) else: raise ex - except urllib3.exceptions.HTTPError as ex: - consecutive_errors += 1 - backoff_time = _get_next_backoff_time(consecutive_errors) - logger.debug( - f"The {func.__name__} request failed to reach the server, got a an error {str(ex)}." - f" Retrying after {backoff_time} seconds." - ) - time.sleep(backoff_time) - - raise Exception(f"The default maximum backoff {_DEFAULT_BACKOFF_MAX} seconds has been reached.") return wrapped @@ -169,15 +174,16 @@ class LightningClient(GridRestClient): Args: retry: Whether API calls should follow a retry mechanism with exponential backoff. + max_tries: Maximum number of attempts (or -1 to retry forever). """ - def __init__(self, retry: bool = True) -> None: + def __init__(self, retry: bool = True, max_tries: Optional[int] = None) -> None: super().__init__(api_client=create_swagger_client()) if retry: for base_class in GridRestClient.__mro__: for name, attribute in base_class.__dict__.items(): if callable(attribute) and attribute.__name__ != "__init__": - setattr(self, name, _retry_wrapper(self, attribute)) + setattr(self, name, _retry_wrapper(self, attribute, max_tries=max_tries)) class CustomRetryAdapter(HTTPAdapter): diff --git a/tests/tests_app/plugin/test_plugin.py b/tests/tests_app/plugin/test_plugin.py index 6eb16e15b44d7..eaa75a60dc562 100644 --- a/tests/tests_app/plugin/test_plugin.py +++ b/tests/tests_app/plugin/test_plugin.py @@ -192,9 +192,10 @@ def run(self, name, entrypoint): (_plugin_with_job_run_navigate, [{"content": "/testing", "type": "NAVIGATE_TO"}]), ], ) +@mock.patch("lightning.app.runners.backends.cloud.CloudBackend") @mock.patch("lightning.app.runners.cloud.CloudRuntime") @mock.patch("lightning.app.plugin.plugin.requests") -def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server, plugin_source, actions): +def test_run_job(mock_requests, mock_cloud_runtime, mock_cloud_backend, mock_plugin_server, plugin_source, actions): """Tests that running a job from a plugin calls the correct `CloudRuntime` methods with the correct arguments.""" content = as_tar_bytes("plugin.py", plugin_source) @@ -228,6 +229,7 @@ def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server, plugin_s env_vars={}, secrets={}, run_app_comment_commands=True, + backend=mock.ANY, ) mock_cloud_runtime().cloudspace_dispatch.assert_called_once_with( diff --git a/tests/tests_app/utilities/test_network.py b/tests/tests_app/utilities/test_network.py index 8c25e8305f2f2..30e5e12c7765e 100644 --- a/tests/tests_app/utilities/test_network.py +++ b/tests/tests_app/utilities/test_network.py @@ -1,9 +1,11 @@ +import re from unittest import mock import pytest +from urllib3.exceptions import HTTPError from lightning.app.core import constants -from lightning.app.utilities.network import find_free_network_port, LightningClient +from lightning.app.utilities.network import _retry_wrapper, find_free_network_port, LightningClient def test_find_free_network_port(): @@ -54,3 +56,17 @@ def test_lightning_client_retry_enabled(): client = LightningClient(retry=True) assert hasattr(client.auth_service_get_user_with_http_info, "__wrapped__") + + +@mock.patch("time.sleep") +def test_retry_wrapper_max_tries(_): + mock_client = mock.MagicMock() + mock_client.test.__name__ = "test" + mock_client.test.side_effect = HTTPError("failed") + + wrapped_mock_client = _retry_wrapper(mock_client, mock_client.test, max_tries=3) + + with pytest.raises(Exception, match=re.escape("The test request failed to reach the server, error: failed")): + wrapped_mock_client() + + assert mock_client.test.call_count == 3