diff --git a/providers/snowflake/docs/connections/snowflake.rst b/providers/snowflake/docs/connections/snowflake.rst index 2d7076d120f55..903bc17c5da9e 100644 --- a/providers/snowflake/docs/connections/snowflake.rst +++ b/providers/snowflake/docs/connections/snowflake.rst @@ -39,10 +39,11 @@ Configuring the Connection -------------------------- Login - Specify the snowflake username. + Specify the snowflake username. For OAuth, the OAuth Client ID. Password Specify the snowflake password. For public key authentication, the passphrase for the private key. + For OAuth, the OAuth Client Secret. Schema (optional) Specify the snowflake schema to be used. diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 088c7177171e1..9cbb554249181 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -26,8 +26,10 @@ from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload from urllib.parse import urlparse +import requests from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization +from requests.auth import HTTPBasicAuth from snowflake import connector from snowflake.connector import DictCursor, SnowflakeConnection, util_text from snowflake.sqlalchemy import URL @@ -185,6 +187,30 @@ def _get_field(self, extra_dict, field_name): return extra_dict[field_name] or None return extra_dict.get(backcompat_key) or None + def get_oauth_token(self, conn_config: dict) -> str: + """Generate temporary OAuth access token using refresh token in connection details.""" + url = f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request" + data = { + "grant_type": "refresh_token", + "refresh_token": conn_config["refresh_token"], + "redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"), + } + response = requests.post( + url, + data=data, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + }, + auth=HTTPBasicAuth(conn_config["client_id"], conn_config["client_secret"]), # type: ignore[arg-type] + ) + + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: # pragma: no cover + msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}" + raise AirflowException(msg) + return response.json()["access_token"] + @cached_property def _get_conn_params(self) -> dict[str, str | None]: """ @@ -289,8 +315,11 @@ def _get_conn_params(self) -> dict[str, str | None]: conn_config["client_id"] = conn.login conn_config["client_secret"] = conn.password conn_config.pop("login", None) + conn_config.pop("user", None) conn_config.pop("password", None) + conn_config["token"] = self.get_oauth_token(conn_config=conn_config) + # configure custom target hostname and port, if specified snowflake_host = extra_dict.get("host") snowflake_port = extra_dict.get("port") diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index 2897921dc0cf7..b682e2e851190 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -22,6 +22,7 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any from unittest import mock +from unittest.mock import Mock, PropertyMock import pytest from cryptography.hazmat.backends import default_backend @@ -51,6 +52,21 @@ }, } +CONN_PARAMS_OAUTH = { + "account": "airflow", + "application": "AIRFLOW", + "authenticator": "oauth", + "database": "db", + "client_id": "test_client_id", + "client_secret": "test_client_pw", + "refresh_token": "secrettoken", + "region": "af_region", + "role": "af_role", + "schema": "public", + "session_parameters": None, + "warehouse": "af_wh", +} + @pytest.fixture def non_encrypted_temporary_private_key(tmp_path: Path) -> Path: @@ -483,6 +499,39 @@ def test_get_conn_params_should_fail_on_invalid_key(self): ): SnowflakeHook(snowflake_conn_id="test_conn").get_conn() + @mock.patch("requests.post") + def test_get_conn_params_should_support_oauth(self, requests_post): + requests_post.return_value = Mock( + status_code=200, + json=lambda: { + "access_token": "supersecretaccesstoken", + "expires_in": 600, + "refresh_token": "secrettoken", + "token_type": "Bearer", + "username": "test_user", + }, + ) + connection_kwargs = { + **BASE_CONNECTION_KWARGS, + "login": "test_client_id", + "password": "test_client_secret", + "extra": { + "database": "db", + "account": "airflow", + "warehouse": "af_wh", + "region": "af_region", + "role": "af_role", + "refresh_token": "secrettoken", + }, + } + with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + assert "user" not in hook._get_conn_params + assert "password" not in hook._get_conn_params + assert "refresh_token" in hook._get_conn_params + assert "token" in hook._get_conn_params + assert hook._get_conn_params["authenticator"] == "oauth" + def test_should_add_partner_info(self): with mock.patch.dict( "os.environ", @@ -816,3 +865,28 @@ def test_get_snowpark_session(self, mock_session_builder): "airflow_provider_version": provider_version, } ) + + @mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth") + @mock.patch("requests.post") + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params", + new_callable=PropertyMock, + ) + def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth): + """Test get_oauth_token method makes the right http request""" + BASIC_AUTH = {"Authorization": "Basic usernamepassword"} + mock_conn_param.return_value = CONN_PARAMS_OAUTH + requests_post.return_value.status_code = 200 + mock_auth.return_value = BASIC_AUTH + hook = SnowflakeHook(snowflake_conn_id="mock_conn_id") + hook.get_oauth_token(conn_config=CONN_PARAMS_OAUTH) + requests_post.assert_called_once_with( + f"https://{CONN_PARAMS_OAUTH['account']}.snowflakecomputing.com/oauth/token-request", + data={ + "grant_type": "refresh_token", + "refresh_token": CONN_PARAMS_OAUTH["refresh_token"], + "redirect_uri": "https://localhost.com", + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + auth=BASIC_AUTH, + )