diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py index 7ece90122e7a7..919e21c3287ca 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py @@ -121,6 +121,9 @@ def __init__( self.oauth_tokens: dict[str, dict] = {} self.token_timeout_seconds = 10 self.caller = caller + self._metadata_cache: dict[str, Any] = {} + self._metadata_expiry: float = 0 + self._metadata_ttl: int = 300 def my_after_func(retry_state): self._log_request_error(retry_state.attempt_number, retry_state.outcome) @@ -515,43 +518,64 @@ def _is_oauth_token_valid(token: dict, time_key="expires_on") -> bool: return int(token[time_key]) > (int(time.time()) + TOKEN_REFRESH_LEAD_TIME) - @staticmethod - def _check_azure_metadata_service() -> None: + def _check_azure_metadata_service(self) -> None: """ - Check for Azure Metadata Service. + Check for Azure Metadata Service (with caching). https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service """ + if self._metadata_cache and time.time() < self._metadata_expiry: + return try: - jsn = requests.get( - AZURE_METADATA_SERVICE_INSTANCE_URL, - params={"api-version": "2021-02-01"}, - headers={"Metadata": "true"}, - timeout=2, - ).json() - if "compute" not in jsn or "azEnvironment" not in jsn["compute"]: - raise AirflowException( - f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}" - ) + for attempt in self._get_retry_object(): + with attempt: + response = requests.get( + AZURE_METADATA_SERVICE_INSTANCE_URL, + params={"api-version": "2021-02-01"}, + headers={"Metadata": "true"}, + timeout=2, + ) + response.raise_for_status() + response_json = response.json() + + self._validate_azure_metadata_service(response_json) + self._metadata_cache = response_json + self._metadata_expiry = time.time() + self._metadata_ttl + break + except RetryError: + raise ConnectionError(f"Failed to reach Azure Metadata Service after {self.retry_limit} retries.") except (requests_exceptions.RequestException, ValueError) as e: - raise AirflowException(f"Can't reach Azure Metadata Service: {e}") + raise ConnectionError(f"Can't reach Azure Metadata Service: {e}") async def _a_check_azure_metadata_service(self): """Async version of `_check_azure_metadata_service()`.""" + if self._metadata_cache and time.time() < self._metadata_expiry: + return try: - async with self._session.get( - url=AZURE_METADATA_SERVICE_INSTANCE_URL, - params={"api-version": "2021-02-01"}, - headers={"Metadata": "true"}, - timeout=2, - ) as resp: - jsn = await resp.json() - if "compute" not in jsn or "azEnvironment" not in jsn["compute"]: - raise AirflowException( - f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}" - ) - except (requests_exceptions.RequestException, ValueError) as e: - raise AirflowException(f"Can't reach Azure Metadata Service: {e}") + async for attempt in self._a_get_retry_object(): + with attempt: + async with self._session.get( + url=AZURE_METADATA_SERVICE_INSTANCE_URL, + params={"api-version": "2021-02-01"}, + headers={"Metadata": "true"}, + timeout=2, + ) as resp: + resp.raise_for_status() + response_json = await resp.json() + self._validate_azure_metadata_service(response_json) + self._metadata_cache = response_json + self._metadata_expiry = time.time() + self._metadata_ttl + break + except RetryError: + raise ConnectionError(f"Failed to reach Azure Metadata Service after {self.retry_limit} retries.") + except (aiohttp.ClientError, ValueError) as e: + raise ConnectionError(f"Can't reach Azure Metadata Service: {e}") + + def _validate_azure_metadata_service(self, response_json: dict) -> None: + if "compute" not in response_json or "azEnvironment" not in response_json["compute"]: + raise ValueError( + f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {response_json}" + ) def _get_token(self, raise_error: bool = False) -> str | None: if "token" in self.databricks_conn.extra_dejson: diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py index 2d6df615dca18..34b952f7c841f 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py @@ -26,7 +26,7 @@ from aiohttp.client_exceptions import ClientConnectorError from requests import exceptions as requests_exceptions from requests.auth import HTTPBasicAuth -from tenacity import Future, RetryError +from tenacity import AsyncRetrying, Future, RetryError, retry_if_exception, stop_after_attempt, wait_fixed from airflow.exceptions import AirflowException from airflow.models import Connection @@ -768,3 +768,145 @@ def test_get_error_code_with_http_error_and_valid_error_code(self): exception.response = mock_response hook = BaseDatabricksHook() assert hook._get_error_code(exception) == "INVALID_REQUEST" + + @mock.patch("requests.get") + @time_machine.travel("2025-07-12 12:00:00") + def test_check_azure_metadata_service_normal(self, mock_get): + travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp()) + hook = BaseDatabricksHook() + mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}} + mock_get.return_value.json.return_value = mock_response + + hook._check_azure_metadata_service() + + assert hook._metadata_cache == mock_response + assert int(hook._metadata_expiry) == travel_time + hook._metadata_ttl + + @mock.patch("requests.get") + @time_machine.travel("2025-07-12 12:00:00") + def test_check_azure_metadata_service_cached(self, mock_get): + travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp()) + hook = BaseDatabricksHook() + mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}} + hook._metadata_cache = mock_response + hook._metadata_expiry = travel_time + 1000 + + hook._check_azure_metadata_service() + mock_get.assert_not_called() + + @mock.patch("requests.get") + def test_check_azure_metadata_service_http_error(self, mock_get): + hook = BaseDatabricksHook() + mock_get.side_effect = requests_exceptions.RequestException("Fail") + + with pytest.raises(ConnectionError, match="Can't reach Azure Metadata Service"): + hook._check_azure_metadata_service() + assert hook._metadata_cache == {} + assert hook._metadata_expiry == 0 + + @mock.patch("requests.get") + def test_check_azure_metadata_service_retry_error(self, mock_get): + hook = BaseDatabricksHook() + + resp_429 = mock.Mock() + resp_429.status_code = 429 + resp_429.content = b"Too many requests" + http_error = requests_exceptions.HTTPError(response=resp_429) + mock_get.side_effect = http_error + + with pytest.raises(ConnectionError, match="Failed to reach Azure Metadata Service after 3 retries."): + hook._check_azure_metadata_service() + assert mock_get.call_count == 3 + + @pytest.mark.asyncio + @mock.patch("aiohttp.ClientSession.get") + async def test_a_check_azure_metadata_service_normal(self, mock_get): + hook = BaseDatabricksHook() + + async_mock = mock.AsyncMock() + async_mock.__aenter__.return_value = async_mock + async_mock.__aexit__.return_value = None + async_mock.json.return_value = {"compute": {"azEnvironment": "AzurePublicCloud"}} + + mock_get.return_value = async_mock + + async with aiohttp.ClientSession() as session: + hook._session = session + mock_attempt = mock.Mock() + mock_attempt.__enter__ = mock.Mock(return_value=None) + mock_attempt.__exit__ = mock.Mock(return_value=None) + + async def mock_retry_generator(): + yield mock_attempt + + hook._a_get_retry_object = mock.Mock(return_value=mock_retry_generator()) + await hook._a_check_azure_metadata_service() + + assert hook._metadata_cache["compute"]["azEnvironment"] == "AzurePublicCloud" + assert hook._metadata_expiry > 0 + + @pytest.mark.asyncio + @mock.patch("aiohttp.ClientSession.get") + @time_machine.travel("2025-07-12 12:00:00") + async def test_a_check_azure_metadata_service_cached(self, mock_get): + travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp()) + hook = BaseDatabricksHook() + hook._metadata_cache = {"compute": {"azEnvironment": "AzurePublicCloud"}} + hook._metadata_expiry = travel_time + 1000 + + async with aiohttp.ClientSession() as session: + hook._session = session + await hook._a_check_azure_metadata_service() + mock_get.assert_not_called() + + @pytest.mark.asyncio + @mock.patch("aiohttp.ClientSession.get") + async def test_a_check_azure_metadata_service_http_error(self, mock_get): + hook = BaseDatabricksHook() + + async_mock = mock.AsyncMock() + async_mock.__aenter__.side_effect = aiohttp.ClientError("Fail") + async_mock.__aexit__.return_value = None + mock_get.return_value = async_mock + + async with aiohttp.ClientSession() as session: + hook._session = session + mock_attempt = mock.Mock() + mock_attempt.__enter__ = mock.Mock(return_value=None) + mock_attempt.__exit__ = mock.Mock(return_value=None) + + async def mock_retry_generator(): + yield mock_attempt + + hook._a_get_retry_object = mock.Mock(return_value=mock_retry_generator()) + + with pytest.raises(ConnectionError, match="Can't reach Azure Metadata Service"): + await hook._a_check_azure_metadata_service() + assert hook._metadata_cache == {} + assert hook._metadata_expiry == 0 + + @pytest.mark.asyncio + @mock.patch("aiohttp.ClientSession.get") + async def test_a_check_azure_metadata_service_retry_error(self, mock_get): + hook = BaseDatabricksHook() + + mock_get.side_effect = aiohttp.ClientResponseError( + request_info=mock.Mock(), history=(), status=429, message="429 Too Many Requests" + ) + + async with aiohttp.ClientSession() as session: + hook._session = session + + hook._a_get_retry_object = lambda: AsyncRetrying( + stop=stop_after_attempt(hook.retry_limit), + wait=wait_fixed(0), + retry=retry_if_exception(hook._retryable_error), + ) + + hook._validate_azure_metadata_service = mock.Mock() + + with pytest.raises( + ConnectionError, match="Failed to reach Azure Metadata Service after 3 retries." + ): + await hook._a_check_azure_metadata_service() + assert mock_get.call_count == 3