diff --git a/providers/microsoft/azure/docs/connections/azure.rst b/providers/microsoft/azure/docs/connections/azure.rst index abe425067eb84..f8d111fd34a10 100644 --- a/providers/microsoft/azure/docs/connections/azure.rst +++ b/providers/microsoft/azure/docs/connections/azure.rst @@ -74,8 +74,6 @@ Extra (optional) It specifies the json that contains the authentication information. * ``managed_identity_client_id``: The client ID of a user-assigned managed identity. If provided with ``workload_identity_tenant_id``, they'll pass to DefaultAzureCredential_. * ``workload_identity_tenant_id``: ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with ``managed_identity_client_id``, they'll pass to DefaultAzureCredential_. - * ``use_azure_identity_object``: If set to true, it will use credential of newer type: ClientSecretCredential or DefaultAzureCredential instead of ServicePrincipalCredentials or AzureIdentityCredentialAdapter. - These newer credentials support get_token method which can be used to generate OAuth token with custom scope. The entire extra column can be left out to fall back on DefaultAzureCredential_. diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py index 5693321676088..2a59234e0a4c7 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py @@ -16,25 +16,18 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict from azure.common.credentials import ServicePrincipalCredentials -from azure.identity import ClientSecretCredential, DefaultAzureCredential from airflow.exceptions import AirflowException from airflow.providers.microsoft.azure.utils import ( AzureIdentityCredentialAdapter, add_managed_identity_connection_widgets, - get_sync_default_azure_credential, ) from airflow.providers.microsoft.azure.version_compat import BaseHook -if TYPE_CHECKING: - from azure.core.credentials import AccessToken - - from airflow.sdk import Connection - class AzureBaseHook(BaseHook): """ @@ -92,7 +85,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: }, } - def __init__(self, sdk_client: Any = None, conn_id: str = "azure_default"): + def __init__(self, sdk_client: Any, conn_id: str = "azure_default"): self.sdk_client = sdk_client self.conn_id = conn_id super().__init__() @@ -103,9 +96,8 @@ def get_conn(self) -> Any: :return: the authenticated client. """ - if not self.sdk_client: - raise ValueError("`sdk_client` must be provided to AzureBaseHook to use `get_conn` method.") conn = self.get_connection(self.conn_id) + tenant = conn.extra_dejson.get("tenantId") subscription_id = conn.extra_dejson.get("subscriptionId") key_path = conn.extra_dejson.get("key_path") if key_path: @@ -119,74 +111,22 @@ def get_conn(self) -> Any: self.log.info("Getting connection using a JSON config.") return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json) - credentials = self.get_credential(conn=conn) - - return self.sdk_client( - credentials=credentials, - subscription_id=subscription_id, - ) - - def get_credential(self, *, conn: Connection | None = None) -> Any: - """ - Get Azure credential object for the connection. - - Azure Identity based credential object (``ClientSecretCredential``, ``DefaultAzureCredential``) can be used to get OAuth token using ``get_token`` method. - Older Credential objects (``ServicePrincipalCredentials``, ``AzureIdentityCredentialAdapter``) are supported for backward compatibility. - - :return: The Azure credential object - """ - if not conn: - conn = self.get_connection(self.conn_id) - tenant = conn.extra_dejson.get("tenantId") - credential: ( - ServicePrincipalCredentials - | AzureIdentityCredentialAdapter - | ClientSecretCredential - | DefaultAzureCredential - ) + credentials: ServicePrincipalCredentials | AzureIdentityCredentialAdapter if all([conn.login, conn.password, tenant]): - credential = self._get_client_secret_credential(conn) - else: - credential = self._get_default_azure_credential(conn) - return credential - - def _get_client_secret_credential(self, conn: Connection): - self.log.info("Getting credentials using specific credentials and subscription_id.") - extra_dejson = conn.extra_dejson - tenant = extra_dejson.get("tenantId") - use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False) - if use_azure_identity_object: - return ClientSecretCredential( - client_id=conn.login, # type: ignore[arg-type] - client_secret=conn.password, # type: ignore[arg-type] - tenant_id=tenant, # type: ignore[arg-type] + self.log.info("Getting connection using specific credentials and subscription_id.") + credentials = ServicePrincipalCredentials( + client_id=conn.login, secret=conn.password, tenant=tenant ) - return ServicePrincipalCredentials(client_id=conn.login, secret=conn.password, tenant=tenant) - - def _get_default_azure_credential(self, conn: Connection): - self.log.info("Using DefaultAzureCredential as credential") - extra_dejson = conn.extra_dejson - managed_identity_client_id = extra_dejson.get("managed_identity_client_id") - workload_identity_tenant_id = extra_dejson.get("workload_identity_tenant_id") - use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False) - if use_azure_identity_object: - return get_sync_default_azure_credential( + else: + self.log.info("Using DefaultAzureCredential as credential") + managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id") + workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id") + credentials = AzureIdentityCredentialAdapter( managed_identity_client_id=managed_identity_client_id, workload_identity_tenant_id=workload_identity_tenant_id, ) - return AzureIdentityCredentialAdapter( - managed_identity_client_id=managed_identity_client_id, - workload_identity_tenant_id=workload_identity_tenant_id, - ) - def get_token(self, *scopes, **kwargs) -> AccessToken: - """Request an access token for `scopes`.""" - credential = self.get_credential() - if isinstance(credential, AzureIdentityCredentialAdapter) or isinstance( - credential, AzureIdentityCredentialAdapter - ): - raise AttributeError( - "The azure credential does not support get_token method. " - "Please set `use_azure_identity_object: True` in the connection extra field to use credential that support get_token method." - ) - return credential.get_token(*scopes, **kwargs) + return self.sdk_client( + credentials=credentials, + subscription_id=subscription_id, + ) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py index 98df1a4c08fcc..89881eae16513 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py @@ -31,7 +31,6 @@ Connection = MagicMock() # type: ignore[misc] MODULE = "airflow.providers.microsoft.azure.hooks.base_azure" -UTILS = "airflow.providers.microsoft.azure.utils" class TestBaseAzureHook: @@ -112,7 +111,7 @@ def test_get_conn_with_credentials(self, mock_spc, mocked_connection): indirect=True, ) @patch("azure.common.credentials.ServicePrincipalCredentials") - @patch(f"{MODULE}.AzureIdentityCredentialAdapter") + @patch("airflow.providers.microsoft.azure.hooks.base_azure.AzureIdentityCredentialAdapter") def test_get_conn_fallback_to_azure_identity_credential_adapter( self, mock_credential_adapter, @@ -134,90 +133,3 @@ def test_get_conn_fallback_to_azure_identity_credential_adapter( credentials=mock_credential, subscription_id="subscription_id", ) - - @patch(f"{MODULE}.ClientSecretCredential") - @pytest.mark.parametrize( - "mocked_connection", - [ - Connection( - conn_id="azure_default", - login="my_login", - password="my_password", - extra={"tenantId": "my_tenant", "use_azure_identity_object": True}, - ), - ], - indirect=True, - ) - def test_get_credential_with_client_secret(self, mock_spc, mocked_connection): - mock_spc.return_value = "foo-bar" - cred = AzureBaseHook().get_credential() - - mock_spc.assert_called_once_with( - client_id=mocked_connection.login, - client_secret=mocked_connection.password, - tenant_id=mocked_connection.extra_dejson["tenantId"], - ) - assert cred == "foo-bar" - - @patch(f"{UTILS}.DefaultAzureCredential") - @pytest.mark.parametrize( - "mocked_connection", - [ - Connection( - conn_id="azure_default", - extra={"use_azure_identity_object": True}, - ), - ], - indirect=True, - ) - def test_get_credential_with_azure_default_credential(self, mock_spc, mocked_connection): - mock_spc.return_value = "foo-bar" - cred = AzureBaseHook().get_credential() - - mock_spc.assert_called_once_with() - assert cred == "foo-bar" - - @patch(f"{UTILS}.DefaultAzureCredential") - @pytest.mark.parametrize( - "mocked_connection", - [ - Connection( - conn_id="azure_default", - extra={ - "managed_identity_client_id": "test_client_id", - "workload_identity_tenant_id": "test_tenant_id", - "use_azure_identity_object": True, - }, - ), - ], - indirect=True, - ) - def test_get_credential_with_azure_default_credential_with_extra(self, mock_spc, mocked_connection): - mock_spc.return_value = "foo-bar" - cred = AzureBaseHook().get_credential() - - mock_spc.assert_called_once_with( - managed_identity_client_id=mocked_connection.extra_dejson.get("managed_identity_client_id"), - workload_identity_tenant_id=mocked_connection.extra_dejson.get("workload_identity_tenant_id"), - additionally_allowed_tenants=[mocked_connection.extra_dejson.get("workload_identity_tenant_id")], - ) - assert cred == "foo-bar" - - @patch(f"{UTILS}.DefaultAzureCredential") - @pytest.mark.parametrize( - "mocked_connection", - [ - Connection( - conn_id="azure_default", - extra={"use_azure_identity_object": True}, - ), - ], - indirect=True, - ) - def test_get_token_with_azure_default_credential(self, mock_spc, mocked_connection): - mock_spc.get_token.return_value = "new-token" - scope = "custom_scope" - token = AzureBaseHook().get_token(scope) - - mock_spc.assert_called_once_with() - assert token == "new-token"