diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 48368452c6a4a..e26f51879067e 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -49,7 +49,9 @@ USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'} # https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token -AZURE_TOKEN_SERVICE_URL = "https://login.microsoftonline.com/{}/oauth2/token" +# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints +AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com" +AZURE_TOKEN_SERVICE_URL = "{}/{}/oauth2/token" # https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token AZURE_METADATA_SERVICE_TOKEN_URL = "http://169.254.169.254/metadata/identity/oauth2/token" AZURE_METADATA_SERVICE_INSTANCE_URL = "http://169.254.169.254/metadata/instance" @@ -200,8 +202,11 @@ def _get_aad_token(self, resource: str) -> str: "resource": resource, "client_secret": self.databricks_conn.password, } + azure_ad_endpoint = self.databricks_conn.extra_dejson.get( + "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT + ) resp = requests.post( - AZURE_TOKEN_SERVICE_URL.format(tenant_id), + AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id), data=data, headers={**USER_AGENT_HEADER, 'Content-Type': 'application/x-www-form-urlencoded'}, timeout=self.aad_timeout_seconds, diff --git a/docs/apache-airflow-providers-databricks/connections/databricks.rst b/docs/apache-airflow-providers-databricks/connections/databricks.rst index cb04ef97e1bcf..504c850bb1e9a 100644 --- a/docs/apache-airflow-providers-databricks/connections/databricks.rst +++ b/docs/apache-airflow-providers-databricks/connections/databricks.rst @@ -72,6 +72,7 @@ Extra (optional) * ``azure_tenant_id``: ID of the Azure Active Directory tenant * ``azure_resource_id``: optional Resource ID of the Azure Databricks workspace (required if Service Principal isn't a user inside workspace) + * ``azure_ad_endpoint``: optional host name of Azure AD endpoint if you're using special `Azure Cloud (GovCloud, China, Germany) `_. The value must contain a protocol. For example: ``https://login.microsoftonline.de``. Following parameters are necessary if using authentication with AAD token for Azure managed identity: diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 9341fad57cc61..13430c9438997 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -29,7 +29,9 @@ from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.databricks.hooks.databricks import ( + AZURE_DEFAULT_AD_ENDPOINT, AZURE_MANAGEMENT_ENDPOINT, + AZURE_TOKEN_SERVICE_URL, DEFAULT_DATABRICKS_SCOPE, SUBMIT_RUN_ENDPOINT, DatabricksHook, @@ -638,6 +640,53 @@ def test_submit_run(self, mock_requests): assert kwargs['auth'].token == TOKEN +class TestDatabricksHookAadTokenOtherClouds(unittest.TestCase): + """ + Tests for DatabricksHook when auth is done with AAD token for SP as user inside workspace and + using non-global Azure cloud (China, GovCloud, Germany) + """ + + @provide_session + def setUp(self, session=None): + self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d' + self.ad_endpoint = 'https://login.microsoftonline.de' + self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d' + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.login = self.client_id + conn.password = 'secret' + conn.extra = json.dumps( + { + 'host': HOST, + 'azure_tenant_id': self.tenant_id, + 'azure_ad_endpoint': self.ad_endpoint, + } + ) + session.commit() + self.hook = DatabricksHook() + + @mock.patch('airflow.providers.databricks.hooks.databricks.requests') + def test_submit_run(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.side_effect = [ + create_successful_response_mock(create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)), + create_successful_response_mock({'run_id': '1'}), + ] + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER} + run_id = self.hook.submit_run(data) + + ad_call_args = mock_requests.method_calls[0] + assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(self.ad_endpoint, self.tenant_id) + assert ad_call_args[2]['data']['client_id'] == self.client_id + assert ad_call_args[2]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE + + assert run_id == '1' + args = mock_requests.post.call_args + kwargs = args[1] + assert kwargs['auth'].token == TOKEN + + class TestDatabricksHookAadTokenSpOutside(unittest.TestCase): """ Tests for DatabricksHook when auth is done with AAD token for SP outside of workspace. @@ -646,7 +695,9 @@ class TestDatabricksHookAadTokenSpOutside(unittest.TestCase): @provide_session def setUp(self, session=None): conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() - conn.login = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d' + self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d' + self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d' + conn.login = self.client_id conn.password = 'secret' conn.host = HOST conn.extra = json.dumps( @@ -671,6 +722,16 @@ def test_submit_run(self, mock_requests): data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER} run_id = self.hook.submit_run(data) + ad_call_args = mock_requests.method_calls[0] + assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id) + assert ad_call_args[2]['data']['client_id'] == self.client_id + assert ad_call_args[2]['data']['resource'] == AZURE_MANAGEMENT_ENDPOINT + + ad_call_args = mock_requests.method_calls[1] + assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id) + assert ad_call_args[2]['data']['client_id'] == self.client_id + assert ad_call_args[2]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE + assert run_id == '1' args = mock_requests.post.call_args kwargs = args[1]