Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions providers/microsoft/azure/docs/connections/azure.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ Extra (optional)
It specifies the json that contains the authentication information.
* ``managed_identity_client_id``: The client ID of a user-assigned managed identity. If provided with ``workload_identity_tenant_id``, they'll pass to DefaultAzureCredential_.
* ``workload_identity_tenant_id``: ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with ``managed_identity_client_id``, they'll pass to DefaultAzureCredential_.
* ``use_azure_identity_object``: If set to true, it will use credential of newer type: ClientSecretCredential or DefaultAzureCredential instead of ServicePrincipalCredentials or AzureIdentityCredentialAdapter.
These newer credentials support get_token method which can be used to generate OAuth token with custom scope.

The entire extra column can be left out to fall back on DefaultAzureCredential_.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,18 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import Any

from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict
from azure.common.credentials import ServicePrincipalCredentials
from azure.identity import ClientSecretCredential, DefaultAzureCredential

from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.utils import (
AzureIdentityCredentialAdapter,
add_managed_identity_connection_widgets,
get_sync_default_azure_credential,
)
from airflow.providers.microsoft.azure.version_compat import BaseHook

if TYPE_CHECKING:
from azure.core.credentials import AccessToken

from airflow.sdk import Connection


class AzureBaseHook(BaseHook):
"""
Expand Down Expand Up @@ -92,7 +85,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
},
}

def __init__(self, sdk_client: Any = None, conn_id: str = "azure_default"):
def __init__(self, sdk_client: Any, conn_id: str = "azure_default"):
self.sdk_client = sdk_client
self.conn_id = conn_id
super().__init__()
Expand All @@ -103,9 +96,8 @@ def get_conn(self) -> Any:

:return: the authenticated client.
"""
if not self.sdk_client:
raise ValueError("`sdk_client` must be provided to AzureBaseHook to use `get_conn` method.")
conn = self.get_connection(self.conn_id)
tenant = conn.extra_dejson.get("tenantId")
subscription_id = conn.extra_dejson.get("subscriptionId")
key_path = conn.extra_dejson.get("key_path")
if key_path:
Expand All @@ -119,74 +111,22 @@ def get_conn(self) -> Any:
self.log.info("Getting connection using a JSON config.")
return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json)

credentials = self.get_credential(conn=conn)

return self.sdk_client(
credentials=credentials,
subscription_id=subscription_id,
)

def get_credential(self, *, conn: Connection | None = None) -> Any:
"""
Get Azure credential object for the connection.

Azure Identity based credential object (``ClientSecretCredential``, ``DefaultAzureCredential``) can be used to get OAuth token using ``get_token`` method.
Older Credential objects (``ServicePrincipalCredentials``, ``AzureIdentityCredentialAdapter``) are supported for backward compatibility.

:return: The Azure credential object
"""
if not conn:
conn = self.get_connection(self.conn_id)
tenant = conn.extra_dejson.get("tenantId")
credential: (
ServicePrincipalCredentials
| AzureIdentityCredentialAdapter
| ClientSecretCredential
| DefaultAzureCredential
)
credentials: ServicePrincipalCredentials | AzureIdentityCredentialAdapter
if all([conn.login, conn.password, tenant]):
credential = self._get_client_secret_credential(conn)
else:
credential = self._get_default_azure_credential(conn)
return credential

def _get_client_secret_credential(self, conn: Connection):
self.log.info("Getting credentials using specific credentials and subscription_id.")
extra_dejson = conn.extra_dejson
tenant = extra_dejson.get("tenantId")
use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False)
if use_azure_identity_object:
return ClientSecretCredential(
client_id=conn.login, # type: ignore[arg-type]
client_secret=conn.password, # type: ignore[arg-type]
tenant_id=tenant, # type: ignore[arg-type]
self.log.info("Getting connection using specific credentials and subscription_id.")
credentials = ServicePrincipalCredentials(
client_id=conn.login, secret=conn.password, tenant=tenant
)
return ServicePrincipalCredentials(client_id=conn.login, secret=conn.password, tenant=tenant)

def _get_default_azure_credential(self, conn: Connection):
self.log.info("Using DefaultAzureCredential as credential")
extra_dejson = conn.extra_dejson
managed_identity_client_id = extra_dejson.get("managed_identity_client_id")
workload_identity_tenant_id = extra_dejson.get("workload_identity_tenant_id")
use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False)
if use_azure_identity_object:
return get_sync_default_azure_credential(
else:
self.log.info("Using DefaultAzureCredential as credential")
managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id")
workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id")
credentials = AzureIdentityCredentialAdapter(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)
return AzureIdentityCredentialAdapter(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)

def get_token(self, *scopes, **kwargs) -> AccessToken:
"""Request an access token for `scopes`."""
credential = self.get_credential()
if isinstance(credential, AzureIdentityCredentialAdapter) or isinstance(
credential, AzureIdentityCredentialAdapter
):
raise AttributeError(
"The azure credential does not support get_token method. "
"Please set `use_azure_identity_object: True` in the connection extra field to use credential that support get_token method."
)
return credential.get_token(*scopes, **kwargs)
return self.sdk_client(
credentials=credentials,
subscription_id=subscription_id,
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
Connection = MagicMock() # type: ignore[misc]

MODULE = "airflow.providers.microsoft.azure.hooks.base_azure"
UTILS = "airflow.providers.microsoft.azure.utils"


class TestBaseAzureHook:
Expand Down Expand Up @@ -112,7 +111,7 @@ def test_get_conn_with_credentials(self, mock_spc, mocked_connection):
indirect=True,
)
@patch("azure.common.credentials.ServicePrincipalCredentials")
@patch(f"{MODULE}.AzureIdentityCredentialAdapter")
@patch("airflow.providers.microsoft.azure.hooks.base_azure.AzureIdentityCredentialAdapter")
def test_get_conn_fallback_to_azure_identity_credential_adapter(
self,
mock_credential_adapter,
Expand All @@ -134,90 +133,3 @@ def test_get_conn_fallback_to_azure_identity_credential_adapter(
credentials=mock_credential,
subscription_id="subscription_id",
)

@patch(f"{MODULE}.ClientSecretCredential")
@pytest.mark.parametrize(
"mocked_connection",
[
Connection(
conn_id="azure_default",
login="my_login",
password="my_password",
extra={"tenantId": "my_tenant", "use_azure_identity_object": True},
),
],
indirect=True,
)
def test_get_credential_with_client_secret(self, mock_spc, mocked_connection):
mock_spc.return_value = "foo-bar"
cred = AzureBaseHook().get_credential()

mock_spc.assert_called_once_with(
client_id=mocked_connection.login,
client_secret=mocked_connection.password,
tenant_id=mocked_connection.extra_dejson["tenantId"],
)
assert cred == "foo-bar"

@patch(f"{UTILS}.DefaultAzureCredential")
@pytest.mark.parametrize(
"mocked_connection",
[
Connection(
conn_id="azure_default",
extra={"use_azure_identity_object": True},
),
],
indirect=True,
)
def test_get_credential_with_azure_default_credential(self, mock_spc, mocked_connection):
mock_spc.return_value = "foo-bar"
cred = AzureBaseHook().get_credential()

mock_spc.assert_called_once_with()
assert cred == "foo-bar"

@patch(f"{UTILS}.DefaultAzureCredential")
@pytest.mark.parametrize(
"mocked_connection",
[
Connection(
conn_id="azure_default",
extra={
"managed_identity_client_id": "test_client_id",
"workload_identity_tenant_id": "test_tenant_id",
"use_azure_identity_object": True,
},
),
],
indirect=True,
)
def test_get_credential_with_azure_default_credential_with_extra(self, mock_spc, mocked_connection):
mock_spc.return_value = "foo-bar"
cred = AzureBaseHook().get_credential()

mock_spc.assert_called_once_with(
managed_identity_client_id=mocked_connection.extra_dejson.get("managed_identity_client_id"),
workload_identity_tenant_id=mocked_connection.extra_dejson.get("workload_identity_tenant_id"),
additionally_allowed_tenants=[mocked_connection.extra_dejson.get("workload_identity_tenant_id")],
)
assert cred == "foo-bar"

@patch(f"{UTILS}.DefaultAzureCredential")
@pytest.mark.parametrize(
"mocked_connection",
[
Connection(
conn_id="azure_default",
extra={"use_azure_identity_object": True},
),
],
indirect=True,
)
def test_get_token_with_azure_default_credential(self, mock_spc, mocked_connection):
mock_spc.get_token.return_value = "new-token"
scope = "custom_scope"
token = AzureBaseHook().get_token(scope)

mock_spc.assert_called_once_with()
assert token == "new-token"
Loading