diff --git a/providers/microsoft/azure/docs/connections/azure.rst b/providers/microsoft/azure/docs/connections/azure.rst index f8d111fd34a10..abe425067eb84 100644 --- a/providers/microsoft/azure/docs/connections/azure.rst +++ b/providers/microsoft/azure/docs/connections/azure.rst @@ -74,6 +74,8 @@ Extra (optional) It specifies the json that contains the authentication information. * ``managed_identity_client_id``: The client ID of a user-assigned managed identity. If provided with ``workload_identity_tenant_id``, they'll pass to DefaultAzureCredential_. * ``workload_identity_tenant_id``: ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with ``managed_identity_client_id``, they'll pass to DefaultAzureCredential_. + * ``use_azure_identity_object``: If set to true, it will use credential of newer type: ClientSecretCredential or DefaultAzureCredential instead of ServicePrincipalCredentials or AzureIdentityCredentialAdapter. + These newer credentials support get_token method which can be used to generate OAuth token with custom scope. The entire extra column can be left out to fall back on DefaultAzureCredential_. diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py index 2a59234e0a4c7..5693321676088 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py @@ -16,18 +16,25 @@ # under the License. from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict from azure.common.credentials import ServicePrincipalCredentials +from azure.identity import ClientSecretCredential, DefaultAzureCredential from airflow.exceptions import AirflowException from airflow.providers.microsoft.azure.utils import ( AzureIdentityCredentialAdapter, add_managed_identity_connection_widgets, + get_sync_default_azure_credential, ) from airflow.providers.microsoft.azure.version_compat import BaseHook +if TYPE_CHECKING: + from azure.core.credentials import AccessToken + + from airflow.sdk import Connection + class AzureBaseHook(BaseHook): """ @@ -85,7 +92,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: }, } - def __init__(self, sdk_client: Any, conn_id: str = "azure_default"): + def __init__(self, sdk_client: Any = None, conn_id: str = "azure_default"): self.sdk_client = sdk_client self.conn_id = conn_id super().__init__() @@ -96,8 +103,9 @@ def get_conn(self) -> Any: :return: the authenticated client. """ + if not self.sdk_client: + raise ValueError("`sdk_client` must be provided to AzureBaseHook to use `get_conn` method.") conn = self.get_connection(self.conn_id) - tenant = conn.extra_dejson.get("tenantId") subscription_id = conn.extra_dejson.get("subscriptionId") key_path = conn.extra_dejson.get("key_path") if key_path: @@ -111,22 +119,74 @@ def get_conn(self) -> Any: self.log.info("Getting connection using a JSON config.") return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json) - credentials: ServicePrincipalCredentials | AzureIdentityCredentialAdapter + credentials = self.get_credential(conn=conn) + + return self.sdk_client( + credentials=credentials, + subscription_id=subscription_id, + ) + + def get_credential(self, *, conn: Connection | None = None) -> Any: + """ + Get Azure credential object for the connection. + + Azure Identity based credential object (``ClientSecretCredential``, ``DefaultAzureCredential``) can be used to get OAuth token using ``get_token`` method. + Older Credential objects (``ServicePrincipalCredentials``, ``AzureIdentityCredentialAdapter``) are supported for backward compatibility. + + :return: The Azure credential object + """ + if not conn: + conn = self.get_connection(self.conn_id) + tenant = conn.extra_dejson.get("tenantId") + credential: ( + ServicePrincipalCredentials + | AzureIdentityCredentialAdapter + | ClientSecretCredential + | DefaultAzureCredential + ) if all([conn.login, conn.password, tenant]): - self.log.info("Getting connection using specific credentials and subscription_id.") - credentials = ServicePrincipalCredentials( - client_id=conn.login, secret=conn.password, tenant=tenant - ) + credential = self._get_client_secret_credential(conn) else: - self.log.info("Using DefaultAzureCredential as credential") - managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id") - workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id") - credentials = AzureIdentityCredentialAdapter( + credential = self._get_default_azure_credential(conn) + return credential + + def _get_client_secret_credential(self, conn: Connection): + self.log.info("Getting credentials using specific credentials and subscription_id.") + extra_dejson = conn.extra_dejson + tenant = extra_dejson.get("tenantId") + use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False) + if use_azure_identity_object: + return ClientSecretCredential( + client_id=conn.login, # type: ignore[arg-type] + client_secret=conn.password, # type: ignore[arg-type] + tenant_id=tenant, # type: ignore[arg-type] + ) + return ServicePrincipalCredentials(client_id=conn.login, secret=conn.password, tenant=tenant) + + def _get_default_azure_credential(self, conn: Connection): + self.log.info("Using DefaultAzureCredential as credential") + extra_dejson = conn.extra_dejson + managed_identity_client_id = extra_dejson.get("managed_identity_client_id") + workload_identity_tenant_id = extra_dejson.get("workload_identity_tenant_id") + use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False) + if use_azure_identity_object: + return get_sync_default_azure_credential( managed_identity_client_id=managed_identity_client_id, workload_identity_tenant_id=workload_identity_tenant_id, ) - - return self.sdk_client( - credentials=credentials, - subscription_id=subscription_id, + return AzureIdentityCredentialAdapter( + managed_identity_client_id=managed_identity_client_id, + workload_identity_tenant_id=workload_identity_tenant_id, ) + + def get_token(self, *scopes, **kwargs) -> AccessToken: + """Request an access token for `scopes`.""" + credential = self.get_credential() + if isinstance(credential, AzureIdentityCredentialAdapter) or isinstance( + credential, AzureIdentityCredentialAdapter + ): + raise AttributeError( + "The azure credential does not support get_token method. " + "Please set `use_azure_identity_object: True` in the connection extra field to use credential that support get_token method." + ) + return credential.get_token(*scopes, **kwargs) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py index 89881eae16513..98df1a4c08fcc 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py @@ -31,6 +31,7 @@ Connection = MagicMock() # type: ignore[misc] MODULE = "airflow.providers.microsoft.azure.hooks.base_azure" +UTILS = "airflow.providers.microsoft.azure.utils" class TestBaseAzureHook: @@ -111,7 +112,7 @@ def test_get_conn_with_credentials(self, mock_spc, mocked_connection): indirect=True, ) @patch("azure.common.credentials.ServicePrincipalCredentials") - @patch("airflow.providers.microsoft.azure.hooks.base_azure.AzureIdentityCredentialAdapter") + @patch(f"{MODULE}.AzureIdentityCredentialAdapter") def test_get_conn_fallback_to_azure_identity_credential_adapter( self, mock_credential_adapter, @@ -133,3 +134,90 @@ def test_get_conn_fallback_to_azure_identity_credential_adapter( credentials=mock_credential, subscription_id="subscription_id", ) + + @patch(f"{MODULE}.ClientSecretCredential") + @pytest.mark.parametrize( + "mocked_connection", + [ + Connection( + conn_id="azure_default", + login="my_login", + password="my_password", + extra={"tenantId": "my_tenant", "use_azure_identity_object": True}, + ), + ], + indirect=True, + ) + def test_get_credential_with_client_secret(self, mock_spc, mocked_connection): + mock_spc.return_value = "foo-bar" + cred = AzureBaseHook().get_credential() + + mock_spc.assert_called_once_with( + client_id=mocked_connection.login, + client_secret=mocked_connection.password, + tenant_id=mocked_connection.extra_dejson["tenantId"], + ) + assert cred == "foo-bar" + + @patch(f"{UTILS}.DefaultAzureCredential") + @pytest.mark.parametrize( + "mocked_connection", + [ + Connection( + conn_id="azure_default", + extra={"use_azure_identity_object": True}, + ), + ], + indirect=True, + ) + def test_get_credential_with_azure_default_credential(self, mock_spc, mocked_connection): + mock_spc.return_value = "foo-bar" + cred = AzureBaseHook().get_credential() + + mock_spc.assert_called_once_with() + assert cred == "foo-bar" + + @patch(f"{UTILS}.DefaultAzureCredential") + @pytest.mark.parametrize( + "mocked_connection", + [ + Connection( + conn_id="azure_default", + extra={ + "managed_identity_client_id": "test_client_id", + "workload_identity_tenant_id": "test_tenant_id", + "use_azure_identity_object": True, + }, + ), + ], + indirect=True, + ) + def test_get_credential_with_azure_default_credential_with_extra(self, mock_spc, mocked_connection): + mock_spc.return_value = "foo-bar" + cred = AzureBaseHook().get_credential() + + mock_spc.assert_called_once_with( + managed_identity_client_id=mocked_connection.extra_dejson.get("managed_identity_client_id"), + workload_identity_tenant_id=mocked_connection.extra_dejson.get("workload_identity_tenant_id"), + additionally_allowed_tenants=[mocked_connection.extra_dejson.get("workload_identity_tenant_id")], + ) + assert cred == "foo-bar" + + @patch(f"{UTILS}.DefaultAzureCredential") + @pytest.mark.parametrize( + "mocked_connection", + [ + Connection( + conn_id="azure_default", + extra={"use_azure_identity_object": True}, + ), + ], + indirect=True, + ) + def test_get_token_with_azure_default_credential(self, mock_spc, mocked_connection): + mock_spc.get_token.return_value = "new-token" + scope = "custom_scope" + token = AzureBaseHook().get_token(scope) + + mock_spc.assert_called_once_with() + assert token == "new-token"