Skip to content

Commit

Permalink
Add azure identity to get bearer token
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Martinelli authored and marcel-martinelli committed Jun 20, 2024
1 parent c3752e2 commit 86b3545
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 71 deletions.
96 changes: 27 additions & 69 deletions airflow/providers/databricks/hooks/databricks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@

import aiohttp
import requests
from azure.identity import ClientSecretCredential, ManagedIdentityCredential
from azure.identity.aio import (
ClientSecretCredential as AsyncClientSecretCredential,
ManagedIdentityCredential as AsyncManagedIdentityCredential,
)
from requests import PreparedRequest, exceptions as requests_exceptions
from requests.auth import AuthBase, HTTPBasicAuth
from requests.exceptions import JSONDecodeError
Expand All @@ -54,10 +59,6 @@
if TYPE_CHECKING:
from airflow.models import Connection

# https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token
# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints
AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com"
AZURE_TOKEN_SERVICE_URL = "{}/{}/oauth2/token"
# https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token
AZURE_METADATA_SERVICE_TOKEN_URL = "http://169.254.169.254/metadata/identity/oauth2/token"
AZURE_METADATA_SERVICE_INSTANCE_URL = "http://169.254.169.254/metadata/instance"
Expand Down Expand Up @@ -304,40 +305,19 @@ def _get_aad_token(self, resource: str) -> str:
for attempt in self._get_retry_object():
with attempt:
if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
params = {
"api-version": "2018-02-01",
"resource": resource,
}
resp = requests.get(
AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
headers={**self.user_agent_header, "Metadata": "true"},
timeout=self.token_timeout_seconds,
)
token = ManagedIdentityCredential().get_token(f"{resource}/.default")
else:
tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"]
data = {
"grant_type": "client_credentials",
"client_id": self.databricks_conn.login,
"resource": resource,
"client_secret": self.databricks_conn.password,
}
azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
"azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
credential = ClientSecretCredential(
client_id=self.databricks_conn.login,
client_secret=self.databricks_conn.password,
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
)
resp = requests.post(
AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={
**self.user_agent_header,
"Content-Type": "application/x-www-form-urlencoded",
},
timeout=self.token_timeout_seconds,
)

resp.raise_for_status()
jsn = resp.json()

token = credential.get_token(f"{resource}/.default")
jsn = {
"access_token": token.token,
"token_type": "Bearer",
"expires_on": token.expires_on,
}
self._is_oauth_token_valid(jsn)
self.oauth_tokens[resource] = jsn
break
Expand Down Expand Up @@ -365,41 +345,19 @@ async def _a_get_aad_token(self, resource: str) -> str:
async for attempt in self._a_get_retry_object():
with attempt:
if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
params = {
"api-version": "2018-02-01",
"resource": resource,
}
async with self._session.get(
url=AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
headers={**self.user_agent_header, "Metadata": "true"},
timeout=self.token_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()
token = await AsyncManagedIdentityCredential().get_token(f"{resource}/.default")
else:
tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"]
data = {
"grant_type": "client_credentials",
"client_id": self.databricks_conn.login,
"resource": resource,
"client_secret": self.databricks_conn.password,
}
azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
"azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
credential = AsyncClientSecretCredential(
client_id=self.databricks_conn.login,
client_secret=self.databricks_conn.password,
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
)
async with self._session.post(
url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={
**self.user_agent_header,
"Content-Type": "application/x-www-form-urlencoded",
},
timeout=self.token_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()

token = await credential.get_token(f"{resource}/.default")
jsn = {
"access_token": token.token,
"token_type": "Bearer",
"expires_on": token.expires_on,
}
self._is_oauth_token_valid(jsn)
self.oauth_tokens[resource] = jsn
break
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ dependencies:
- pandas>=2.1.2,<2.2;python_version>="3.9"
- pandas>=1.5.3,<2.2;python_version<"3.9"
- pyarrow>=14.0.1
- azure-identity>=1.3.1

additional-extras:
# pip install apache-airflow-providers-databricks[sdk]
Expand Down
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@
"aiohttp>=3.9.2, <4",
"apache-airflow-providers-common-sql>=1.10.0",
"apache-airflow>=2.7.0",
"azure-identity>=1.3.1",
"databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0",
"mergedeep>=1.3.4",
"pandas>=1.5.3,<2.2;python_version<\"3.9\"",
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@
RunState,
)
from airflow.providers.databricks.hooks.databricks_base import (
AZURE_DEFAULT_AD_ENDPOINT,
AZURE_MANAGEMENT_ENDPOINT,
AZURE_METADATA_SERVICE_INSTANCE_URL,
AZURE_TOKEN_SERVICE_URL,
DEFAULT_DATABRICKS_SCOPE,
OIDC_TOKEN_SERVICE_URL,
TOKEN_REFRESH_LEAD_TIME,
Expand Down Expand Up @@ -71,6 +69,8 @@
LOGIN = "login"
PASSWORD = "password"
TOKEN = "token"
AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com"
AZURE_TOKEN_SERVICE_URL = "{}/{}/oauth2/token"
RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1"
LIFE_CYCLE_STATE = "PENDING"
STATE_MESSAGE = "Waiting for cluster"
Expand Down

0 comments on commit 86b3545

Please sign in to comment.