diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 762f1b10991dc..1fe3e6cd83cc5 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -201,7 +201,7 @@ def get_headers(self) -> dict[str, Any]: if all( [conn_config.get("refresh_token"), conn_config.get("client_id"), conn_config.get("client_secret")] ): - oauth_token = self.get_oauth_token() + oauth_token = self.get_oauth_token(conn_config=conn_config) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {oauth_token}", @@ -232,9 +232,8 @@ def get_headers(self) -> dict[str, Any]: } return headers - def get_oauth_token(self) -> str: + def get_oauth_token(self, conn_config: dict[str, Any]) -> str: """Generate temporary OAuth access token using refresh token in connection details.""" - conn_config = self._get_conn_params url = f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request" data = { "grant_type": "refresh_token", diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index 1247e3e031820..efb47f0c9d213 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -364,7 +364,7 @@ def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth): requests_post.return_value.status_code = 200 mock_auth.return_value = BASIC_AUTH hook = SnowflakeSqlApiHook(snowflake_conn_id="mock_conn_id") - hook.get_oauth_token() + hook.get_oauth_token(CONN_PARAMS_OAUTH) requests_post.assert_called_once_with( f"https://{CONN_PARAMS_OAUTH['account']}.{CONN_PARAMS_OAUTH['region']}.snowflakecomputing.com/oauth/token-request", data={