Skip to content

Commit

Permalink
Add ClientConnectorError to be a retryable error in databricks provid…
Browse files Browse the repository at this point in the history
…er (apache#43091)

closes apache#43080
  • Loading branch information
rawwar authored and PaulKobow7536 committed Oct 24, 2024
1 parent 6ae788c commit 6984b31
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import aiohttp
import requests
from aiohttp.client_exceptions import ClientConnectorError
from requests import PreparedRequest, exceptions as requests_exceptions
from requests.auth import AuthBase, HTTPBasicAuth
from requests.exceptions import JSONDecodeError
Expand Down Expand Up @@ -678,6 +679,9 @@ def _retryable_error(exception: BaseException) -> bool:
if exception.status >= 500 or exception.status == 429:
return True

if isinstance(exception, ClientConnectorError):
return True

return False


Expand Down
17 changes: 17 additions & 0 deletions providers/tests/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

import itertools
import json
import ssl
import time
from unittest import mock
from unittest.mock import AsyncMock

import aiohttp
import aiohttp.client_exceptions
import azure.identity
import azure.identity.aio
import pytest
Expand Down Expand Up @@ -1534,6 +1536,21 @@ async def test_init_async_session(self):
assert isinstance(self.hook._session, aiohttp.ClientSession)
assert self.hook._session is None

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_do_api_call_retries_with_client_connector_error(self, mock_get):
mock_get.side_effect = aiohttp.ClientConnectorError(
connection_key=None,
os_error=ssl.SSLError(
"SSL handshake is taking longer than 60.0 seconds: aborting the connection"
),
)
with mock.patch.object(self.hook.log, "error") as mock_errors:
async with self.hook:
with pytest.raises(AirflowException):
await self.hook._a_do_api_call(GET_RUN_ENDPOINT, {})
assert mock_errors.call_count == DEFAULT_RETRY_NUMBER

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_do_api_call_retries_with_retryable_error(self, mock_get):
Expand Down

0 comments on commit 6984b31

Please sign in to comment.