Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def __init__(
self.oauth_tokens: dict[str, dict] = {}
self.token_timeout_seconds = 10
self.caller = caller
self._metadata_cache: dict[str, Any] = {}
self._metadata_expiry: float = 0
self._metadata_ttl: int = 300

def my_after_func(retry_state):
self._log_request_error(retry_state.attempt_number, retry_state.outcome)
Expand Down Expand Up @@ -515,43 +518,64 @@ def _is_oauth_token_valid(token: dict, time_key="expires_on") -> bool:

return int(token[time_key]) > (int(time.time()) + TOKEN_REFRESH_LEAD_TIME)

@staticmethod
def _check_azure_metadata_service() -> None:
def _check_azure_metadata_service(self) -> None:
"""
Check for Azure Metadata Service.
Check for Azure Metadata Service (with caching).

https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
"""
if self._metadata_cache and time.time() < self._metadata_expiry:
return
try:
jsn = requests.get(
AZURE_METADATA_SERVICE_INSTANCE_URL,
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
timeout=2,
).json()
if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
raise AirflowException(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}"
)
for attempt in self._get_retry_object():
with attempt:
response = requests.get(
AZURE_METADATA_SERVICE_INSTANCE_URL,
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
timeout=2,
)
response.raise_for_status()
response_json = response.json()

self._validate_azure_metadata_service(response_json)
self._metadata_cache = response_json
self._metadata_expiry = time.time() + self._metadata_ttl
break
except RetryError:
raise ConnectionError(f"Failed to reach Azure Metadata Service after {self.retry_limit} retries.")
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")
raise ConnectionError(f"Can't reach Azure Metadata Service: {e}")

async def _a_check_azure_metadata_service(self):
"""Async version of `_check_azure_metadata_service()`."""
if self._metadata_cache and time.time() < self._metadata_expiry:
return
try:
async with self._session.get(
url=AZURE_METADATA_SERVICE_INSTANCE_URL,
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
timeout=2,
) as resp:
jsn = await resp.json()
if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
raise AirflowException(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}"
)
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")
async for attempt in self._a_get_retry_object():
with attempt:
async with self._session.get(
url=AZURE_METADATA_SERVICE_INSTANCE_URL,
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
timeout=2,
) as resp:
resp.raise_for_status()
response_json = await resp.json()
self._validate_azure_metadata_service(response_json)
self._metadata_cache = response_json
self._metadata_expiry = time.time() + self._metadata_ttl
break
except RetryError:
raise ConnectionError(f"Failed to reach Azure Metadata Service after {self.retry_limit} retries.")
except (aiohttp.ClientError, ValueError) as e:
raise ConnectionError(f"Can't reach Azure Metadata Service: {e}")

def _validate_azure_metadata_service(self, response_json: dict) -> None:
if "compute" not in response_json or "azEnvironment" not in response_json["compute"]:
raise ValueError(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {response_json}"
)

def _get_token(self, raise_error: bool = False) -> str | None:
if "token" in self.databricks_conn.extra_dejson:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from aiohttp.client_exceptions import ClientConnectorError
from requests import exceptions as requests_exceptions
from requests.auth import HTTPBasicAuth
from tenacity import Future, RetryError
from tenacity import AsyncRetrying, Future, RetryError, retry_if_exception, stop_after_attempt, wait_fixed

from airflow.exceptions import AirflowException
from airflow.models import Connection
Expand Down Expand Up @@ -768,3 +768,145 @@ def test_get_error_code_with_http_error_and_valid_error_code(self):
exception.response = mock_response
hook = BaseDatabricksHook()
assert hook._get_error_code(exception) == "INVALID_REQUEST"

@mock.patch("requests.get")
@time_machine.travel("2025-07-12 12:00:00")
def test_check_azure_metadata_service_normal(self, mock_get):
travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
hook = BaseDatabricksHook()
mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
mock_get.return_value.json.return_value = mock_response

hook._check_azure_metadata_service()

assert hook._metadata_cache == mock_response
assert int(hook._metadata_expiry) == travel_time + hook._metadata_ttl

@mock.patch("requests.get")
@time_machine.travel("2025-07-12 12:00:00")
def test_check_azure_metadata_service_cached(self, mock_get):
travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
hook = BaseDatabricksHook()
mock_response = {"compute": {"azEnvironment": "AzurePublicCloud"}}
hook._metadata_cache = mock_response
hook._metadata_expiry = travel_time + 1000

hook._check_azure_metadata_service()
mock_get.assert_not_called()

@mock.patch("requests.get")
def test_check_azure_metadata_service_http_error(self, mock_get):
hook = BaseDatabricksHook()
mock_get.side_effect = requests_exceptions.RequestException("Fail")

with pytest.raises(ConnectionError, match="Can't reach Azure Metadata Service"):
hook._check_azure_metadata_service()
assert hook._metadata_cache == {}
assert hook._metadata_expiry == 0

@mock.patch("requests.get")
def test_check_azure_metadata_service_retry_error(self, mock_get):
hook = BaseDatabricksHook()

resp_429 = mock.Mock()
resp_429.status_code = 429
resp_429.content = b"Too many requests"
http_error = requests_exceptions.HTTPError(response=resp_429)
mock_get.side_effect = http_error

with pytest.raises(ConnectionError, match="Failed to reach Azure Metadata Service after 3 retries."):
hook._check_azure_metadata_service()
assert mock_get.call_count == 3

@pytest.mark.asyncio
@mock.patch("aiohttp.ClientSession.get")
async def test_a_check_azure_metadata_service_normal(self, mock_get):
hook = BaseDatabricksHook()

async_mock = mock.AsyncMock()
async_mock.__aenter__.return_value = async_mock
async_mock.__aexit__.return_value = None
async_mock.json.return_value = {"compute": {"azEnvironment": "AzurePublicCloud"}}

mock_get.return_value = async_mock

async with aiohttp.ClientSession() as session:
hook._session = session
mock_attempt = mock.Mock()
mock_attempt.__enter__ = mock.Mock(return_value=None)
mock_attempt.__exit__ = mock.Mock(return_value=None)

async def mock_retry_generator():
yield mock_attempt

hook._a_get_retry_object = mock.Mock(return_value=mock_retry_generator())
await hook._a_check_azure_metadata_service()

assert hook._metadata_cache["compute"]["azEnvironment"] == "AzurePublicCloud"
assert hook._metadata_expiry > 0

@pytest.mark.asyncio
@mock.patch("aiohttp.ClientSession.get")
@time_machine.travel("2025-07-12 12:00:00")
async def test_a_check_azure_metadata_service_cached(self, mock_get):
travel_time = int(datetime(2025, 7, 12, 12, 0, 0).timestamp())
hook = BaseDatabricksHook()
hook._metadata_cache = {"compute": {"azEnvironment": "AzurePublicCloud"}}
hook._metadata_expiry = travel_time + 1000

async with aiohttp.ClientSession() as session:
hook._session = session
await hook._a_check_azure_metadata_service()
mock_get.assert_not_called()

@pytest.mark.asyncio
@mock.patch("aiohttp.ClientSession.get")
async def test_a_check_azure_metadata_service_http_error(self, mock_get):
hook = BaseDatabricksHook()

async_mock = mock.AsyncMock()
async_mock.__aenter__.side_effect = aiohttp.ClientError("Fail")
async_mock.__aexit__.return_value = None
mock_get.return_value = async_mock

async with aiohttp.ClientSession() as session:
hook._session = session
mock_attempt = mock.Mock()
mock_attempt.__enter__ = mock.Mock(return_value=None)
mock_attempt.__exit__ = mock.Mock(return_value=None)

async def mock_retry_generator():
yield mock_attempt

hook._a_get_retry_object = mock.Mock(return_value=mock_retry_generator())

with pytest.raises(ConnectionError, match="Can't reach Azure Metadata Service"):
await hook._a_check_azure_metadata_service()
assert hook._metadata_cache == {}
assert hook._metadata_expiry == 0

@pytest.mark.asyncio
@mock.patch("aiohttp.ClientSession.get")
async def test_a_check_azure_metadata_service_retry_error(self, mock_get):
hook = BaseDatabricksHook()

mock_get.side_effect = aiohttp.ClientResponseError(
request_info=mock.Mock(), history=(), status=429, message="429 Too Many Requests"
)

async with aiohttp.ClientSession() as session:
hook._session = session

hook._a_get_retry_object = lambda: AsyncRetrying(
stop=stop_after_attempt(hook.retry_limit),
wait=wait_fixed(0),
retry=retry_if_exception(hook._retryable_error),
)

hook._validate_azure_metadata_service = mock.Mock()

with pytest.raises(
ConnectionError, match="Failed to reach Azure Metadata Service after 3 retries."
):
await hook._a_check_azure_metadata_service()
assert mock_get.call_count == 3