From 3ab9e1d5aca236986aa10f95e3bbd7ffb9be0e25 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Fri, 18 Oct 2024 07:38:08 +0530 Subject: [PATCH 1/2] include timeout error to be retryable --- .../providers/databricks/hooks/databricks_base.py | 4 ++-- providers/tests/databricks/hooks/test_databricks.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/databricks/hooks/databricks_base.py b/providers/src/airflow/providers/databricks/hooks/databricks_base.py index 08a6eb8d40cbf..74498ecb233a4 100644 --- a/providers/src/airflow/providers/databricks/hooks/databricks_base.py +++ b/providers/src/airflow/providers/databricks/hooks/databricks_base.py @@ -34,7 +34,7 @@ import aiohttp import requests -from aiohttp.client_exceptions import ClientConnectorError +from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError from requests import PreparedRequest, exceptions as requests_exceptions from requests.auth import AuthBase, HTTPBasicAuth from requests.exceptions import JSONDecodeError @@ -679,7 +679,7 @@ def _retryable_error(exception: BaseException) -> bool: if exception.status >= 500 or exception.status == 429: return True - if isinstance(exception, ClientConnectorError): + if isinstance(exception, (ClientConnectorError, ServerTimeoutError)): return True return False diff --git a/providers/tests/databricks/hooks/test_databricks.py b/providers/tests/databricks/hooks/test_databricks.py index bec238f70e39c..323f5730c33e6 100644 --- a/providers/tests/databricks/hooks/test_databricks.py +++ b/providers/tests/databricks/hooks/test_databricks.py @@ -1551,6 +1551,16 @@ async def test_do_api_call_retries_with_client_connector_error(self, mock_get): await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {}) assert mock_errors.call_count == DEFAULT_RETRY_NUMBER + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") + async def test_do_api_call_retries_with_client_timeout_error(self, mock_get): + mock_get.side_effect = aiohttp.ServerTimeoutError() + with mock.patch.object(self.hook.log, "error") as mock_errors: + async with self.hook: + with pytest.raises(AirflowException): + await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {}) + assert mock_errors.call_count == DEFAULT_RETRY_NUMBER + @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") async def test_do_api_call_retries_with_retryable_error(self, mock_get): From 4d31871be67bf9892b6e70539bf4ce66e85bfffb Mon Sep 17 00:00:00 2001 From: kalyanr Date: Fri, 18 Oct 2024 12:18:29 +0530 Subject: [PATCH 2/2] use TimeoutError --- .../airflow/providers/databricks/hooks/databricks_base.py | 5 +++-- providers/tests/databricks/hooks/test_databricks.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/providers/src/airflow/providers/databricks/hooks/databricks_base.py b/providers/src/airflow/providers/databricks/hooks/databricks_base.py index 74498ecb233a4..1a4ccb6e980b3 100644 --- a/providers/src/airflow/providers/databricks/hooks/databricks_base.py +++ b/providers/src/airflow/providers/databricks/hooks/databricks_base.py @@ -28,13 +28,14 @@ import copy import platform import time +from asyncio.exceptions import TimeoutError from functools import cached_property from typing import TYPE_CHECKING, Any from urllib.parse import urlsplit import aiohttp import requests -from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError +from aiohttp.client_exceptions import ClientConnectorError from requests import PreparedRequest, exceptions as requests_exceptions from requests.auth import AuthBase, HTTPBasicAuth from requests.exceptions import JSONDecodeError @@ -679,7 +680,7 @@ def _retryable_error(exception: BaseException) -> bool: if exception.status >= 500 or exception.status == 429: return True - if isinstance(exception, (ClientConnectorError, ServerTimeoutError)): + if isinstance(exception, (ClientConnectorError, TimeoutError)): return True return False diff --git a/providers/tests/databricks/hooks/test_databricks.py b/providers/tests/databricks/hooks/test_databricks.py index 323f5730c33e6..94a8cfb9c4c56 100644 --- a/providers/tests/databricks/hooks/test_databricks.py +++ b/providers/tests/databricks/hooks/test_databricks.py @@ -21,6 +21,7 @@ import json import ssl import time +from asyncio.exceptions import TimeoutError from unittest import mock from unittest.mock import AsyncMock @@ -1554,7 +1555,7 @@ async def test_do_api_call_retries_with_client_connector_error(self, mock_get): @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") async def test_do_api_call_retries_with_client_timeout_error(self, mock_get): - mock_get.side_effect = aiohttp.ServerTimeoutError() + mock_get.side_effect = TimeoutError() with mock.patch.object(self.hook.log, "error") as mock_errors: async with self.hook: with pytest.raises(AirflowException):