diff --git a/requirements/app/base.txt b/requirements/app/base.txt index e9e1d284e1434..80607f79a91a2 100644 --- a/requirements/app/base.txt +++ b/requirements/app/base.txt @@ -13,6 +13,7 @@ inquirer >=2.10.0, <=3.1.3 psutil <5.9.5 click <=8.1.3 python-multipart>=0.0.5, <=0.0.6 +backoff >=2.2.1, <2.3.0 fastapi >=0.92.0, <0.100.0 starlette # https://fastapi.tiangolo.com/deployment/versions/#about-starlette diff --git a/src/lightning/app/core/queues.py b/src/lightning/app/core/queues.py index 5b27b601fb5e5..18e02dd989b2e 100644 --- a/src/lightning/app/core/queues.py +++ b/src/lightning/app/core/queues.py @@ -23,6 +23,7 @@ from typing import Any, Optional, Tuple from urllib.parse import urljoin +import backoff import requests from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout @@ -431,6 +432,7 @@ def _get(self) -> Any: # we consider the queue is empty to avoid failing the app. raise queue.Empty + @backoff.on_exception(backoff.expo, (RuntimeError, requests.exceptions.HTTPError)) def put(self, item: Any) -> None: if not self.app_id: raise ValueError(f"The Lightning App ID couldn't be extracted from the queue name: {self.name}") diff --git a/tests/tests_app/core/test_queues.py b/tests/tests_app/core/test_queues.py index d00432b734aad..8dd6d7d3a0b32 100644 --- a/tests/tests_app/core/test_queues.py +++ b/tests/tests_app/core/test_queues.py @@ -217,13 +217,21 @@ def test_http_queue_get(self, monkeypatch): def test_unreachable_queue(monkeypatch): monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token") + test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue") - resp = mock.MagicMock() - resp.status_code = 204 + resp1 = mock.MagicMock() + resp1.status_code = 204 + + resp2 = mock.MagicMock() + resp2.status_code = 201 test_queue.client = mock.MagicMock() - test_queue.client.post.return_value = resp + test_queue.client.post = mock.Mock(side_effect=[resp1, resp1, resp2]) with pytest.raises(queue.Empty): test_queue._get() + + # Test backoff on queue.put + test_queue.put("foo") + assert test_queue.client.post.call_count == 3