diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 7e4ef8dfe0288..185a49ffaba06 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -25,8 +25,18 @@ import aiohttp import requests +from aiohttp import ClientConnectionError, ClientResponseError from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization +from requests.exceptions import ConnectionError, HTTPError, Timeout +from tenacity import ( + AsyncRetrying, + Retrying, + before_sleep_log, + retry_if_exception, + stop_after_attempt, + wait_exponential, +) from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook @@ -65,6 +75,7 @@ class SnowflakeSqlApiHook(SnowflakeHook): :param token_life_time: lifetime of the JWT Token in timedelta :param token_renewal_delta: Renewal time of the JWT Token in timedelta :param deferrable: Run operator in the deferrable mode. + :param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes. """ LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime @@ -75,15 +86,27 @@ def __init__( snowflake_conn_id: str, token_life_time: timedelta = LIFETIME, token_renewal_delta: timedelta = RENEWAL_DELTA, + api_retry_args: dict[Any, Any] | None = None, # Optional retry arguments passed to tenacity.retry *args: Any, **kwargs: Any, ): self.snowflake_conn_id = snowflake_conn_id self.token_life_time = token_life_time self.token_renewal_delta = token_renewal_delta + super().__init__(snowflake_conn_id, *args, **kwargs) self.private_key: Any = None + self.retry_config = { + "retry": retry_if_exception(self._should_retry_on_error), + "wait": wait_exponential(multiplier=1, min=1, max=60), + "stop": stop_after_attempt(5), + "before_sleep": before_sleep_log(self.log, log_level=20), # INFO level + "reraise": True, + } + if api_retry_args: + self.retry_config.update(api_retry_args) + def get_private_key(self) -> None: """Get the private key from snowflake connection.""" conn = self.get_connection(self.snowflake_conn_id) @@ -168,13 +191,8 @@ def execute_query( "query_tag": query_tag, }, } - response = requests.post(url, json=data, headers=headers, params=params) - try: - response.raise_for_status() - except requests.exceptions.HTTPError as e: # pragma: no cover - msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}" - raise AirflowException(msg) - json_response = response.json() + + _, json_response = self._make_api_call_with_retries("POST", url, headers, params, data) self.log.info("Snowflake SQL POST API response: %s", json_response) if "statementHandles" in json_response: self.query_ids = json_response["statementHandles"] @@ -259,13 +277,10 @@ def check_query_output(self, query_ids: list[str]) -> None: """ for query_id in query_ids: header, params, url = self.get_request_url_header_params(query_id) - try: - response = requests.get(url, headers=header, params=params) - response.raise_for_status() - self.log.info(response.json()) - except requests.exceptions.HTTPError as e: - msg = f"Response: {e.response.content.decode()}, Status Code: {e.response.status_code}" - raise AirflowException(msg) + _, response_json = self._make_api_call_with_retries( + method="GET", url=url, headers=header, params=params + ) + self.log.info(response_json) def _process_response(self, status_code, resp): self.log.info("Snowflake SQL GET statements status API response: %s", resp) @@ -295,9 +310,7 @@ def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]: """ self.log.info("Retrieving status for query id %s", query_id) header, params, url = self.get_request_url_header_params(query_id) - response = requests.get(url, params=params, headers=header) - status_code = response.status_code - resp = response.json() + status_code, resp = self._make_api_call_with_retries("GET", url, header, params) return self._process_response(status_code, resp) async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]: @@ -308,10 +321,85 @@ async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | """ self.log.info("Retrieving status for query id %s", query_id) header, params, url = self.get_request_url_header_params(query_id) - async with ( - aiohttp.ClientSession(headers=header) as session, - session.get(url, params=params) as response, + status_code, resp = await self._make_api_call_with_retries_async("GET", url, header, params) + return self._process_response(status_code, resp) + + @staticmethod + def _should_retry_on_error(exception) -> bool: + """ + Determine if the exception should trigger a retry based on error type and status code. + + Retries on HTTP errors 429 (Too Many Requests), 503 (Service Unavailable), + and 504 (Gateway Timeout) as recommended by Snowflake error handling docs. + Retries on connection errors and timeouts. + + :param exception: The exception to check + :return: True if the request should be retried, False otherwise + """ + if isinstance(exception, HTTPError): + return exception.response.status_code in [429, 503, 504] + if isinstance(exception, ClientResponseError): + return exception.status in [429, 503, 504] + if isinstance( + exception, + ( + ConnectionError, + Timeout, + ClientConnectionError, + ), ): - status_code = response.status - resp = await response.json() - return self._process_response(status_code, resp) + return True + return False + + def _make_api_call_with_retries( + self, method: str, url: str, headers: dict, params: dict | None = None, json: dict | None = None + ): + """ + Make an API call to the Snowflake SQL API with retry logic for specific HTTP errors. + + Error handling implemented based on Snowflake error handling docs: + https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors + + :param method: The HTTP method to use for the API call. + :param url: The URL for the API endpoint. + :param headers: The headers to include in the API call. + :param params: (Optional) The query parameters to include in the API call. + :param data: (Optional) The data to include in the API call. + :return: The response object from the API call. + """ + with requests.Session() as session: + for attempt in Retrying(**self.retry_config): # type: ignore + with attempt: + if method.upper() in ("GET", "POST"): + response = session.request( + method=method.lower(), url=url, headers=headers, params=params, json=json + ) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + response.raise_for_status() + return response.status_code, response.json() + + async def _make_api_call_with_retries_async(self, method, url, headers, params=None): + """ + Make an API call to the Snowflake SQL API asynchronously with retry logic for specific HTTP errors. + + Error handling implemented based on Snowflake error handling docs: + https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors + + :param method: The HTTP method to use for the API call. Only GET is supported as is synchronous. + :param url: The URL for the API endpoint. + :param headers: The headers to include in the API call. + :param params: (Optional) The query parameters to include in the API call. + :return: The response object from the API call. + """ + async with aiohttp.ClientSession(headers=headers) as session: + async for attempt in AsyncRetrying(**self.retry_config): # type: ignore + with attempt: + if method.upper() == "GET": + async with session.request(method=method.lower(), url=url, params=params) as response: + response.raise_for_status() + # Return status and json content for async processing + content = await response.json() + return response.status, content + else: + raise ValueError(f"Unsupported HTTP method: {method}") diff --git a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py index 2c8db1391e82b..48086e53cd8c7 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py @@ -355,6 +355,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator): When executing the statement, Snowflake replaces placeholders (? and :name) in the statement with these specified values. :param deferrable: Run operator in the deferrable mode. + :param snowflake_api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes. """ LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime @@ -381,6 +382,7 @@ def __init__( token_renewal_delta: timedelta = RENEWAL_DELTA, bindings: dict[str, Any] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + snowflake_api_retry_args: dict[str, Any] | None = None, **kwargs: Any, ) -> None: self.snowflake_conn_id = snowflake_conn_id @@ -390,6 +392,7 @@ def __init__( self.token_renewal_delta = token_renewal_delta self.bindings = bindings self.execute_async = False + self.snowflake_api_retry_args = snowflake_api_retry_args or {} self.deferrable = deferrable self.query_ids: list[str] = [] if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover @@ -412,6 +415,7 @@ def _hook(self): token_life_time=self.token_life_time, token_renewal_delta=self.token_renewal_delta, deferrable=self.deferrable, + api_retry_args=self.snowflake_api_retry_args, **self.hook_params, ) diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index 21a3fa7a999a8..410588a7a71bb 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -21,14 +21,15 @@ import uuid from typing import TYPE_CHECKING, Any from unittest import mock -from unittest.mock import AsyncMock, PropertyMock +from unittest.mock import AsyncMock, PropertyMock, call +import aiohttp import pytest import requests +import tenacity from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa -from responses import RequestsMock from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import Connection @@ -149,6 +150,28 @@ "role": "airflow_role", } +API_URL = "https://test.snowflakecomputing.com/api/v2/statements/test" + + +@pytest.fixture +def mock_requests(): + with mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.requests.Session" + ) as mock_session_cls: + mock_session = mock.MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + yield mock_session + + +@pytest.fixture +def mock_async_request(): + with mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.request" + ) as mock_session_cls: + mock_request = mock.MagicMock() + mock_session_cls.return_value = mock_request + yield mock_request + def create_successful_response_mock(content): """Create mock response for success state""" @@ -167,6 +190,35 @@ def create_post_side_effect(status_code=429): return response +def create_async_request_client_response_error(request_info=None, history=None, status_code=429): + """Create mock response for async request side effect""" + response = mock.MagicMock() + request_info = mock.MagicMock() if request_info is None else request_info + history = mock.MagicMock() if history is None else history + response.status = status_code + response.reason = f"{status_code} Error" + response.raise_for_status.side_effect = aiohttp.ClientResponseError( + request_info=request_info, history=history, status=status_code, message=response.reason + ) + return response + + +def create_async_connection_error(): + response = mock.MagicMock() + response.raise_for_status.side_effect = aiohttp.ClientConnectionError() + return response + + +def create_async_request_client_response_success(json=GET_RESPONSE, status_code=200): + """Create mock response for async request side effect""" + response = mock.MagicMock() + response.status = status_code + response.reason = "test" + response.json = AsyncMock(return_value=json) + response.raise_for_status.side_effect = None + return response + + class TestSnowflakeSqlApiHook: @pytest.mark.parametrize( "sql,statement_count,expected_response, expected_query_ids", @@ -175,7 +227,6 @@ class TestSnowflakeSqlApiHook: (SQL_MULTIPLE_STMTS, 4, {"statementHandles": ["uuid", "uuid1"]}, ["uuid", "uuid1"]), ], ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", new_callable=PropertyMock, @@ -185,25 +236,24 @@ def test_execute_query( self, mock_get_header, mock_conn_param, - mock_requests, sql, statement_count, expected_response, expected_query_ids, + mock_requests, ): """Test execute_query method, run query by mocking post request method and return the query ids""" mock_requests.codes.ok = 200 - mock_requests.post.side_effect = [ + mock_requests.request.side_effect = [ create_successful_response_mock(expected_response), ] status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock + type(mock_requests.request.return_value).status_code = status_code_mock hook = SnowflakeSqlApiHook("mock_conn_id") query_ids = hook.execute_query(sql, statement_count) assert query_ids == expected_query_ids - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", new_callable=PropertyMock, @@ -221,11 +271,11 @@ def test_execute_query_multiple_times_give_fresh_query_ids_each_time( ) mock_requests.codes.ok = 200 - mock_requests.post.side_effect = [ + mock_requests.request.side_effect = [ create_successful_response_mock(expected_response), ] status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock + type(mock_requests.request.return_value).status_code = status_code_mock hook = SnowflakeSqlApiHook("mock_conn_id") query_ids = hook.execute_query(sql, statement_count) @@ -237,7 +287,7 @@ def test_execute_query_multiple_times_give_fresh_query_ids_each_time( {"statementHandle": "uuid"}, ["uuid"], ) - mock_requests.post.side_effect = [ + mock_requests.request.side_effect = [ create_successful_response_mock(expected_response), ] query_ids = hook.execute_query(sql, statement_count) @@ -247,7 +297,6 @@ def test_execute_query_multiple_times_give_fresh_query_ids_each_time( "sql,statement_count,expected_response, expected_query_ids", [(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])], ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", new_callable=PropertyMock, @@ -257,18 +306,18 @@ def test_execute_query_exception_without_statement_handle( self, mock_get_header, mock_conn_param, - mock_requests, sql, statement_count, expected_response, expected_query_ids, + mock_requests, ): """ Test execute_query method by mocking the exception response and raise airflow exception without statementHandle in the response """ side_effect = create_post_side_effect() - mock_requests.post.side_effect = side_effect + mock_requests.request.side_effect = side_effect hook = SnowflakeSqlApiHook("mock_conn_id") with pytest.raises(AirflowException) as exception_info: @@ -281,7 +330,6 @@ def test_execute_query_exception_without_statement_handle( (SQL_MULTIPLE_STMTS, 4, {"1": {"type": "FIXED", "value": "123"}}), ], ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", new_callable=PropertyMock, @@ -291,15 +339,15 @@ def test_execute_query_bindings_warning( self, mock_get_headers, mock_conn_params, - mock_requests, sql, statement_count, bindings, + mock_requests, ): """Test execute_query method logs warning when bindings are provided for multi-statement queries""" mock_conn_params.return_value = CONN_PARAMS mock_get_headers.return_value = HEADERS - mock_requests.post.return_value = create_successful_response_mock( + mock_requests.request.return_value = create_successful_response_mock( {"statementHandles": ["uuid", "uuid1"]} ) @@ -316,28 +364,32 @@ def test_execute_query_bindings_warning( (["uuid", "uuid1"]), ], ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook." "get_request_url_header_params" ) - def test_check_query_output(self, mock_geturl_header_params, mock_requests, query_ids): + def test_check_query_output(self, mock_geturl_header_params, query_ids, mock_requests): """Test check_query_output by passing query ids as params and mock get_request_url_header_params""" req_id = uuid.uuid4() params = {"requestId": str(req_id), "page": 2, "pageSize": 10} mock_geturl_header_params.return_value = HEADERS, params, "/test/airflow/" - mock_requests.get.return_value.json.return_value = GET_RESPONSE + mock_requests.request.return_value.json.return_value = GET_RESPONSE hook = SnowflakeSqlApiHook("mock_conn_id") with mock.patch.object(hook.log, "info") as mock_log_info: hook.check_query_output(query_ids) mock_log_info.assert_called_with(GET_RESPONSE) - @pytest.mark.parametrize("query_ids", [(["uuid", "uuid1"])]) + @pytest.mark.parametrize("query_ids", [["uuid", "uuid1"]]) @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook." "get_request_url_header_params" ) - def test_check_query_output_exception(self, mock_geturl_header_params, query_ids): + def test_check_query_output_exception( + self, + mock_geturl_header_params, + query_ids, + mock_requests, + ): """ Test check_query_output by passing query ids as params and mock get_request_url_header_params to raise airflow exception and mock with http error @@ -345,11 +397,13 @@ def test_check_query_output_exception(self, mock_geturl_header_params, query_ids req_id = uuid.uuid4() params = {"requestId": str(req_id), "page": 2, "pageSize": 10} mock_geturl_header_params.return_value = HEADERS, params, "https://test/airflow/" - hook = SnowflakeSqlApiHook("mock_conn_id") - with mock.patch.object(hook.log, "error"), RequestsMock() as requests_mock: - requests_mock.get(url="https://test/airflow/", json={"foo": "bar"}, status=500) - with pytest.raises(AirflowException, match='Response: {"foo": "bar"}, Status Code: 500'): - hook.check_query_output(query_ids) + custom_retry_args = { + "stop": tenacity.stop_after_attempt(2), # Only 2 attempts instead of default 5 + } + hook = SnowflakeSqlApiHook("mock_conn_id", api_retry_args=custom_retry_args) + mock_requests.request.side_effect = [create_post_side_effect(status_code=500)] * 3 + with pytest.raises(requests.exceptions.HTTPError): + hook.check_query_output(query_ids) @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", @@ -605,9 +659,8 @@ def test_get_private_key_should_support_private_auth_with_unencrypted_key( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook." "get_request_url_header_params" ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") def test_get_sql_api_query_status( - self, mock_requests, mock_geturl_header_params, status_code, response, expected_response + self, mock_geturl_header_params, status_code, response, expected_response, mock_requests ): """Test get_sql_api_query_status function by mocking the status, response and expected response""" @@ -623,7 +676,10 @@ def __init__(self, status_code, data): def json(self): return self.data - mock_requests.get.return_value = MockResponse(status_code, response) + def raise_for_status(self): + return + + mock_requests.request.return_value = MockResponse(status_code, response) hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") assert hook.get_sql_api_query_status("uuid") == expected_response @@ -666,17 +722,16 @@ def json(self): "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook." "get_request_url_header_params" ) - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get") async def test_get_sql_api_query_status_async( - self, mock_get, mock_geturl_header_params, status_code, response, expected_response + self, mock_geturl_header_params, status_code, response, expected_response, mock_async_request ): """Test Async get_sql_api_query_status_async function by mocking the status, response and expected response""" req_id = uuid.uuid4() params = {"requestId": str(req_id), "page": 2, "pageSize": 10} mock_geturl_header_params.return_value = HEADERS, params, "/test/airflow/" - mock_get.return_value.__aenter__.return_value.status = status_code - mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=response) + mock_async_request.__aenter__.return_value.status = status_code + mock_async_request.__aenter__.return_value.json = AsyncMock(return_value=response) hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") response = await hook.get_sql_api_query_status_async("uuid") assert response == expected_response @@ -774,7 +829,6 @@ def test_hook_parameter_propagation(self, hook_params): ], ) @mock.patch("uuid.uuid4") - @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") @mock.patch( "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", new_callable=PropertyMock, @@ -784,13 +838,13 @@ def test_proper_parametrization_of_execute_query_api_request( self, mock_get_headers, mock_conn_param, - mock_requests, mock_uuid, test_hook_params, sql, statement_count, expected_payload, expected_response, + mock_requests, ): """ This tests if the query execution ordered by POST request to Snowflake API @@ -801,7 +855,7 @@ def test_proper_parametrization_of_execute_query_api_request( mock_conn_param.return_value = CONN_PARAMS mock_get_headers.return_value = HEADERS mock_requests.codes.ok = 200 - mock_requests.post.side_effect = [ + mock_requests.request.side_effect = [ create_successful_response_mock(expected_response), ] status_code_mock = mock.PropertyMock(return_value=200) @@ -812,4 +866,316 @@ def test_proper_parametrization_of_execute_query_api_request( hook.execute_query(sql, statement_count) - mock_requests.post.assert_called_once_with(url, headers=HEADERS, json=expected_payload, params=params) + mock_requests.request.assert_called_once_with( + method="post", url=url, headers=HEADERS, json=expected_payload, params=params + ) + + @pytest.mark.parametrize( + "status_code,should_retry", + [ + (429, True), # Too Many Requests - should retry + (503, True), # Service Unavailable - should retry + (504, True), # Gateway Timeout - should retry + (500, False), # Internal Server Error - should not retry + (400, False), # Bad Request - should not retry + (401, False), # Unauthorized - should not retry + (404, False), # Not Found - should not retry + ], + ) + def test_make_api_call_with_retries_http_errors(self, status_code, should_retry, mock_requests): + """ + Test that _make_api_call_with_retries method only retries on specific HTTP status codes. + Should retry on 429, 503, 504 but not on other error codes. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # Mock failed response + failed_response = mock.MagicMock() + failed_response.status_code = status_code + failed_response.json.return_value = {"error": "test error"} + failed_response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=failed_response) + + # Mock successful response for retries + success_response = mock.MagicMock() + success_response.status_code = 200 + success_response.json.return_value = {"statementHandle": "uuid"} + success_response.raise_for_status.return_value = None + + if should_retry: + # For retryable errors, first call fails, second succeeds + mock_requests.request.side_effect = [failed_response, success_response] + status_code, resp_json = hook._make_api_call_with_retries( + method="GET", + url=API_URL, + headers=HEADERS, + ) + assert status_code == 200 + assert resp_json == {"statementHandle": "uuid"} + assert mock_requests.request.call_count == 2 + mock_requests.request.assert_has_calls( + [ + call( + method="get", + json=None, + url=API_URL, + params=None, + headers=HEADERS, + ), + call( + method="get", + json=None, + url=API_URL, + params=None, + headers=HEADERS, + ), + ] + ) + else: + # For non-retryable errors, should fail immediately + mock_requests.request.side_effect = [failed_response] + with pytest.raises(requests.exceptions.HTTPError): + hook._make_api_call_with_retries( + method="GET", + url=API_URL, + headers=HEADERS, + ) + assert mock_requests.request.call_count == 1 + mock_requests.request.assert_called_with( + method="get", + json=None, + url=API_URL, + params=None, + headers=HEADERS, + ) + + def test_make_api_call_with_retries_connection_errors(self, mock_requests): + """ + Test that _make_api_call_with_retries method retries on connection errors. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # Mock connection error then success + success_response = mock.MagicMock() + success_response.status_code = 200 + success_response.json.return_value = {"statementHandle": "uuid"} + success_response.raise_for_status.return_value = None + + mock_requests.request.side_effect = [ + requests.exceptions.ConnectionError("Connection failed"), + success_response, + ] + + status_code, resp_json = hook._make_api_call_with_retries( + "POST", API_URL, HEADERS, json={"test": "data"} + ) + + assert status_code == 200 + mock_requests.request.assert_called_with( + method="post", + url=API_URL, + params=None, + headers=HEADERS, + json={"test": "data"}, + ) + assert resp_json == {"statementHandle": "uuid"} + assert mock_requests.request.call_count == 2 + + def test_make_api_call_with_retries_timeout_errors(self, mock_requests): + """ + Test that _make_api_call_with_retries method retries on timeout errors. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # Mock timeout error then success + success_response = mock.MagicMock() + success_response.status_code = 200 + success_response.json.return_value = {"statementHandle": "uuid"} + success_response.raise_for_status.return_value = None + + mock_requests.request.side_effect = [ + requests.exceptions.Timeout("Request timed out"), + success_response, + ] + + status_code, resp_json = hook._make_api_call_with_retries("GET", API_URL, HEADERS) + + assert status_code == 200 + assert resp_json == {"statementHandle": "uuid"} + assert mock_requests.request.call_count == 2 + + def test_make_api_call_with_retries_max_attempts(self, mock_requests): + """ + Test that _make_api_call_with_retries method respects max retry attempts. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # Mock response that always fails with retryable error + failed_response = mock.MagicMock() + failed_response.status_code = 429 + failed_response.json.return_value = {"error": "rate limited"} + failed_response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=failed_response) + + mock_requests.request.side_effect = [failed_response] * 10 # More failures than max retries + + with pytest.raises(requests.exceptions.HTTPError): + hook._make_api_call_with_retries("GET", API_URL, HEADERS) + + # Should attempt 5 times (initial + 4 retries) based on default retry config + assert mock_requests.request.call_count == 5 + + def test_make_api_call_with_retries_success_no_retry(self, mock_requests): + """ + Test that _make_api_call_with_retries method doesn't retry on successful requests. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # Mock successful response + success_response = mock.MagicMock() + success_response.status_code = 200 + success_response.json.return_value = {"statementHandle": "uuid"} + success_response.raise_for_status.return_value = None + + mock_requests.request.return_value = success_response + + status_code, resp_json = hook._make_api_call_with_retries( + "POST", API_URL, HEADERS, json={"test": "data"} + ) + + assert status_code == 200 + assert resp_json == {"statementHandle": "uuid"} + assert mock_requests.request.call_count == 1 + + def test_make_api_call_with_retries_unsupported_method(self): + """ + Test that _make_api_call_with_retries method raises ValueError for unsupported HTTP methods. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + with pytest.raises(ValueError, match="Unsupported HTTP method: PUT"): + hook._make_api_call_with_retries("PUT", API_URL, HEADERS) + + def test_make_api_call_with_retries_custom_retry_config(self, mock_requests): + """ + Test that _make_api_call_with_retries method respects custom retry configuration. + """ + + # Create hook with custom retry config + custom_retry_args = { + "stop": tenacity.stop_after_attempt(2), # Only 2 attempts instead of default 5 + } + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn", api_retry_args=custom_retry_args) + + # Mock response that always fails with retryable error + failed_response = mock.MagicMock() + failed_response.status_code = 503 + failed_response.json.return_value = {"error": "service unavailable"} + failed_response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=failed_response) + + mock_requests.request.side_effect = [failed_response] * 3 + + with pytest.raises(requests.exceptions.HTTPError): + hook._make_api_call_with_retries("GET", API_URL, HEADERS) + + # Should attempt only 2 times due to custom config + assert mock_requests.request.call_count == 2 + + @pytest.mark.asyncio + async def test_make_api_call_with_retries_async_success(self, mock_async_request): + """ + Test that _make_api_call_with_retries_async returns response on success. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + mock_response = create_async_request_client_response_success() + mock_async_request.__aenter__.return_value = mock_response + status_code, resp_json = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) + assert status_code == 200 + assert resp_json == GET_RESPONSE + assert mock_async_request.__aenter__.call_count == 1 + + @pytest.mark.asyncio + async def test_make_api_call_with_retries_async_retryable_http_error(self, mock_async_request): + """ + Test that _make_api_call_with_retries_async retries on retryable HTTP errors (429, 503, 504). + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # First response: 429, then 200 + mock_response_429 = create_async_request_client_response_error() + mock_response_200 = create_async_request_client_response_success() + # Side effect for request context manager + mock_async_request.__aenter__.side_effect = [ + mock_response_429, + mock_response_200, + ] + + status_code, resp_json = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) + assert status_code == 200 + assert resp_json == GET_RESPONSE + assert mock_async_request.__aenter__.call_count == 2 + + @pytest.mark.asyncio + async def test_make_api_call_with_retries_async_non_retryable_http_error(self, mock_async_request): + """ + Test that _make_api_call_with_retries_async does not retry on non-retryable HTTP errors. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + mock_response_400 = create_async_request_client_response_error(status_code=400) + + mock_async_request.__aenter__.return_value = mock_response_400 + + with pytest.raises(aiohttp.ClientResponseError): + await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) + assert mock_async_request.__aenter__.call_count == 1 + + @pytest.mark.asyncio + async def test_make_api_call_with_retries_async_connection_error(self, mock_async_request): + """ + Test that _make_api_call_with_retries_async retries on connection errors. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + # First: connection error, then: success + failed_conn = create_async_connection_error() + + mock_request_200 = create_async_request_client_response_success() + + # Side effect for request context manager + mock_async_request.__aenter__.side_effect = [ + failed_conn, + mock_request_200, + ] + + status_code, resp_json = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) + assert status_code == 200 + assert resp_json == GET_RESPONSE + assert mock_async_request.__aenter__.call_count == 2 + + @pytest.mark.asyncio + async def test_make_api_call_with_retries_async_max_attempts(self, mock_async_request): + """ + Test that _make_api_call_with_retries_async respects max retry attempts. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + mock_request_429 = create_async_request_client_response_error(status_code=429) + + # Always returns 429 + mock_async_request.__aenter__.side_effect = [mock_request_429] * 5 + + with pytest.raises(aiohttp.ClientResponseError): + await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) + # Should attempt 5 times (default max retries) + assert mock_async_request.__aenter__.call_count == 5 + + @pytest.mark.asyncio + async def test_make_api_call_with_retries_async_unsupported_method(self, mock_async_request): + """ + Test that _make_api_call_with_retries_async raises ValueError for unsupported HTTP methods. + """ + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + with pytest.raises(ValueError, match="Unsupported HTTP method: PATCH"): + await hook._make_api_call_with_retries_async("PATCH", API_URL, HEADERS) + # No HTTP call should be made + assert mock_async_request.__aenter__.call_count == 0