From 6880bf48d9b51786f8438bcad6fbfccd0d5d578f Mon Sep 17 00:00:00 2001 From: Alex Ott Date: Sat, 20 Nov 2021 12:27:29 +0100 Subject: [PATCH 1/2] Databricks - allow Azure SP authentication on other clouds When other Azure clouds are used (US GovCloud, China, ...) other authentication endpoints should be used. This PR allows to overwrite the authentication endpoint when using other clouds --- .../providers/databricks/hooks/databricks.py | 9 ++- .../connections/databricks.rst | 1 + .../databricks/hooks/test_databricks.py | 63 ++++++++++++++++++- 3 files changed, 70 insertions(+), 3 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 48368452c6a4a..e26f51879067e 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -49,7 +49,9 @@ USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'} # https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token -AZURE_TOKEN_SERVICE_URL = "https://login.microsoftonline.com/{}/oauth2/token" +# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints +AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com" +AZURE_TOKEN_SERVICE_URL = "{}/{}/oauth2/token" # https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token AZURE_METADATA_SERVICE_TOKEN_URL = "http://169.254.169.254/metadata/identity/oauth2/token" AZURE_METADATA_SERVICE_INSTANCE_URL = "http://169.254.169.254/metadata/instance" @@ -200,8 +202,11 @@ def _get_aad_token(self, resource: str) -> str: "resource": resource, "client_secret": self.databricks_conn.password, } + azure_ad_endpoint = self.databricks_conn.extra_dejson.get( + "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT + ) resp = requests.post( - AZURE_TOKEN_SERVICE_URL.format(tenant_id), + AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id), data=data, headers={**USER_AGENT_HEADER, 'Content-Type': 'application/x-www-form-urlencoded'}, timeout=self.aad_timeout_seconds, diff --git a/docs/apache-airflow-providers-databricks/connections/databricks.rst b/docs/apache-airflow-providers-databricks/connections/databricks.rst index cb04ef97e1bcf..437a1c3730355 100644 --- a/docs/apache-airflow-providers-databricks/connections/databricks.rst +++ b/docs/apache-airflow-providers-databricks/connections/databricks.rst @@ -72,6 +72,7 @@ Extra (optional) * ``azure_tenant_id``: ID of the Azure Active Directory tenant * ``azure_resource_id``: optional Resource ID of the Azure Databricks workspace (required if Service Principal isn't a user inside workspace) + * ``azure_ad_endpoint``: optional host name of Azure AD endpoint if you're using special `Azure Cloud (GovCloud, China, Germany) `_. Should be specified as hostname with protocol. For example: ``https://login.microsoftonline.de``. Following parameters are necessary if using authentication with AAD token for Azure managed identity: diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 9341fad57cc61..13430c9438997 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -29,7 +29,9 @@ from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.databricks.hooks.databricks import ( + AZURE_DEFAULT_AD_ENDPOINT, AZURE_MANAGEMENT_ENDPOINT, + AZURE_TOKEN_SERVICE_URL, DEFAULT_DATABRICKS_SCOPE, SUBMIT_RUN_ENDPOINT, DatabricksHook, @@ -638,6 +640,53 @@ def test_submit_run(self, mock_requests): assert kwargs['auth'].token == TOKEN +class TestDatabricksHookAadTokenOtherClouds(unittest.TestCase): + """ + Tests for DatabricksHook when auth is done with AAD token for SP as user inside workspace and + using non-global Azure cloud (China, GovCloud, Germany) + """ + + @provide_session + def setUp(self, session=None): + self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d' + self.ad_endpoint = 'https://login.microsoftonline.de' + self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d' + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.login = self.client_id + conn.password = 'secret' + conn.extra = json.dumps( + { + 'host': HOST, + 'azure_tenant_id': self.tenant_id, + 'azure_ad_endpoint': self.ad_endpoint, + } + ) + session.commit() + self.hook = DatabricksHook() + + @mock.patch('airflow.providers.databricks.hooks.databricks.requests') + def test_submit_run(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.side_effect = [ + create_successful_response_mock(create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)), + create_successful_response_mock({'run_id': '1'}), + ] + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER} + run_id = self.hook.submit_run(data) + + ad_call_args = mock_requests.method_calls[0] + assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(self.ad_endpoint, self.tenant_id) + assert ad_call_args[2]['data']['client_id'] == self.client_id + assert ad_call_args[2]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE + + assert run_id == '1' + args = mock_requests.post.call_args + kwargs = args[1] + assert kwargs['auth'].token == TOKEN + + class TestDatabricksHookAadTokenSpOutside(unittest.TestCase): """ Tests for DatabricksHook when auth is done with AAD token for SP outside of workspace. @@ -646,7 +695,9 @@ class TestDatabricksHookAadTokenSpOutside(unittest.TestCase): @provide_session def setUp(self, session=None): conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() - conn.login = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d' + self.tenant_id = '3ff810a6-5504-4ab8-85cb-cd0e6f879c1d' + self.client_id = '9ff815a6-4404-4ab8-85cb-cd0e6f879c1d' + conn.login = self.client_id conn.password = 'secret' conn.host = HOST conn.extra = json.dumps( @@ -671,6 +722,16 @@ def test_submit_run(self, mock_requests): data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER} run_id = self.hook.submit_run(data) + ad_call_args = mock_requests.method_calls[0] + assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id) + assert ad_call_args[2]['data']['client_id'] == self.client_id + assert ad_call_args[2]['data']['resource'] == AZURE_MANAGEMENT_ENDPOINT + + ad_call_args = mock_requests.method_calls[1] + assert ad_call_args[1][0] == AZURE_TOKEN_SERVICE_URL.format(AZURE_DEFAULT_AD_ENDPOINT, self.tenant_id) + assert ad_call_args[2]['data']['client_id'] == self.client_id + assert ad_call_args[2]['data']['resource'] == DEFAULT_DATABRICKS_SCOPE + assert run_id == '1' args = mock_requests.post.call_args kwargs = args[1] From 3dd6ba30a8bf52a60fbe1d6fb91ff91cf40c6373 Mon Sep 17 00:00:00 2001 From: Alex Ott Date: Mon, 22 Nov 2021 08:25:01 +0100 Subject: [PATCH 2/2] Update docs/apache-airflow-providers-databricks/connections/databricks.rst Co-authored-by: Tzu-ping Chung --- .../connections/databricks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/apache-airflow-providers-databricks/connections/databricks.rst b/docs/apache-airflow-providers-databricks/connections/databricks.rst index 437a1c3730355..504c850bb1e9a 100644 --- a/docs/apache-airflow-providers-databricks/connections/databricks.rst +++ b/docs/apache-airflow-providers-databricks/connections/databricks.rst @@ -72,7 +72,7 @@ Extra (optional) * ``azure_tenant_id``: ID of the Azure Active Directory tenant * ``azure_resource_id``: optional Resource ID of the Azure Databricks workspace (required if Service Principal isn't a user inside workspace) - * ``azure_ad_endpoint``: optional host name of Azure AD endpoint if you're using special `Azure Cloud (GovCloud, China, Germany) `_. Should be specified as hostname with protocol. For example: ``https://login.microsoftonline.de``. + * ``azure_ad_endpoint``: optional host name of Azure AD endpoint if you're using special `Azure Cloud (GovCloud, China, Germany) `_. The value must contain a protocol. For example: ``https://login.microsoftonline.de``. Following parameters are necessary if using authentication with AAD token for Azure managed identity: