Skip to content

Commit

Permalink
Bug/fix support azure managed identities in Databricks operator (#40332)
Browse files Browse the repository at this point in the history
* Add azure identity to get bearer token

* Add azure-identity as an additional-extras dependency

* Fix tests

---------

Co-authored-by: Marcel Martinelli <marcel.martinelli@rabobank.nl>
  • Loading branch information
marcel-martinelli and Marcel Martinelli committed Jun 22, 2024
1 parent 054e9fc commit de5c751
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 142 deletions.
104 changes: 34 additions & 70 deletions airflow/providers/databricks/hooks/databricks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,13 @@
)

from airflow import __version__
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException
from airflow.hooks.base import BaseHook
from airflow.providers_manager import ProvidersManager

if TYPE_CHECKING:
from airflow.models import Connection

# https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token
# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints
AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com"
AZURE_TOKEN_SERVICE_URL = "{}/{}/oauth2/token"
# https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token
AZURE_METADATA_SERVICE_TOKEN_URL = "http://169.254.169.254/metadata/identity/oauth2/token"
AZURE_METADATA_SERVICE_INSTANCE_URL = "http://169.254.169.254/metadata/instance"
Expand Down Expand Up @@ -301,46 +297,29 @@ def _get_aad_token(self, resource: str) -> str:

self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...")
try:
from azure.identity import ClientSecretCredential, ManagedIdentityCredential

for attempt in self._get_retry_object():
with attempt:
if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
params = {
"api-version": "2018-02-01",
"resource": resource,
}
resp = requests.get(
AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
headers={**self.user_agent_header, "Metadata": "true"},
timeout=self.token_timeout_seconds,
)
token = ManagedIdentityCredential().get_token(f"{resource}/.default")
else:
tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"]
data = {
"grant_type": "client_credentials",
"client_id": self.databricks_conn.login,
"resource": resource,
"client_secret": self.databricks_conn.password,
}
azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
"azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
)
resp = requests.post(
AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={
**self.user_agent_header,
"Content-Type": "application/x-www-form-urlencoded",
},
timeout=self.token_timeout_seconds,
credential = ClientSecretCredential(
client_id=self.databricks_conn.login,
client_secret=self.databricks_conn.password,
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
)

resp.raise_for_status()
jsn = resp.json()

token = credential.get_token(f"{resource}/.default")
jsn = {
"access_token": token.token,
"token_type": "Bearer",
"expires_on": token.expires_on,
}
self._is_oauth_token_valid(jsn)
self.oauth_tokens[resource] = jsn
break
except ImportError as e:
raise AirflowOptionalProviderFeatureException(e)
except RetryError:
raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.")
except requests_exceptions.HTTPError as e:
Expand All @@ -362,47 +341,32 @@ async def _a_get_aad_token(self, resource: str) -> str:

self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...")
try:
from azure.identity.aio import (
ClientSecretCredential as AsyncClientSecretCredential,
ManagedIdentityCredential as AsyncManagedIdentityCredential,
)

async for attempt in self._a_get_retry_object():
with attempt:
if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
params = {
"api-version": "2018-02-01",
"resource": resource,
}
async with self._session.get(
url=AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
headers={**self.user_agent_header, "Metadata": "true"},
timeout=self.token_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()
token = await AsyncManagedIdentityCredential().get_token(f"{resource}/.default")
else:
tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"]
data = {
"grant_type": "client_credentials",
"client_id": self.databricks_conn.login,
"resource": resource,
"client_secret": self.databricks_conn.password,
}
azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
"azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
credential = AsyncClientSecretCredential(
client_id=self.databricks_conn.login,
client_secret=self.databricks_conn.password,
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
)
async with self._session.post(
url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={
**self.user_agent_header,
"Content-Type": "application/x-www-form-urlencoded",
},
timeout=self.token_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()

token = await credential.get_token(f"{resource}/.default")
jsn = {
"access_token": token.token,
"token_type": "Bearer",
"expires_on": token.expires_on,
}
self._is_oauth_token_valid(jsn)
self.oauth_tokens[resource] = jsn
break
except ImportError as e:
raise AirflowOptionalProviderFeatureException(e)
except RetryError:
raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.")
except aiohttp.ClientResponseError as err:
Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,17 @@ dependencies:
- pandas>=1.5.3,<2.2;python_version<"3.9"
- pyarrow>=14.0.1


additional-extras:
# pip install apache-airflow-providers-databricks[sdk]
- name: sdk
description: Install Databricks SDK
dependencies:
- databricks-sdk==0.10.0
- name: azure-identity
description: Install Azure Identity client library
dependencies:
- azure-identity>=1.3.1

devel-dependencies:
- deltalake>=0.12.0
Expand Down
Loading

0 comments on commit de5c751

Please sign in to comment.