Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/postgres/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ dependencies = [
"apache-airflow-providers-amazon>=2.6.0",
]
"microsoft.azure" = [
"apache-airflow-providers-microsoft-azure"
"apache-airflow-providers-microsoft-azure>=12.8.0"
]
"openlineage" = [
"apache-airflow-providers-openlineage"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,19 +522,17 @@ def get_azure_iam_token(self, conn: Connection) -> tuple[str, str, int]:
azure_conn = Connection.get(azure_conn_id)
except AttributeError:
azure_conn = Connection.get_connection_from_secrets(azure_conn_id) # type: ignore[attr-defined]
azure_base_hook: AzureBaseHook = azure_conn.get_hook()
scope = conf.get("postgres", "azure_oauth_scope", fallback=self.default_azure_oauth_scope)
try:
token = azure_base_hook.get_token(scope).token
except AttributeError as e:
if e.name == "get_token" and e.obj == azure_base_hook:
raise AttributeError(
"'AzureBaseHook' object has no attribute 'get_token'. "
"Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0",
name=e.name,
obj=e.obj,
azure_base_hook: AzureBaseHook = azure_conn.get_hook()
except TypeError as e:
if "required positional argument: 'sdk_client'" in str(e):
raise AirflowOptionalProviderFeatureException(
"Getting azure token is not supported by current version of 'AzureBaseHook'. "
"Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0"
) from e
raise
scope = conf.get("postgres", "azure_oauth_scope", fallback=self.default_azure_oauth_scope)
token = azure_base_hook.get_token(scope).token
return cast("str", conn.login or azure_conn.login), token, conn.port or 5432

def get_table_primary_key(self, table: str, schema: str | None = "public") -> list[str] | None:
Expand Down
16 changes: 8 additions & 8 deletions providers/postgres/tests/unit/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import pytest
import sqlalchemy

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException
from airflow.models import Connection
from airflow.providers.postgres.dialects.postgres import PostgresDialect
from airflow.providers.postgres.hooks.postgres import CompatConnection, PostgresHook
Expand Down Expand Up @@ -471,22 +471,22 @@ def test_get_conn_azure_iam(self, mocker, mock_connect):

assert mock_db_token in self.db_hook.sqlalchemy_url

def test_get_azure_iam_token_expect_failure_on_get_token(self, mocker):
"""Test get_azure_iam_token method gets token from provided connection id"""
def test_get_azure_iam_token_expect_failure_on_older_azure_provider_package(self, mocker):
class MockAzureBaseHookOldVersion:
"""Simulate an old version of AzureBaseHook where sdk_client is required."""

class MockAzureBaseHookWithoutGetToken:
def __init__(self):
def __init__(self, sdk_client, conn_id="azure_default"):
pass

azure_conn_id = "azure_test_conn"
mock_connection_class = mocker.patch("airflow.providers.postgres.hooks.postgres.Connection")
mock_connection_class.get.return_value.get_hook.return_value = MockAzureBaseHookWithoutGetToken()
mock_connection_class.get.return_value.get_hook = MockAzureBaseHookOldVersion

self.connection.extra = json.dumps({"iam": True, "azure_conn_id": azure_conn_id})
with pytest.raises(
AttributeError,
AirflowOptionalProviderFeatureException,
match=(
"'AzureBaseHook' object has no attribute 'get_token'. "
"Getting azure token is not supported.*"
"Please upgrade apache-airflow-providers-microsoft-azure>="
),
):
Expand Down
2 changes: 1 addition & 1 deletion providers/snowflake/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ dependencies = [
# Any change in the dependencies is preserved when the file is regenerated
[project.optional-dependencies]
"microsoft.azure" = [
"apache-airflow-providers-microsoft-azure"
"apache-airflow-providers-microsoft-azure>=12.8.0"
]
"openlineage" = [
"apache-airflow-providers-openlineage>=2.3.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from sqlalchemy import create_engine

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException
from airflow.providers.common.compat.sdk import Connection
from airflow.providers.common.sql.hooks.handlers import return_single_query_results
from airflow.providers.common.sql.hooks.sql import DbApiHook
Expand Down Expand Up @@ -268,19 +268,17 @@ def get_azure_oauth_token(self, azure_conn_id: str) -> str:
azure_conn = Connection.get(azure_conn_id)
except AttributeError:
azure_conn = Connection.get_connection_from_secrets(azure_conn_id) # type: ignore[attr-defined]
azure_base_hook: AzureBaseHook = azure_conn.get_hook()
scope = conf.get("snowflake", "azure_oauth_scope", fallback=self.default_azure_oauth_scope)
try:
token = azure_base_hook.get_token(scope).token
except AttributeError as e:
if e.name == "get_token" and e.obj == azure_base_hook:
raise AttributeError(
"'AzureBaseHook' object has no attribute 'get_token'. "
"Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0",
name=e.name,
obj=e.obj,
azure_base_hook: AzureBaseHook = azure_conn.get_hook()
except TypeError as e:
if "required positional argument: 'sdk_client'" in str(e):
raise AirflowOptionalProviderFeatureException(
"Getting azure token is not supported by current version of 'AzureBaseHook'. "
"Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0"
) from e
raise
scope = conf.get("snowflake", "azure_oauth_scope", fallback=self.default_azure_oauth_scope)
token = azure_base_hook.get_token(scope).token
return token

@cached_property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1110,22 +1110,22 @@ def test_get_azure_oauth_token(self, mocker):
mock_azure_base_hook.get_token.assert_called_once_with(SnowflakeHook.default_azure_oauth_scope)
assert token == mock_azure_token

def test_get_azure_oauth_token_expect_failure_on_get_token(self, mocker):
"""Test get_azure_oauth_token method gets token from provided connection id"""
def test_get_azure_oauth_token_expect_failure_on_older_azure_provider_package(self, mocker):
class MockAzureBaseHookOldVersion:
"""Simulate an old version of AzureBaseHook where sdk_client is required."""

class MockAzureBaseHookWithoutGetToken:
def __init__(self):
def __init__(self, sdk_client, conn_id="azure_default"):
pass

azure_conn_id = "azure_test_conn"
mock_connection_class = mocker.patch("airflow.providers.snowflake.hooks.snowflake.Connection")
mock_connection_class.get.return_value.get_hook.return_value = MockAzureBaseHookWithoutGetToken()
mock_connection_class.get.return_value.get_hook = MockAzureBaseHookOldVersion

hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
with pytest.raises(
AttributeError,
AirflowOptionalProviderFeatureException,
match=(
"'AzureBaseHook' object has no attribute 'get_token'. "
"Getting azure token is not supported.*"
"Please upgrade apache-airflow-providers-microsoft-azure>="
),
):
Expand Down