diff --git a/providers/postgres/pyproject.toml b/providers/postgres/pyproject.toml index 626d6c78b7eeb..c850d4b115acf 100644 --- a/providers/postgres/pyproject.toml +++ b/providers/postgres/pyproject.toml @@ -72,7 +72,7 @@ dependencies = [ "apache-airflow-providers-amazon>=2.6.0", ] "microsoft.azure" = [ - "apache-airflow-providers-microsoft-azure" + "apache-airflow-providers-microsoft-azure>=12.8.0" ] "openlineage" = [ "apache-airflow-providers-openlineage" diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index 1b678160591bb..f3e33c42ff4f5 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -522,19 +522,17 @@ def get_azure_iam_token(self, conn: Connection) -> tuple[str, str, int]: azure_conn = Connection.get(azure_conn_id) except AttributeError: azure_conn = Connection.get_connection_from_secrets(azure_conn_id) # type: ignore[attr-defined] - azure_base_hook: AzureBaseHook = azure_conn.get_hook() - scope = conf.get("postgres", "azure_oauth_scope", fallback=self.default_azure_oauth_scope) try: - token = azure_base_hook.get_token(scope).token - except AttributeError as e: - if e.name == "get_token" and e.obj == azure_base_hook: - raise AttributeError( - "'AzureBaseHook' object has no attribute 'get_token'. " - "Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0", - name=e.name, - obj=e.obj, + azure_base_hook: AzureBaseHook = azure_conn.get_hook() + except TypeError as e: + if "required positional argument: 'sdk_client'" in str(e): + raise AirflowOptionalProviderFeatureException( + "Getting azure token is not supported by current version of 'AzureBaseHook'. " + "Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0" ) from e raise + scope = conf.get("postgres", "azure_oauth_scope", fallback=self.default_azure_oauth_scope) + token = azure_base_hook.get_token(scope).token return cast("str", conn.login or azure_conn.login), token, conn.port or 5432 def get_table_primary_key(self, table: str, schema: str | None = "public") -> list[str] | None: diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py index 47fffe0b6bb0b..ee634c0f35542 100644 --- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py @@ -27,7 +27,7 @@ import pytest import sqlalchemy -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException from airflow.models import Connection from airflow.providers.postgres.dialects.postgres import PostgresDialect from airflow.providers.postgres.hooks.postgres import CompatConnection, PostgresHook @@ -471,22 +471,22 @@ def test_get_conn_azure_iam(self, mocker, mock_connect): assert mock_db_token in self.db_hook.sqlalchemy_url - def test_get_azure_iam_token_expect_failure_on_get_token(self, mocker): - """Test get_azure_iam_token method gets token from provided connection id""" + def test_get_azure_iam_token_expect_failure_on_older_azure_provider_package(self, mocker): + class MockAzureBaseHookOldVersion: + """Simulate an old version of AzureBaseHook where sdk_client is required.""" - class MockAzureBaseHookWithoutGetToken: - def __init__(self): + def __init__(self, sdk_client, conn_id="azure_default"): pass azure_conn_id = "azure_test_conn" mock_connection_class = mocker.patch("airflow.providers.postgres.hooks.postgres.Connection") - mock_connection_class.get.return_value.get_hook.return_value = MockAzureBaseHookWithoutGetToken() + mock_connection_class.get.return_value.get_hook = MockAzureBaseHookOldVersion self.connection.extra = json.dumps({"iam": True, "azure_conn_id": azure_conn_id}) with pytest.raises( - AttributeError, + AirflowOptionalProviderFeatureException, match=( - "'AzureBaseHook' object has no attribute 'get_token'. " + "Getting azure token is not supported.*" "Please upgrade apache-airflow-providers-microsoft-azure>=" ), ): diff --git a/providers/snowflake/pyproject.toml b/providers/snowflake/pyproject.toml index 727b51bca9a20..91745ebe5bf6f 100644 --- a/providers/snowflake/pyproject.toml +++ b/providers/snowflake/pyproject.toml @@ -76,7 +76,7 @@ dependencies = [ # Any change in the dependencies is preserved when the file is regenerated [project.optional-dependencies] "microsoft.azure" = [ - "apache-airflow-providers-microsoft-azure" + "apache-airflow-providers-microsoft-azure>=12.8.0" ] "openlineage" = [ "apache-airflow-providers-openlineage>=2.3.0" diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 8cabe3fc4af12..a1aa025f72607 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -37,7 +37,7 @@ from sqlalchemy import create_engine from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException from airflow.providers.common.compat.sdk import Connection from airflow.providers.common.sql.hooks.handlers import return_single_query_results from airflow.providers.common.sql.hooks.sql import DbApiHook @@ -268,19 +268,17 @@ def get_azure_oauth_token(self, azure_conn_id: str) -> str: azure_conn = Connection.get(azure_conn_id) except AttributeError: azure_conn = Connection.get_connection_from_secrets(azure_conn_id) # type: ignore[attr-defined] - azure_base_hook: AzureBaseHook = azure_conn.get_hook() - scope = conf.get("snowflake", "azure_oauth_scope", fallback=self.default_azure_oauth_scope) try: - token = azure_base_hook.get_token(scope).token - except AttributeError as e: - if e.name == "get_token" and e.obj == azure_base_hook: - raise AttributeError( - "'AzureBaseHook' object has no attribute 'get_token'. " - "Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0", - name=e.name, - obj=e.obj, + azure_base_hook: AzureBaseHook = azure_conn.get_hook() + except TypeError as e: + if "required positional argument: 'sdk_client'" in str(e): + raise AirflowOptionalProviderFeatureException( + "Getting azure token is not supported by current version of 'AzureBaseHook'. " + "Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0" ) from e raise + scope = conf.get("snowflake", "azure_oauth_scope", fallback=self.default_azure_oauth_scope) + token = azure_base_hook.get_token(scope).token return token @cached_property diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index 6c7b0df885585..c6577b67de8d4 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -1110,22 +1110,22 @@ def test_get_azure_oauth_token(self, mocker): mock_azure_base_hook.get_token.assert_called_once_with(SnowflakeHook.default_azure_oauth_scope) assert token == mock_azure_token - def test_get_azure_oauth_token_expect_failure_on_get_token(self, mocker): - """Test get_azure_oauth_token method gets token from provided connection id""" + def test_get_azure_oauth_token_expect_failure_on_older_azure_provider_package(self, mocker): + class MockAzureBaseHookOldVersion: + """Simulate an old version of AzureBaseHook where sdk_client is required.""" - class MockAzureBaseHookWithoutGetToken: - def __init__(self): + def __init__(self, sdk_client, conn_id="azure_default"): pass azure_conn_id = "azure_test_conn" mock_connection_class = mocker.patch("airflow.providers.snowflake.hooks.snowflake.Connection") - mock_connection_class.get.return_value.get_hook.return_value = MockAzureBaseHookWithoutGetToken() + mock_connection_class.get.return_value.get_hook = MockAzureBaseHookOldVersion hook = SnowflakeHook(snowflake_conn_id="mock_conn_id") with pytest.raises( - AttributeError, + AirflowOptionalProviderFeatureException, match=( - "'AzureBaseHook' object has no attribute 'get_token'. " + "Getting azure token is not supported.*" "Please upgrade apache-airflow-providers-microsoft-azure>=" ), ):