Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
dfa5a5c
init
karunpoudel Jun 24, 2025
56dae4d
fx
karunpoudel-chr Jun 24, 2025
c841b59
fx
karunpoudel-chr Jun 24, 2025
2d63327
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel Jun 24, 2025
5807c86
fx
karunpoudel Jun 24, 2025
0b1c35e
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jun 24, 2025
3cb169b
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jun 24, 2025
c9a2838
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel Jun 30, 2025
50e03ba
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jun 30, 2025
be69a63
doc update
karunpoudel-chr Jun 30, 2025
aa6673e
Merge branch 'azurebasehook-creds-support-get_token' of https://githu…
karunpoudel-chr Jun 30, 2025
f7ffb15
Update azure.rst
karunpoudel-chr Jun 30, 2025
7ec4303
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel Jul 2, 2025
90d6c73
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jul 10, 2025
f67cd27
fix
karunpoudel Jul 10, 2025
9575103
fix
karunpoudel Jul 10, 2025
4e57ceb
fx
karunpoudel Jul 10, 2025
b9bfd18
fix import
karunpoudel Jul 10, 2025
1da12bb
fix import
karunpoudel Jul 10, 2025
649944a
fix import
karunpoudel Jul 11, 2025
b458680
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jul 14, 2025
2965ccf
Update base_azure.py
karunpoudel-chr Jul 16, 2025
982a27b
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jul 16, 2025
41836a8
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jul 17, 2025
74aca20
precommit fixes
karunpoudel-chr Jul 21, 2025
8167606
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jul 21, 2025
3657a4b
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jul 23, 2025
671d875
pre-commit and doc fix
karunpoudel-chr Jul 23, 2025
5b35b01
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jul 23, 2025
5f55827
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jul 23, 2025
5b84c8c
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jul 24, 2025
782a885
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jul 30, 2025
94d7a0e
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Jul 31, 2025
6c412c2
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Aug 10, 2025
2e213d5
refactor
karunpoudel-chr Aug 10, 2025
df5d7dc
included get_token method in hook for error handling
karunpoudel-chr Aug 10, 2025
2c9e007
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel Aug 10, 2025
69d906c
add test for get_token
karunpoudel-chr Aug 10, 2025
28c1ad2
Merge branch 'main' into azurebasehook-creds-support-get_token
karunpoudel-chr Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions providers/microsoft/azure/docs/connections/azure.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ Extra (optional)
It specifies the json that contains the authentication information.
* ``managed_identity_client_id``: The client ID of a user-assigned managed identity. If provided with ``workload_identity_tenant_id``, they'll pass to DefaultAzureCredential_.
* ``workload_identity_tenant_id``: ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with ``managed_identity_client_id``, they'll pass to DefaultAzureCredential_.
* ``use_azure_identity_object``: If set to true, it will use credential of newer type: ClientSecretCredential or DefaultAzureCredential instead of ServicePrincipalCredentials or AzureIdentityCredentialAdapter.
These newer credentials support get_token method which can be used to generate OAuth token with custom scope.

The entire extra column can be left out to fall back on DefaultAzureCredential_.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,25 @@
# under the License.
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict
from azure.common.credentials import ServicePrincipalCredentials
from azure.identity import ClientSecretCredential, DefaultAzureCredential

from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.utils import (
AzureIdentityCredentialAdapter,
add_managed_identity_connection_widgets,
get_sync_default_azure_credential,
)
from airflow.providers.microsoft.azure.version_compat import BaseHook

if TYPE_CHECKING:
from azure.core.credentials import AccessToken

from airflow.sdk import Connection


class AzureBaseHook(BaseHook):
"""
Expand Down Expand Up @@ -85,7 +92,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
},
}

def __init__(self, sdk_client: Any, conn_id: str = "azure_default"):
def __init__(self, sdk_client: Any = None, conn_id: str = "azure_default"):
self.sdk_client = sdk_client
self.conn_id = conn_id
super().__init__()
Expand All @@ -96,8 +103,9 @@ def get_conn(self) -> Any:

:return: the authenticated client.
"""
if not self.sdk_client:
raise ValueError("`sdk_client` must be provided to AzureBaseHook to use `get_conn` method.")
conn = self.get_connection(self.conn_id)
tenant = conn.extra_dejson.get("tenantId")
subscription_id = conn.extra_dejson.get("subscriptionId")
key_path = conn.extra_dejson.get("key_path")
if key_path:
Expand All @@ -111,22 +119,74 @@ def get_conn(self) -> Any:
self.log.info("Getting connection using a JSON config.")
return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json)

credentials: ServicePrincipalCredentials | AzureIdentityCredentialAdapter
credentials = self.get_credential(conn=conn)

return self.sdk_client(
credentials=credentials,
subscription_id=subscription_id,
)

def get_credential(self, *, conn: Connection | None = None) -> Any:
"""
Get Azure credential object for the connection.

Azure Identity based credential object (``ClientSecretCredential``, ``DefaultAzureCredential``) can be used to get OAuth token using ``get_token`` method.
Older Credential objects (``ServicePrincipalCredentials``, ``AzureIdentityCredentialAdapter``) are supported for backward compatibility.

:return: The Azure credential object
"""
if not conn:
conn = self.get_connection(self.conn_id)
tenant = conn.extra_dejson.get("tenantId")
credential: (
ServicePrincipalCredentials
| AzureIdentityCredentialAdapter
| ClientSecretCredential
| DefaultAzureCredential
)
if all([conn.login, conn.password, tenant]):
self.log.info("Getting connection using specific credentials and subscription_id.")
credentials = ServicePrincipalCredentials(
client_id=conn.login, secret=conn.password, tenant=tenant
)
credential = self._get_client_secret_credential(conn)
else:
self.log.info("Using DefaultAzureCredential as credential")
managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id")
workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id")
credentials = AzureIdentityCredentialAdapter(
credential = self._get_default_azure_credential(conn)
return credential

def _get_client_secret_credential(self, conn: Connection):
self.log.info("Getting credentials using specific credentials and subscription_id.")
extra_dejson = conn.extra_dejson
tenant = extra_dejson.get("tenantId")
use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False)
if use_azure_identity_object:
return ClientSecretCredential(
client_id=conn.login, # type: ignore[arg-type]
client_secret=conn.password, # type: ignore[arg-type]
tenant_id=tenant, # type: ignore[arg-type]
)
return ServicePrincipalCredentials(client_id=conn.login, secret=conn.password, tenant=tenant)

def _get_default_azure_credential(self, conn: Connection):
self.log.info("Using DefaultAzureCredential as credential")
extra_dejson = conn.extra_dejson
managed_identity_client_id = extra_dejson.get("managed_identity_client_id")
workload_identity_tenant_id = extra_dejson.get("workload_identity_tenant_id")
use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False)
if use_azure_identity_object:
return get_sync_default_azure_credential(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)

return self.sdk_client(
credentials=credentials,
subscription_id=subscription_id,
return AzureIdentityCredentialAdapter(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)

def get_token(self, *scopes, **kwargs) -> AccessToken:
"""Request an access token for `scopes`."""
credential = self.get_credential()
if isinstance(credential, AzureIdentityCredentialAdapter) or isinstance(
credential, AzureIdentityCredentialAdapter
):
raise AttributeError(
"The azure credential does not support get_token method. "
"Please set `use_azure_identity_object: True` in the connection extra field to use credential that support get_token method."
)
return credential.get_token(*scopes, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Connection = MagicMock() # type: ignore[misc]

MODULE = "airflow.providers.microsoft.azure.hooks.base_azure"
UTILS = "airflow.providers.microsoft.azure.utils"


class TestBaseAzureHook:
Expand Down Expand Up @@ -111,7 +112,7 @@ def test_get_conn_with_credentials(self, mock_spc, mocked_connection):
indirect=True,
)
@patch("azure.common.credentials.ServicePrincipalCredentials")
@patch("airflow.providers.microsoft.azure.hooks.base_azure.AzureIdentityCredentialAdapter")
@patch(f"{MODULE}.AzureIdentityCredentialAdapter")
def test_get_conn_fallback_to_azure_identity_credential_adapter(
self,
mock_credential_adapter,
Expand All @@ -133,3 +134,90 @@ def test_get_conn_fallback_to_azure_identity_credential_adapter(
credentials=mock_credential,
subscription_id="subscription_id",
)

@patch(f"{MODULE}.ClientSecretCredential")
@pytest.mark.parametrize(
"mocked_connection",
[
Connection(
conn_id="azure_default",
login="my_login",
password="my_password",
extra={"tenantId": "my_tenant", "use_azure_identity_object": True},
),
],
indirect=True,
)
def test_get_credential_with_client_secret(self, mock_spc, mocked_connection):
mock_spc.return_value = "foo-bar"
cred = AzureBaseHook().get_credential()

mock_spc.assert_called_once_with(
client_id=mocked_connection.login,
client_secret=mocked_connection.password,
tenant_id=mocked_connection.extra_dejson["tenantId"],
)
assert cred == "foo-bar"

@patch(f"{UTILS}.DefaultAzureCredential")
@pytest.mark.parametrize(
"mocked_connection",
[
Connection(
conn_id="azure_default",
extra={"use_azure_identity_object": True},
),
],
indirect=True,
)
def test_get_credential_with_azure_default_credential(self, mock_spc, mocked_connection):
mock_spc.return_value = "foo-bar"
cred = AzureBaseHook().get_credential()

mock_spc.assert_called_once_with()
assert cred == "foo-bar"

@patch(f"{UTILS}.DefaultAzureCredential")
@pytest.mark.parametrize(
"mocked_connection",
[
Connection(
conn_id="azure_default",
extra={
"managed_identity_client_id": "test_client_id",
"workload_identity_tenant_id": "test_tenant_id",
"use_azure_identity_object": True,
},
),
],
indirect=True,
)
def test_get_credential_with_azure_default_credential_with_extra(self, mock_spc, mocked_connection):
mock_spc.return_value = "foo-bar"
cred = AzureBaseHook().get_credential()

mock_spc.assert_called_once_with(
managed_identity_client_id=mocked_connection.extra_dejson.get("managed_identity_client_id"),
workload_identity_tenant_id=mocked_connection.extra_dejson.get("workload_identity_tenant_id"),
additionally_allowed_tenants=[mocked_connection.extra_dejson.get("workload_identity_tenant_id")],
)
assert cred == "foo-bar"

@patch(f"{UTILS}.DefaultAzureCredential")
@pytest.mark.parametrize(
"mocked_connection",
[
Connection(
conn_id="azure_default",
extra={"use_azure_identity_object": True},
),
],
indirect=True,
)
def test_get_token_with_azure_default_credential(self, mock_spc, mocked_connection):
mock_spc.get_token.return_value = "new-token"
scope = "custom_scope"
token = AzureBaseHook().get_token(scope)

mock_spc.assert_called_once_with()
assert token == "new-token"