diff --git a/providers/snowflake/docs/connections/snowflake.rst b/providers/snowflake/docs/connections/snowflake.rst index 9c8c6a4a87576..e15280ceb773f 100644 --- a/providers/snowflake/docs/connections/snowflake.rst +++ b/providers/snowflake/docs/connections/snowflake.rst @@ -58,6 +58,8 @@ Extra (optional) * ``warehouse``: Snowflake warehouse name. * ``role``: Snowflake role. * ``authenticator``: To connect using OAuth set this parameter ``oauth``. + * ``token_endpoint``: Specify token endpoint for external OAuth provider. + * ``grant_type``: Specify grant type for OAuth authentication. Currently supported: ``refresh_token`` (default), ``client_credentials``. * ``refresh_token``: Specify refresh_token for OAuth connection. * ``private_key_file``: Specify the path to the private key file. * ``private_key_content``: Specify the content of the private key file in base64 encoded format. You can use the following Python code to encode the private key: diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 5535a5e35492b..88cb5f331e259 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -136,6 +136,9 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: "session_parameters": "session parameters", "client_request_mfa_token": "client request mfa token", "client_store_temporary_credential": "client store temporary credential (externalbrowser mode)", + "grant_type": "refresh_token client_credentials", + "token_endpoint": "token endpoint", + "refresh_token": "refresh token", }, indent=1, ), @@ -200,18 +203,32 @@ def account_identifier(self) -> str: return account_identifier - def get_oauth_token(self, conn_config: dict | None = None) -> str: + def get_oauth_token( + self, + conn_config: dict | None = None, + token_endpoint: str | None = None, + grant_type: str = "refresh_token", + ) -> str: """Generate temporary OAuth access token using refresh token in connection details.""" if conn_config is None: conn_config = self._get_conn_params - url = f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request" + url = token_endpoint or f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request" data = { - "grant_type": "refresh_token", - "refresh_token": conn_config["refresh_token"], + "grant_type": grant_type, "redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"), } + + if grant_type == "refresh_token": + data |= { + "refresh_token": conn_config["refresh_token"], + } + elif grant_type == "client_credentials": + pass # no setup necessary for client credentials grant. + else: + raise ValueError(f"Unknown grant_type: {grant_type}") + response = requests.post( url, data=data, @@ -226,7 +243,8 @@ def get_oauth_token(self, conn_config: dict | None = None) -> str: except requests.exceptions.HTTPError as e: # pragma: no cover msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}" raise AirflowException(msg) - return response.json()["access_token"] + token = response.json()["access_token"] + return token @cached_property def _get_conn_params(self) -> dict[str, str | None]: @@ -329,14 +347,21 @@ def _get_conn_params(self) -> dict[str, str | None]: if refresh_token: conn_config["refresh_token"] = refresh_token conn_config["authenticator"] = "oauth" + + if conn_config.get("authenticator") == "oauth": + token_endpoint = self._get_field(extra_dict, "token_endpoint") or "" conn_config["client_id"] = conn.login conn_config["client_secret"] = conn.password + conn_config["token"] = self.get_oauth_token( + conn_config=conn_config, + token_endpoint=token_endpoint, + grant_type=extra_dict.get("grant_type", "refresh_token"), + ) + conn_config.pop("login", None) conn_config.pop("user", None) conn_config.pop("password", None) - conn_config["token"] = self.get_oauth_token(conn_config=conn_config) - # configure custom target hostname and port, if specified snowflake_host = extra_dict.get("host") snowflake_port = extra_dict.get("port") diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index c7b9765c60968..b5b4f0ffc99b8 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -222,14 +222,21 @@ def get_headers(self) -> dict[str, Any]: } return headers - def get_oauth_token(self, conn_config: dict[str, Any] | None = None) -> str: + def get_oauth_token( + self, + conn_config: dict[str, Any] | None = None, + token_endpoint: str | None = None, + grant_type: str = "refresh_token", + ) -> str: """Generate temporary OAuth access token using refresh token in connection details.""" warnings.warn( "This method is deprecated. Please use `get_oauth_token` method from `SnowflakeHook` instead. ", AirflowProviderDeprecationWarning, stacklevel=2, ) - return super().get_oauth_token(conn_config=conn_config) + return super().get_oauth_token( + conn_config=conn_config, token_endpoint=token_endpoint, grant_type=grant_type + ) def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]: """ diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index ec384ddc5c8ce..e28d0bca28e7d 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -53,14 +53,13 @@ }, } -CONN_PARAMS_OAUTH = { +CONN_PARAMS_OAUTH_BASE = { "account": "airflow", "application": "AIRFLOW", "authenticator": "oauth", "database": "db", "client_id": "test_client_id", "client_secret": "test_client_pw", - "refresh_token": "secrettoken", "region": "af_region", "role": "af_role", "schema": "public", @@ -68,6 +67,8 @@ "warehouse": "af_wh", } +CONN_PARAMS_OAUTH = CONN_PARAMS_OAUTH_BASE | {"refresh_token": "secrettoken"} + @pytest.fixture def unencrypted_temporary_private_key(tmp_path: Path) -> Path: @@ -559,6 +560,112 @@ def test_get_conn_params_should_support_oauth(self, mock_get_conn_params, reques assert "region" in conn_params_extra_keys assert "account" in conn_params_extra_keys + @mock.patch("requests.post") + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params", + new_callable=PropertyMock, + ) + def test_get_conn_params_should_support_oauth_with_token_endpoint( + self, mock_get_conn_params, requests_post + ): + requests_post.return_value = Mock( + status_code=200, + json=lambda: { + "access_token": "supersecretaccesstoken", + "expires_in": 600, + "refresh_token": "secrettoken", + "token_type": "Bearer", + "username": "test_user", + }, + ) + connection_kwargs = { + **BASE_CONNECTION_KWARGS, + "login": "test_client_id", + "password": "test_client_secret", + "extra": { + "database": "db", + "account": "airflow", + "warehouse": "af_wh", + "region": "af_region", + "role": "af_role", + "refresh_token": "secrettoken", + "authenticator": "oauth", + "token_endpoint": "https://www.example.com/oauth/token", + }, + } + mock_get_conn_params.return_value = connection_kwargs + with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + conn_params = hook._get_conn_params + + conn_params_keys = conn_params.keys() + conn_params_extra = conn_params.get("extra", {}) + conn_params_extra_keys = conn_params_extra.keys() + + assert "authenticator" in conn_params_extra_keys + assert conn_params_extra["authenticator"] == "oauth" + assert conn_params_extra["token_endpoint"] == "https://www.example.com/oauth/token" + + assert "user" not in conn_params_keys + assert "password" in conn_params_keys + assert "refresh_token" in conn_params_extra_keys + # Mandatory fields to generate account_identifier `https://.` + assert "region" in conn_params_extra_keys + assert "account" in conn_params_extra_keys + + @mock.patch("requests.post") + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params", + new_callable=PropertyMock, + ) + def test_get_conn_params_should_support_oauth_with_client_credentials( + self, mock_get_conn_params, requests_post + ): + requests_post.return_value = Mock( + status_code=200, + json=lambda: { + "access_token": "supersecretaccesstoken", + "expires_in": 600, + "refresh_token": "secrettoken", + "token_type": "Bearer", + "username": "test_user", + }, + ) + connection_kwargs = { + **BASE_CONNECTION_KWARGS, + "login": "test_client_id", + "password": "test_client_secret", + "extra": { + "database": "db", + "account": "airflow", + "warehouse": "af_wh", + "region": "af_region", + "role": "af_role", + "authenticator": "oauth", + "token_endpoint": "https://www.example.com/oauth/token", + "grant_type": "client_credentials", + }, + } + mock_get_conn_params.return_value = connection_kwargs + with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + conn_params = hook._get_conn_params + + conn_params_keys = conn_params.keys() + conn_params_extra = conn_params.get("extra", {}) + conn_params_extra_keys = conn_params_extra.keys() + + assert "authenticator" in conn_params_extra_keys + assert conn_params_extra["authenticator"] == "oauth" + assert conn_params_extra["grant_type"] == "client_credentials" + + assert "user" not in conn_params_keys + assert "password" in conn_params_keys + assert "refresh_token" not in conn_params_extra_keys + # Mandatory fields to generate account_identifier `https://.` + assert "region" in conn_params_extra_keys + assert "account" in conn_params_extra_keys + def test_should_add_partner_info(self): with mock.patch.dict( "os.environ", @@ -917,3 +1024,31 @@ def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth): headers={"Content-Type": "application/x-www-form-urlencoded"}, auth=basic_auth, ) + + @mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth") + @mock.patch("requests.post") + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params", + new_callable=PropertyMock, + ) + def test_get_oauth_token_with_token_endpoint(self, mock_conn_param, requests_post, mock_auth): + """Test get_oauth_token method makes the right http request""" + basic_auth = {"Authorization": "Basic usernamepassword"} + token_endpoint = "https://example.com/oauth/token" + mock_conn_param.return_value = CONN_PARAMS_OAUTH + requests_post.return_value.status_code = 200 + mock_auth.return_value = basic_auth + + hook = SnowflakeHook(snowflake_conn_id="mock_conn_id") + hook.get_oauth_token(conn_config=CONN_PARAMS_OAUTH, token_endpoint=token_endpoint) + + requests_post.assert_called_once_with( + token_endpoint, + data={ + "grant_type": "refresh_token", + "refresh_token": CONN_PARAMS_OAUTH["refresh_token"], + "redirect_uri": "https://localhost.com", + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + auth=basic_auth, + )