diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 372fcbfe4a13c..df1b584ff9dc7 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -21,6 +21,7 @@ import os from collections.abc import Callable, Iterable, Mapping from contextlib import closing, contextmanager +from datetime import datetime, timedelta from functools import cached_property from io import StringIO from pathlib import Path @@ -41,9 +42,14 @@ from airflow.providers.common.sql.hooks.handlers import return_single_query_results from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri +from airflow.utils import timezone from airflow.utils.strings import to_boolean +OAUTH_REQUEST_TIMEOUT = 30 # seconds, avoid hanging tasks on token request +OAUTH_EXPIRY_BUFFER = 30 T = TypeVar("T") + + if TYPE_CHECKING: from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import DatabaseInfo @@ -173,6 +179,11 @@ def __init__(self, *args, **kwargs) -> None: self.client_store_temporary_credential = kwargs.pop("client_store_temporary_credential", None) self.query_ids: list[str] = [] + # Access token and expiration timestamp persisted + # to handle premature expiry. + self._oauth_token: str | None = None + self._oauth_token_expires_at: datetime | None = None + def _get_field(self, extra_dict, field_name): backcompat_prefix = "extra__snowflake__" backcompat_key = f"{backcompat_prefix}{field_name}" @@ -198,7 +209,7 @@ def _get_field(self, extra_dict, field_name): @property def account_identifier(self) -> str: """Get snowflake account identifier.""" - conn_config = self._get_conn_params + conn_config = self._get_conn_params() account_identifier = f"https://{conn_config['account']}" if conn_config["region"]: @@ -214,46 +225,15 @@ def get_oauth_token( ) -> str: """Generate temporary OAuth access token using refresh token in connection details.""" if conn_config is None: - conn_config = self._get_conn_params + conn_config = self._get_static_conn_params - url = token_endpoint or f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request" + if token_endpoint is None: + token_endpoint = conn_config.get("token_endpoint") - data = { - "grant_type": grant_type, - "redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"), - } - - scope = conn_config.get("scope") - - if scope: - data["scope"] = scope - - if grant_type == "refresh_token": - data |= { - "refresh_token": conn_config["refresh_token"], - } - elif grant_type == "client_credentials": - pass # no setup necessary for client credentials grant. - else: - raise ValueError(f"Unknown grant_type: {grant_type}") - - response = requests.post( - url, - data=data, - headers={ - "Content-Type": "application/x-www-form-urlencoded", - }, - auth=HTTPBasicAuth(conn_config["client_id"], conn_config["client_secret"]), # type: ignore[arg-type] + return self._get_valid_oauth_token( + conn_config=conn_config, token_endpoint=token_endpoint, grant_type=grant_type ) - try: - response.raise_for_status() - except requests.exceptions.HTTPError as e: # pragma: no cover - msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}" - raise AirflowException(msg) - token = response.json()["access_token"] - return token - def get_azure_oauth_token(self, azure_conn_id: str) -> str: """ Generate OAuth access token using Azure connection id. @@ -286,12 +266,42 @@ def get_azure_oauth_token(self, azure_conn_id: str) -> str: token = azure_base_hook.get_token(scope).token return token - @cached_property def _get_conn_params(self) -> dict[str, str | None]: """ Fetch connection params as a dict. - This is used in ``get_uri()`` and ``get_connection()``. + This is used in ``get_uri()`` and ``get_conn()``. + """ + static_config = self._get_static_conn_params + conn_config = dict(static_config) + + if conn_config.get("authenticator") == "oauth": + azure_conn_id = conn_config.get("azure_conn_id") + if azure_conn_id: + conn_config["token"] = self.get_azure_oauth_token(azure_conn_id) + else: + grant_type = conn_config.get("grant_type") + if not grant_type: + raise ValueError("Grant_type not provided") + conn_config["token"] = self._get_valid_oauth_token( + conn_config=conn_config, + token_endpoint=conn_config.get("token_endpoint"), + grant_type=grant_type, + ) + + conn_config.pop("login", None) + conn_config.pop("user", None) + conn_config.pop("password", None) + return conn_config + + @cached_property + def _get_static_conn_params(self) -> dict[str, str | None]: + """ + Return static Snowflake connection parameters. + + These parameters are cached for the lifetime of the hook and exclude + time-sensitive values such as OAuth access tokens. This is used in + ``_get_valid_oauth_token()`` and ``get_conn_params()``. """ conn = self.get_connection(self.get_conn_id()) extra_dict = conn.extra_dejson @@ -388,25 +398,21 @@ def _get_conn_params(self) -> dict[str, str | None]: conn_config["refresh_token"] = refresh_token conn_config["authenticator"] = "oauth" + grant_type = self._get_field(extra_dict, "grant_type") or "" + if grant_type: + conn_config["grant_type"] = grant_type + elif refresh_token: + conn_config["grant_type"] = "refresh_token" + if conn_config.get("authenticator") == "oauth": - if extra_dict.get("azure_conn_id"): - conn_config["token"] = self.get_azure_oauth_token(extra_dict["azure_conn_id"]) - else: - token_endpoint = self._get_field(extra_dict, "token_endpoint") or "" + conn_config["azure_conn_id"] = extra_dict.get("azure_conn_id") + + if not extra_dict.get("azure_conn_id"): + conn_config["token_endpoint"] = self._get_field(extra_dict, "token_endpoint") or "" conn_config["scope"] = self._get_field(extra_dict, "scope") conn_config["client_id"] = conn.login conn_config["client_secret"] = conn.password - conn_config["token"] = self.get_oauth_token( - conn_config=conn_config, - token_endpoint=token_endpoint, - grant_type=extra_dict.get("grant_type", "refresh_token"), - ) - - conn_config.pop("login", None) - conn_config.pop("user", None) - conn_config.pop("password", None) - # configure custom target hostname and port, if specified snowflake_host = extra_dict.get("host") snowflake_port = extra_dict.get("port") @@ -423,9 +429,80 @@ def _get_conn_params(self) -> dict[str, str | None]: return conn_config + def _get_valid_oauth_token( + self, + *, + conn_config: dict[str, Any], + token_endpoint: str | None, + grant_type: str, + ) -> str: + """ + Return a valid OAuth access token. + + This also updates the internal OAuth token cache and token expiry timestamp. + """ + # Check validity using current timestamp. + now = timezone.utcnow() + + if ( + self._oauth_token is not None + and self._oauth_token_expires_at is not None + and now < self._oauth_token_expires_at + ): + return self._oauth_token + + url = token_endpoint or f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request" + + data = { + "grant_type": grant_type, + "redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"), + } + + scope = conn_config.get("scope") + + if scope: + data["scope"] = scope + + if grant_type == "refresh_token": + data |= { + "refresh_token": conn_config["refresh_token"], + } + elif grant_type == "client_credentials": + pass # no setup necessary for client credentials grant. + else: + raise ValueError(f"Unknown grant_type: {grant_type}") + + response = requests.post( + url, + data=data, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + }, + auth=HTTPBasicAuth(conn_config["client_id"], conn_config["client_secret"]), # type: ignore[arg-type] + timeout=OAUTH_REQUEST_TIMEOUT, + ) + + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: # pragma: no cover + msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}" + raise AirflowException(msg) + + token = response.json()["access_token"] + expires_in = int(response.json()["expires_in"]) + + # Capture issue timestamp after access token is retrieved. + issued_at = timezone.utcnow() + + # Persist retrieved access token and expiration timestamp. + self._oauth_token = token + self._oauth_token_expires_at = issued_at + timedelta(seconds=max(expires_in - OAUTH_EXPIRY_BUFFER, 0)) + + return token + def get_uri(self) -> str: """Override DbApiHook get_uri method for get_sqlalchemy_engine().""" - conn_params = self._get_conn_params + conn_params = self._get_conn_params() return self._conn_params_to_sqlalchemy_uri(conn_params) def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str: @@ -449,7 +526,7 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str: def get_conn(self) -> SnowflakeConnection: """Return a snowflake.connection object.""" - conn_config = self._get_conn_params + conn_config = self._get_conn_params() conn = connector.connect(**conn_config) return conn @@ -461,7 +538,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): :return: the created engine. """ engine_kwargs = engine_kwargs or {} - conn_params = self._get_conn_params + conn_params = self._get_conn_params() if "insecure_mode" in conn_params: engine_kwargs.setdefault("connect_args", {}) engine_kwargs["connect_args"]["insecure_mode"] = True @@ -488,7 +565,7 @@ def get_snowpark_session(self): from airflow import __version__ as airflow_version from airflow.providers.snowflake import __version__ as provider_version - conn_config = self._get_conn_params + conn_config = self._get_conn_params() session = Session.builder.configs(conn_config).create() # add query tag for observability session.update_query_tag( @@ -654,7 +731,7 @@ def get_openlineage_database_dialect(self, _) -> str: return "snowflake" def get_openlineage_default_schema(self) -> str | None: - return self._get_conn_params["schema"] + return self._get_conn_params()["schema"] def _get_openlineage_authority(self, _) -> str | None: uri = fix_snowflake_sqlalchemy_uri(self.get_uri()) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 1d9e6f235bd2a..d682cb4aa5b85 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -163,7 +163,7 @@ def execute_query( the statement with these specified values. """ self.query_ids = [] - conn_config = self._get_conn_params + conn_config = self._get_conn_params() req_id = uuid.uuid4() url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements" @@ -206,7 +206,7 @@ def execute_query( def get_headers(self) -> dict[str, Any]: """Form auth headers based on either OAuth token or JWT token from private key.""" - conn_config = self._get_conn_params + conn_config = self._get_conn_params() # Use OAuth if refresh_token and client_id and client_secret are provided if all( diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index d92ca12752d91..b44703c546937 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -21,6 +21,7 @@ import json import sys from copy import deepcopy +from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any from unittest import mock from unittest.mock import Mock, PropertyMock @@ -33,6 +34,7 @@ from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.models import Connection from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook +from airflow.utils import timezone if TYPE_CHECKING: from pathlib import Path @@ -367,7 +369,7 @@ def test_hook_should_support_prepare_basic_conn_params_and_uri( ): with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() == expected_uri - assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == expected_conn_params + assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == expected_conn_params def test_get_conn_params_should_support_private_auth_in_connection( self, base64_encoded_encrypted_private_key: Path @@ -385,7 +387,7 @@ def test_get_conn_params_should_support_private_auth_in_connection( }, } with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): - assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params + assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() @pytest.mark.parametrize("include_params", [True, False]) def test_hook_param_beats_extra(self, include_params): @@ -408,7 +410,7 @@ def test_hook_param_beats_extra(self, include_params): assert hook_params != extras assert SnowflakeHook( snowflake_conn_id="test_conn", **(hook_params if include_params else {}) - )._get_conn_params == { + )._get_conn_params() == { "user": None, "password": "", "application": "AIRFLOW", @@ -437,7 +439,7 @@ def test_extra_short_beats_long(self, include_unprefixed): ).get_uri(), ): assert list(extras.values()) != list(extras_prefixed.values()) - assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == { + assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == { "user": None, "password": "", "application": "AIRFLOW", @@ -463,7 +465,7 @@ def test_get_conn_params_should_support_private_auth_with_encrypted_key( }, } with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): - assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params + assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() def test_get_conn_params_should_support_private_auth_with_unencrypted_key( self, unencrypted_temporary_private_key @@ -481,16 +483,16 @@ def test_get_conn_params_should_support_private_auth_with_unencrypted_key( }, } with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): - assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params + assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() connection_kwargs["password"] = "" with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): - assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params + assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() connection_kwargs["password"] = _PASSWORD with ( mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()), pytest.raises(TypeError, match="Password was given but private key is not encrypted."), ): - SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params + SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() def test_get_conn_params_should_fail_on_invalid_key(self): connection_kwargs = { @@ -512,10 +514,7 @@ def test_get_conn_params_should_fail_on_invalid_key(self): SnowflakeHook(snowflake_conn_id="test_conn").get_conn() @mock.patch("requests.post") - @mock.patch( - "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params", - new_callable=PropertyMock, - ) + @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params") def test_get_conn_params_should_support_oauth(self, mock_get_conn_params, requests_post): requests_post.return_value = Mock( status_code=200, @@ -544,7 +543,7 @@ def test_get_conn_params_should_support_oauth(self, mock_get_conn_params, reques mock_get_conn_params.return_value = connection_kwargs with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): hook = SnowflakeHook(snowflake_conn_id="test_conn") - conn_params = hook._get_conn_params + conn_params = hook._get_conn_params() conn_params_keys = conn_params.keys() conn_params_extra = conn_params.get("extra", {}) @@ -563,7 +562,6 @@ def test_get_conn_params_should_support_oauth(self, mock_get_conn_params, reques @mock.patch("requests.post") @mock.patch( "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params", - new_callable=PropertyMock, ) def test_get_conn_params_should_support_oauth_with_token_endpoint( self, mock_get_conn_params, requests_post @@ -596,7 +594,7 @@ def test_get_conn_params_should_support_oauth_with_token_endpoint( mock_get_conn_params.return_value = connection_kwargs with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): hook = SnowflakeHook(snowflake_conn_id="test_conn") - conn_params = hook._get_conn_params + conn_params = hook._get_conn_params() conn_params_keys = conn_params.keys() conn_params_extra = conn_params.get("extra", {}) @@ -616,7 +614,6 @@ def test_get_conn_params_should_support_oauth_with_token_endpoint( @mock.patch("requests.post") @mock.patch( "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params", - new_callable=PropertyMock, ) def test_get_conn_params_should_support_oauth_with_client_credentials( self, mock_get_conn_params, requests_post @@ -649,7 +646,7 @@ def test_get_conn_params_should_support_oauth_with_client_credentials( mock_get_conn_params.return_value = connection_kwargs with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): hook = SnowflakeHook(snowflake_conn_id="test_conn") - conn_params = hook._get_conn_params + conn_params = hook._get_conn_params() conn_params_keys = conn_params.keys() conn_params_extra = conn_params.get("extra", {}) @@ -686,7 +683,7 @@ def test_get_conn_params_should_support_oauth_with_azure_conn_id(self, mocker): with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): hook = SnowflakeHook(snowflake_conn_id="test_conn") - conn_params = hook._get_conn_params + conn_params = hook._get_conn_params() # Check AzureBaseHook initialization and get_token call args mock_connection_class.get.assert_called_once_with(azure_conn_id) @@ -737,7 +734,7 @@ def test_get_conn_params_include_scope(self, mock_requests_post): {"AIRFLOW_CONN_TEST_CONN": Connection(**connection_kwargs).get_uri()}, ): hook = SnowflakeHook(snowflake_conn_id="test_conn") - params = hook._get_conn_params + params = hook._get_conn_params() mock_requests_post.assert_called_once() assert "scope" in params assert params["scope"] == "default" @@ -749,7 +746,8 @@ def test_should_add_partner_info(self): AIRFLOW_SNOWFLAKE_PARTNER="PARTNER_NAME", ): assert ( - SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params["application"] == "PARTNER_NAME" + SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()["application"] + == "PARTNER_NAME" ) def test_get_conn_should_call_connect(self): @@ -761,7 +759,7 @@ def test_get_conn_should_call_connect(self): ): hook = SnowflakeHook(snowflake_conn_id="test_conn") conn = hook.get_conn() - mock_connector.connect.assert_called_once_with(**hook._get_conn_params) + mock_connector.connect.assert_called_once_with(**hook._get_conn_params()) assert mock_connector.connect.return_value == conn def test_get_sqlalchemy_engine_should_support_pass_auth(self): @@ -876,7 +874,7 @@ def test_hook_parameters_should_take_precedence(self): authenticator="TEST_AUTH", session_parameters={"AA": "AAA"}, ) - assert hook._get_conn_params == { + assert hook._get_conn_params() == { "account": "TEST_ACCOUNT", "application": "AIRFLOW", "authenticator": "TEST_AUTH", @@ -1068,7 +1066,7 @@ def test_get_snowpark_session(self, mock_session_builder): session = hook.get_snowpark_session() assert session == mock_session - mock_session_builder.configs.assert_called_once_with(hook._get_conn_params) + mock_session_builder.configs.assert_called_once_with(hook._get_conn_params()) # Verify that update_query_tag was called with the expected tag dictionary mock_session.update_query_tag.assert_called_once_with( @@ -1101,6 +1099,7 @@ def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth): }, headers={"Content-Type": "application/x-www-form-urlencoded"}, auth=basic_auth, + timeout=30, ) @mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth") @@ -1129,6 +1128,7 @@ def test_get_oauth_token_with_token_endpoint(self, mock_conn_param, requests_pos }, headers={"Content-Type": "application/x-www-form-urlencoded"}, auth=basic_auth, + timeout=30, ) def test_get_azure_oauth_token(self, mocker): @@ -1181,7 +1181,7 @@ def test_get_oauth_token_with_scope(self, mock_requests_post): mock_requests_post.return_value = Mock( status_code=200, - json=lambda: {"access_token": "dummy_token"}, + json=lambda: {"access_token": "dummy_token", "expires_in": 600}, ) connection_kwargs = { @@ -1213,12 +1213,12 @@ def test_get_oauth_token_with_scope(self, mock_requests_post): @mock.patch("requests.post") def test_get_oauth_token_without_scope(self, mock_requests_post): """ - Verify that `get_oauth_token` returns an access token and sends `scope=None` + Verify that `get_oauth_token` returns an access token` when no scope is defined in the connection extras. """ mock_requests_post.return_value = Mock( status_code=200, - json=lambda: {"access_token": "dummy_token"}, + json=lambda: {"access_token": "dummy_token", "expires_in": 600}, ) connection_kwargs = { @@ -1241,3 +1241,71 @@ def test_get_oauth_token_without_scope(self, mock_requests_post): assert "scope" not in called_data assert called_data["grant_type"] == "client_credentials" + + @mock.patch("requests.post") + @mock.patch("airflow.providers.snowflake.hooks.snowflake.timezone.utcnow") + def test_oauth_token_refresh_after_expiry(self, mock_timezone_utcnow, mock_requests_post): + """ + Ensure OAuth tokens are refreshed after expiry for a reused SnowflakeHook, + without mutating static connection parameters. + """ + + t0 = datetime(2025, 1, 1, 12, 0, tzinfo=timezone.utc) + + # _get_valid_oauth_token calls utcnow twice per refresh: + # 1) validity check + # 2) issued_at + mock_timezone_utcnow.side_effect = [ + t0, + t0, + t0 + timedelta(minutes=11), + t0 + timedelta(minutes=11), + ] + + mock_requests_post.side_effect = [ + Mock( + status_code=200, + json=lambda: {"access_token": "token1", "expires_in": 600}, + raise_for_status=lambda: None, + ), + Mock( + status_code=200, + json=lambda: {"access_token": "token2", "expires_in": 600}, + raise_for_status=lambda: None, + ), + ] + + connection_kwargs = { + **BASE_CONNECTION_KWARGS, + "login": "client_id", + "password": "client_secret", + "extra": { + "account": "airflow", + "authenticator": "oauth", + "grant_type": "refresh_token", + "refresh_token": "secret_token", + }, + } + + with mock.patch.dict( + "os.environ", + {"AIRFLOW_CONN_TEST_CONN": Connection(**connection_kwargs).get_uri()}, + ): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + # First resolution (initial token) + conn_params_1 = hook._get_conn_params() + + # Second resolution (after token expiry) + conn_params_2 = hook._get_conn_params() + + # Token must be refreshed + assert conn_params_1["token"] == "token1" + assert conn_params_2["token"] == "token2" + + # Static params must not change + assert {k: v for k, v in conn_params_1.items() if k != "token"} == { + k: v for k, v in conn_params_2.items() if k != "token" + } + + # Ensure refresh actually happened + assert mock_requests_post.call_count == 2 diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index e47b835dd7310..c970041bc99a4 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -21,7 +21,7 @@ import uuid from typing import TYPE_CHECKING, Any from unittest import mock -from unittest.mock import AsyncMock, PropertyMock, call +from unittest.mock import AsyncMock, call import aiohttp import pytest @@ -225,7 +225,7 @@ class TestSnowflakeSqlApiHook: (SQL_MULTIPLE_STMTS, 4, {"statementHandles": ["uuid", "uuid1"]}, ["uuid", "uuid1"]), ], ) - @mock.patch(f"{HOOK_PATH}._get_conn_params", new_callable=PropertyMock) + @mock.patch(f"{HOOK_PATH}._get_conn_params") @mock.patch(f"{HOOK_PATH}.get_headers") def test_execute_query( self, @@ -249,7 +249,7 @@ def test_execute_query( query_ids = hook.execute_query(sql, statement_count) assert query_ids == expected_query_ids - @mock.patch(f"{HOOK_PATH}._get_conn_params", new_callable=PropertyMock) + @mock.patch(f"{HOOK_PATH}._get_conn_params") @mock.patch(f"{HOOK_PATH}.get_headers") def test_execute_query_multiple_times_give_fresh_query_ids_each_time( self, mock_get_header, mock_conn_param, mock_requests @@ -289,7 +289,7 @@ def test_execute_query_multiple_times_give_fresh_query_ids_each_time( ("sql", "statement_count", "expected_response", "expected_query_ids"), [(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])], ) - @mock.patch(f"{HOOK_PATH}._get_conn_params", new_callable=PropertyMock) + @mock.patch(f"{HOOK_PATH}._get_conn_params") @mock.patch(f"{HOOK_PATH}.get_headers") def test_execute_query_exception_without_statement_handle( self, @@ -319,7 +319,7 @@ def test_execute_query_exception_without_statement_handle( (SQL_MULTIPLE_STMTS, 4, {"1": {"type": "FIXED", "value": "123"}}), ], ) - @mock.patch(f"{HOOK_PATH}._get_conn_params", new_callable=PropertyMock) + @mock.patch(f"{HOOK_PATH}._get_conn_params") @mock.patch(f"{HOOK_PATH}.get_headers") def test_execute_query_bindings_warning( self, @@ -385,7 +385,7 @@ def test_check_query_output_exception( with pytest.raises(requests.exceptions.HTTPError): hook.check_query_output(query_ids) - @mock.patch(f"{HOOK_PATH}._get_conn_params", new_callable=PropertyMock) + @mock.patch(f"{HOOK_PATH}._get_conn_params") @mock.patch(f"{HOOK_PATH}.get_headers") def test_get_request_url_header_params(self, mock_get_header, mock_conn_param): """Test get_request_url_header_params by mocking _get_conn_params and get_headers""" @@ -397,7 +397,7 @@ def test_get_request_url_header_params(self, mock_get_header, mock_conn_param): assert url == "https://airflow.af_region.snowflakecomputing.com/api/v2/statements/uuid" @mock.patch(f"{HOOK_PATH}.get_private_key") - @mock.patch(f"{HOOK_PATH}._get_conn_params", new_callable=PropertyMock) + @mock.patch(f"{HOOK_PATH}._get_conn_params") @mock.patch("airflow.providers.snowflake.utils.sql_api_generate_jwt.JWTGenerator.get_token") def test_get_headers_should_support_private_key(self, mock_get_token, mock_conn_param, mock_private_key): """Test get_headers method by mocking get_private_key and _get_conn_params method""" @@ -408,7 +408,7 @@ def test_get_headers_should_support_private_key(self, mock_get_token, mock_conn_ assert result == HEADERS @mock.patch(f"{HOOK_PATH}.get_oauth_token") - @mock.patch(f"{HOOK_PATH}._get_conn_params", new_callable=PropertyMock) + @mock.patch(f"{HOOK_PATH}._get_conn_params") def test_get_headers_should_support_oauth(self, mock_conn_param, mock_oauth_token): """Test get_headers method by mocking get_oauth_token and _get_conn_params method""" mock_conn_param.return_value = CONN_PARAMS_OAUTH @@ -419,7 +419,7 @@ def test_get_headers_should_support_oauth(self, mock_conn_param, mock_oauth_toke @mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth") @mock.patch("requests.post") - @mock.patch(f"{HOOK_PATH}._get_conn_params", new_callable=PropertyMock) + @mock.patch(f"{HOOK_PATH}._get_conn_params") def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth): """Test get_oauth_token method makes the right http request""" basic_auth = {"Authorization": "Basic usernamepassword"} @@ -438,6 +438,7 @@ def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth): }, headers={"Content-Type": "application/x-www-form-urlencoded"}, auth=basic_auth, + timeout=30, ) @pytest.fixture @@ -804,7 +805,7 @@ def test_hook_parameter_propagation(self, hook_params): ], ) @mock.patch("uuid.uuid4") - @mock.patch(f"{HOOK_PATH}._get_conn_params", new_callable=PropertyMock) + @mock.patch(f"{HOOK_PATH}._get_conn_params") @mock.patch(f"{HOOK_PATH}.get_headers") def test_proper_parametrization_of_execute_query_api_request( self,