diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 2fb6c6b14e496..d4b4f16e5a70d 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -189,9 +189,24 @@ def _get_field(self, extra_dict, field_name): return extra_dict[field_name] or None return extra_dict.get(backcompat_key) or None - def get_oauth_token(self, conn_config: dict) -> str: + @property + def account_identifier(self) -> str: + """Returns snowflake account identifier.""" + conn_config = self._get_conn_params + account_identifier = f"https://{conn_config['account']}" + + if conn_config["region"]: + account_identifier += f".{conn_config['region']}" + + return account_identifier + + def get_oauth_token(self, conn_config: dict | None = None) -> str: """Generate temporary OAuth access token using refresh token in connection details.""" - url = f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request" + if conn_config is None: + conn_config = self._get_conn_params + + url = f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request" + data = { "grant_type": "refresh_token", "refresh_token": conn_config["refresh_token"], diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index c5ee2989cc9ae..c7b9765c60968 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -18,6 +18,7 @@ import base64 import uuid +import warnings from datetime import timedelta from pathlib import Path from typing import Any @@ -26,9 +27,8 @@ import requests from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization -from requests.auth import HTTPBasicAuth -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook from airflow.providers.snowflake.utils.sql_api_generate_jwt import JWTGenerator @@ -84,17 +84,6 @@ def __init__( super().__init__(snowflake_conn_id, *args, **kwargs) self.private_key: Any = None - @property - def account_identifier(self) -> str: - """Returns snowflake account identifier.""" - conn_config = self._get_conn_params - account_identifier = f"https://{conn_config['account']}" - - if conn_config["region"]: - account_identifier += f".{conn_config['region']}" - - return account_identifier - def get_private_key(self) -> None: """Get the private key from snowflake connection.""" conn = self.get_connection(self.snowflake_conn_id) @@ -233,29 +222,14 @@ def get_headers(self) -> dict[str, Any]: } return headers - def get_oauth_token(self, conn_config: dict[str, Any]) -> str: + def get_oauth_token(self, conn_config: dict[str, Any] | None = None) -> str: """Generate temporary OAuth access token using refresh token in connection details.""" - url = f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request" - data = { - "grant_type": "refresh_token", - "refresh_token": conn_config["refresh_token"], - "redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"), - } - response = requests.post( - url, - data=data, - headers={ - "Content-Type": "application/x-www-form-urlencoded", - }, - auth=HTTPBasicAuth(conn_config["client_id"], conn_config["client_secret"]), # type: ignore[arg-type] + warnings.warn( + "This method is deprecated. Please use `get_oauth_token` method from `SnowflakeHook` instead. ", + AirflowProviderDeprecationWarning, + stacklevel=2, ) - - try: - response.raise_for_status() - except requests.exceptions.HTTPError as e: # pragma: no cover - msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}" - raise AirflowException(msg) - return response.json()["access_token"] + return super().get_oauth_token(conn_config=conn_config) def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]: """ diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index 30fbd50289b6c..f3278cd990c9b 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -511,7 +511,11 @@ def test_get_conn_params_should_fail_on_invalid_key(self): SnowflakeHook(snowflake_conn_id="test_conn").get_conn() @mock.patch("requests.post") - def test_get_conn_params_should_support_oauth(self, requests_post): + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params", + new_callable=PropertyMock, + ) + def test_get_conn_params_should_support_oauth(self, mock_get_conn_params, requests_post): requests_post.return_value = Mock( status_code=200, json=lambda: { @@ -533,15 +537,27 @@ def test_get_conn_params_should_support_oauth(self, requests_post): "region": "af_region", "role": "af_role", "refresh_token": "secrettoken", + "authenticator": "oauth", }, } + mock_get_conn_params.return_value = connection_kwargs with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): hook = SnowflakeHook(snowflake_conn_id="test_conn") - assert "user" not in hook._get_conn_params - assert "password" not in hook._get_conn_params - assert "refresh_token" in hook._get_conn_params - assert "token" in hook._get_conn_params - assert hook._get_conn_params["authenticator"] == "oauth" + conn_params = hook._get_conn_params + + conn_params_keys = conn_params.keys() + conn_params_extra = conn_params.get("extra", {}) + conn_params_extra_keys = conn_params_extra.keys() + + assert "authenticator" in conn_params_extra_keys + assert conn_params_extra["authenticator"] == "oauth" + + assert "user" not in conn_params_keys + assert "password" in conn_params_keys + assert "refresh_token" in conn_params_extra_keys + # Mandatory fields to generate account_identifier `https://.` + assert "region" in conn_params_extra_keys + assert "account" in conn_params_extra_keys def test_should_add_partner_info(self): with mock.patch.dict( @@ -885,19 +901,19 @@ def test_get_snowpark_session(self, mock_session_builder): ) def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth): """Test get_oauth_token method makes the right http request""" - BASIC_AUTH = {"Authorization": "Basic usernamepassword"} + basic_auth = {"Authorization": "Basic usernamepassword"} mock_conn_param.return_value = CONN_PARAMS_OAUTH requests_post.return_value.status_code = 200 - mock_auth.return_value = BASIC_AUTH + mock_auth.return_value = basic_auth hook = SnowflakeHook(snowflake_conn_id="mock_conn_id") hook.get_oauth_token(conn_config=CONN_PARAMS_OAUTH) requests_post.assert_called_once_with( - f"https://{CONN_PARAMS_OAUTH['account']}.snowflakecomputing.com/oauth/token-request", + f"https://{CONN_PARAMS_OAUTH['account']}.{CONN_PARAMS_OAUTH['region']}.snowflakecomputing.com/oauth/token-request", data={ "grant_type": "refresh_token", "refresh_token": CONN_PARAMS_OAUTH["refresh_token"], "redirect_uri": "https://localhost.com", }, headers={"Content-Type": "application/x-www-form-urlencoded"}, - auth=BASIC_AUTH, + auth=basic_auth, ) diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index d90c38098f98b..7ef4172d54670 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -30,7 +30,7 @@ from cryptography.hazmat.primitives.asymmetric import rsa from responses import RequestsMock -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import Connection from airflow.providers.snowflake.hooks.snowflake_sql_api import ( SnowflakeSqlApiHook, @@ -352,7 +352,7 @@ def test_get_headers_should_support_oauth(self, mock_conn_param, mock_oauth_toke result = hook.get_headers() assert result == HEADERS_OAUTH - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.HTTPBasicAuth") + @mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth") @mock.patch("requests.post") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", @@ -360,12 +360,13 @@ def test_get_headers_should_support_oauth(self, mock_conn_param, mock_oauth_toke ) def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth): """Test get_oauth_token method makes the right http request""" - BASIC_AUTH = {"Authorization": "Basic usernamepassword"} + basic_auth = {"Authorization": "Basic usernamepassword"} mock_conn_param.return_value = CONN_PARAMS_OAUTH requests_post.return_value.status_code = 200 - mock_auth.return_value = BASIC_AUTH + mock_auth.return_value = basic_auth hook = SnowflakeSqlApiHook(snowflake_conn_id="mock_conn_id") - hook.get_oauth_token(CONN_PARAMS_OAUTH) + with pytest.warns(expected_warning=AirflowProviderDeprecationWarning): + hook.get_oauth_token(CONN_PARAMS_OAUTH) requests_post.assert_called_once_with( f"https://{CONN_PARAMS_OAUTH['account']}.{CONN_PARAMS_OAUTH['region']}.snowflakecomputing.com/oauth/token-request", data={ @@ -374,7 +375,7 @@ def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth): "redirect_uri": "https://localhost.com", }, headers={"Content-Type": "application/x-www-form-urlencoded"}, - auth=BASIC_AUTH, + auth=basic_auth, ) @pytest.fixture