diff --git a/airflow/providers/databricks/hooks/databricks_base.py b/airflow/providers/databricks/hooks/databricks_base.py index 0d6ce7ea7cd94..4dd95d755c72d 100644 --- a/airflow/providers/databricks/hooks/databricks_base.py +++ b/airflow/providers/databricks/hooks/databricks_base.py @@ -47,17 +47,13 @@ ) from airflow import __version__ -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException from airflow.hooks.base import BaseHook from airflow.providers_manager import ProvidersManager if TYPE_CHECKING: from airflow.models import Connection -# https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token -# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints -AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com" -AZURE_TOKEN_SERVICE_URL = "{}/{}/oauth2/token" # https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token AZURE_METADATA_SERVICE_TOKEN_URL = "http://169.254.169.254/metadata/identity/oauth2/token" AZURE_METADATA_SERVICE_INSTANCE_URL = "http://169.254.169.254/metadata/instance" @@ -301,46 +297,29 @@ def _get_aad_token(self, resource: str) -> str: self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...") try: + from azure.identity import ClientSecretCredential, ManagedIdentityCredential + for attempt in self._get_retry_object(): with attempt: if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False): - params = { - "api-version": "2018-02-01", - "resource": resource, - } - resp = requests.get( - AZURE_METADATA_SERVICE_TOKEN_URL, - params=params, - headers={**self.user_agent_header, "Metadata": "true"}, - timeout=self.token_timeout_seconds, - ) + token = ManagedIdentityCredential().get_token(f"{resource}/.default") else: - tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"] - data = { - "grant_type": "client_credentials", - "client_id": self.databricks_conn.login, - "resource": resource, - "client_secret": self.databricks_conn.password, - } - azure_ad_endpoint = self.databricks_conn.extra_dejson.get( - "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT - ) - resp = requests.post( - AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id), - data=data, - headers={ - **self.user_agent_header, - "Content-Type": "application/x-www-form-urlencoded", - }, - timeout=self.token_timeout_seconds, + credential = ClientSecretCredential( + client_id=self.databricks_conn.login, + client_secret=self.databricks_conn.password, + tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"], ) - - resp.raise_for_status() - jsn = resp.json() - + token = credential.get_token(f"{resource}/.default") + jsn = { + "access_token": token.token, + "token_type": "Bearer", + "expires_on": token.expires_on, + } self._is_oauth_token_valid(jsn) self.oauth_tokens[resource] = jsn break + except ImportError as e: + raise AirflowOptionalProviderFeatureException(e) except RetryError: raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.") except requests_exceptions.HTTPError as e: @@ -362,47 +341,32 @@ async def _a_get_aad_token(self, resource: str) -> str: self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...") try: + from azure.identity.aio import ( + ClientSecretCredential as AsyncClientSecretCredential, + ManagedIdentityCredential as AsyncManagedIdentityCredential, + ) + async for attempt in self._a_get_retry_object(): with attempt: if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False): - params = { - "api-version": "2018-02-01", - "resource": resource, - } - async with self._session.get( - url=AZURE_METADATA_SERVICE_TOKEN_URL, - params=params, - headers={**self.user_agent_header, "Metadata": "true"}, - timeout=self.token_timeout_seconds, - ) as resp: - resp.raise_for_status() - jsn = await resp.json() + token = await AsyncManagedIdentityCredential().get_token(f"{resource}/.default") else: - tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"] - data = { - "grant_type": "client_credentials", - "client_id": self.databricks_conn.login, - "resource": resource, - "client_secret": self.databricks_conn.password, - } - azure_ad_endpoint = self.databricks_conn.extra_dejson.get( - "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT + credential = AsyncClientSecretCredential( + client_id=self.databricks_conn.login, + client_secret=self.databricks_conn.password, + tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"], ) - async with self._session.post( - url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id), - data=data, - headers={ - **self.user_agent_header, - "Content-Type": "application/x-www-form-urlencoded", - }, - timeout=self.token_timeout_seconds, - ) as resp: - resp.raise_for_status() - jsn = await resp.json() - + token = await credential.get_token(f"{resource}/.default") + jsn = { + "access_token": token.token, + "token_type": "Bearer", + "expires_on": token.expires_on, + } self._is_oauth_token_valid(jsn) self.oauth_tokens[resource] = jsn break + except ImportError as e: + raise AirflowOptionalProviderFeatureException(e) except RetryError: raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.") except aiohttp.ClientResponseError as err: diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index 0be4611ed5487..d790d5707209f 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -77,12 +77,17 @@ dependencies: - pandas>=1.5.3,<2.2;python_version<"3.9" - pyarrow>=14.0.1 + additional-extras: # pip install apache-airflow-providers-databricks[sdk] - name: sdk description: Install Databricks SDK dependencies: - databricks-sdk==0.10.0 + - name: azure-identity + description: Install Azure Identity client library + dependencies: + - azure-identity>=1.3.1 devel-dependencies: - deltalake>=0.12.0 diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 7e61036973182..fb44fd0aaf4ae 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -24,8 +24,11 @@ from unittest.mock import AsyncMock import aiohttp +import azure.identity +import azure.identity.aio import pytest import tenacity +from azure.core.credentials import AccessToken from requests import exceptions as requests_exceptions from requests.auth import HTTPBasicAuth @@ -39,10 +42,8 @@ RunState, ) from airflow.providers.databricks.hooks.databricks_base import ( - AZURE_DEFAULT_AD_ENDPOINT, AZURE_MANAGEMENT_ENDPOINT, AZURE_METADATA_SERVICE_INSTANCE_URL, - AZURE_TOKEN_SERVICE_URL, DEFAULT_DATABRICKS_SCOPE, OIDC_TOKEN_SERVICE_URL, TOKEN_REFRESH_LEAD_TIME, @@ -71,6 +72,8 @@ LOGIN = "login" PASSWORD = "password" TOKEN = "token" +AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com" +AZURE_TOKEN_SERVICE_URL = "{}/{}/oauth2/token" RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1" LIFE_CYCLE_STATE = "PENDING" STATE_MESSAGE = "Waiting for cluster" @@ -1294,16 +1297,8 @@ def test_from_json(self): assert expected == ClusterState.from_json(json.dumps(state)) -def create_aad_token_for_resource(resource: str) -> dict: - return { - "token_type": "Bearer", - "expires_in": "599", - "ext_expires_in": "599", - "expires_on": "1575500666", - "not_before": "1575499766", - "resource": resource, - "access_token": TOKEN, - } +def create_aad_token_for_resource() -> AccessToken: + return AccessToken(expires_on=1575500666, token=TOKEN) @pytest.mark.db_test @@ -1327,12 +1322,11 @@ def setup_method(self, method, session=None): self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") - def test_submit_run(self, mock_requests): + @mock.patch.object(azure.identity, "ClientSecretCredential") + def test_submit_run(self, mock_azure_identity, mock_requests): mock_requests.codes.ok = 200 - mock_requests.post.side_effect = [ - create_successful_response_mock(create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)), - create_successful_response_mock({"run_id": "1"}), - ] + mock_requests.post.side_effect = [create_successful_response_mock({"run_id": "1"})] + mock_azure_identity().get_token.return_value = create_aad_token_for_resource() status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock data = {"notebook_task": NOTEBOOK_TASK, "new_cluster": NEW_CLUSTER} @@ -1370,21 +1364,23 @@ def setup_method(self, method, session=None): self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") - def test_submit_run(self, mock_requests): + @mock.patch.object(azure.identity, "ClientSecretCredential") + def test_submit_run(self, mock_azure_identity, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.side_effect = [ - create_successful_response_mock(create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)), create_successful_response_mock({"run_id": "1"}), ] + mock_azure_identity().get_token.return_value = create_aad_token_for_resource() status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock data = {"notebook_task": NOTEBOOK_TASK, "new_cluster": NEW_CLUSTER} run_id = self.hook.submit_run(data) - ad_call_args = mock_requests.method_calls[0] - assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(self.ad_endpoint, self.tenant_id) - assert ad_call_args[2]["data"]["client_id"] == self.client_id - assert ad_call_args[2]["data"]["resource"] == DEFAULT_DATABRICKS_SCOPE + azure_identity_args = mock_azure_identity.call_args.kwargs + assert azure_identity_args["tenant_id"] == self.tenant_id + assert azure_identity_args["client_id"] == self.client_id + get_token_args = mock_azure_identity.return_value.get_token.call_args_list + assert get_token_args == [mock.call(f"{DEFAULT_DATABRICKS_SCOPE}/.default")] assert run_id == "1" args = mock_requests.post.call_args @@ -1416,27 +1412,27 @@ def setup_method(self, method, session=None): self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") - def test_submit_run(self, mock_requests): + @mock.patch.object(azure.identity, "ClientSecretCredential") + def test_submit_run(self, mock_azure_identity, mock_requests): mock_requests.codes.ok = 200 mock_requests.post.side_effect = [ - create_successful_response_mock(create_aad_token_for_resource(AZURE_MANAGEMENT_ENDPOINT)), - create_successful_response_mock(create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)), create_successful_response_mock({"run_id": "1"}), ] status_code_mock = mock.PropertyMock(return_value=200) + mock_azure_identity().get_token.return_value = create_aad_token_for_resource() type(mock_requests.post.return_value).status_code = status_code_mock data = {"notebook_task": NOTEBOOK_TASK, "new_cluster": NEW_CLUSTER} run_id = self.hook.submit_run(data) - ad_call_args = mock_requests.method_calls[0] - assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id) - assert ad_call_args[2]["data"]["client_id"] == self.client_id - assert ad_call_args[2]["data"]["resource"] == AZURE_MANAGEMENT_ENDPOINT + azure_identity_args = mock_azure_identity.call_args.kwargs + assert azure_identity_args["tenant_id"] == self.tenant_id + assert azure_identity_args["client_id"] == self.client_id - ad_call_args = mock_requests.method_calls[1] - assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id) - assert ad_call_args[2]["data"]["client_id"] == self.client_id - assert ad_call_args[2]["data"]["resource"] == DEFAULT_DATABRICKS_SCOPE + get_token_args = mock_azure_identity.return_value.get_token.call_args_list + assert get_token_args == [ + mock.call(f"{AZURE_MANAGEMENT_ENDPOINT}/.default"), + mock.call(f"{DEFAULT_DATABRICKS_SCOPE}/.default"), + ] assert run_id == "1" args = mock_requests.post.call_args @@ -1465,15 +1461,16 @@ def setup_method(self, method, session=None): self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") - def test_submit_run(self, mock_requests): + @mock.patch.object(azure.identity, "ManagedIdentityCredential") + def test_submit_run(self, mock_azure_identity, mock_requests): mock_requests.codes.ok = 200 mock_requests.get.side_effect = [ create_successful_response_mock({"compute": {"azEnvironment": "AZUREPUBLICCLOUD"}}), - create_successful_response_mock(create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)), ] mock_requests.post.side_effect = [ create_successful_response_mock({"run_id": "1"}), ] + mock_azure_identity().get_token.return_value = create_aad_token_for_resource() status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock data = {"notebook_task": NOTEBOOK_TASK, "new_cluster": NEW_CLUSTER} @@ -1668,12 +1665,10 @@ def setup_method(self, method, session=None): @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") - @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post") - async def test_get_run_state(self, mock_post, mock_get): - mock_post.return_value.__aenter__.return_value.json = AsyncMock( - return_value=create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE) - ) + @mock.patch("azure.identity.aio.ClientSecretCredential.get_token") + async def test_get_run_state(self, mock_azure_identity, mock_get): mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE) + mock_azure_identity.return_value = create_aad_token_for_resource() async with self.hook: run_state = await self.hook.a_get_run_state(RUN_ID) @@ -1715,11 +1710,11 @@ def setup_method(self, method, session=None): @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") - @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post") - async def test_get_run_state(self, mock_post, mock_get): - mock_post.return_value.__aenter__.return_value.json = AsyncMock( - return_value=create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE) - ) + @mock.patch("azure.identity.aio.ClientSecretCredential.__init__") + @mock.patch("azure.identity.aio.ClientSecretCredential.get_token") + async def test_get_run_state(self, mock_azure_identity_get_token, mock_azure_identity, mock_get): + mock_azure_identity.return_value = None + mock_azure_identity_get_token.return_value = create_aad_token_for_resource() mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE) async with self.hook: @@ -1727,10 +1722,12 @@ async def test_get_run_state(self, mock_post, mock_get): assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE) - ad_call_args = mock_post.call_args_list[0] - assert ad_call_args[1]["url"] == AZURE_TOKEN_SERVICE_URL.format(self.ad_endpoint, self.tenant_id) - assert ad_call_args[1]["data"]["client_id"] == self.client_id - assert ad_call_args[1]["data"]["resource"] == DEFAULT_DATABRICKS_SCOPE + azure_identity_args = mock_azure_identity.call_args.kwargs + assert azure_identity_args["tenant_id"] == self.tenant_id + assert azure_identity_args["client_id"] == self.client_id + + get_token_args = mock_azure_identity_get_token.call_args_list + assert get_token_args == [mock.call(f"{DEFAULT_DATABRICKS_SCOPE}/.default")] mock_get.assert_called_once_with( get_run_endpoint(HOST), @@ -1766,14 +1763,11 @@ def setup_method(self, method, session=None): @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") - @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post") - async def test_get_run_state(self, mock_post, mock_get): - mock_post.return_value.__aenter__.return_value.json.side_effect = AsyncMock( - side_effect=[ - create_aad_token_for_resource(AZURE_MANAGEMENT_ENDPOINT), - create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE), - ] - ) + @mock.patch("azure.identity.aio.ClientSecretCredential.__init__") + @mock.patch("azure.identity.aio.ClientSecretCredential.get_token") + async def test_get_run_state(self, mock_azure_identity_get_token, mock_azure_identity, mock_get): + mock_azure_identity.return_value = None + mock_azure_identity_get_token.return_value = create_aad_token_for_resource() mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE) async with self.hook: @@ -1781,19 +1775,15 @@ async def test_get_run_state(self, mock_post, mock_get): assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE) - ad_call_args = mock_post.call_args_list[0] - assert ad_call_args[1]["url"] == AZURE_TOKEN_SERVICE_URL.format( - AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id - ) - assert ad_call_args[1]["data"]["client_id"] == self.client_id - assert ad_call_args[1]["data"]["resource"] == AZURE_MANAGEMENT_ENDPOINT + azure_identity_args = mock_azure_identity.call_args.kwargs + assert azure_identity_args["tenant_id"] == self.tenant_id + assert azure_identity_args["client_id"] == self.client_id - ad_call_args = mock_post.call_args_list[1] - assert ad_call_args[1]["url"] == AZURE_TOKEN_SERVICE_URL.format( - AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id - ) - assert ad_call_args[1]["data"]["client_id"] == self.client_id - assert ad_call_args[1]["data"]["resource"] == DEFAULT_DATABRICKS_SCOPE + get_token_args = mock_azure_identity_get_token.call_args_list + assert get_token_args == [ + mock.call(f"{AZURE_MANAGEMENT_ENDPOINT}/.default"), + mock.call(f"{DEFAULT_DATABRICKS_SCOPE}/.default"), + ] mock_get.assert_called_once_with( get_run_endpoint(HOST), @@ -1830,14 +1820,15 @@ def setup_method(self, method, session=None): @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") - async def test_get_run_state(self, mock_get): + @mock.patch("azure.identity.aio.ManagedIdentityCredential.get_token") + async def test_get_run_state(self, mock_azure_identity, mock_get): mock_get.return_value.__aenter__.return_value.json.side_effect = AsyncMock( side_effect=[ {"compute": {"azEnvironment": "AZUREPUBLICCLOUD"}}, - create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE), GET_RUN_RESPONSE, ] ) + mock_azure_identity.return_value = create_aad_token_for_resource() async with self.hook: run_state = await self.hook.a_get_run_state(RUN_ID)