diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index c3521e036412b..f5f34b34153e7 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -476,6 +476,13 @@ def _should_retry_on_error(exception) -> bool: return True return False + @staticmethod + def _should_raise_for_status(status: int) -> bool: + # _process_response handles HTTP 422 to provide richer error context. + # The response payload must be passed through even when the status is 422. + # See https://docs.snowflake.com/en/developer-guide/sql-api/reference + return status >= 400 and status != 422 + def _make_api_call_with_retries( self, method: str, url: str, headers: dict, params: dict | None = None, json: dict | None = None ): @@ -516,7 +523,8 @@ def _make_api_call_with_retries( # user first, base second => base wins even if guard misses something request_kwargs: dict[str, Any] = {**user_kwargs, **base_request_kwargs} response = session.request(**request_kwargs) - response.raise_for_status() + if self._should_raise_for_status(response.status_code): + response.raise_for_status() return response.status_code, response.json() async def _make_api_call_with_retries_async(self, method, url, headers, params=None): @@ -560,7 +568,8 @@ async def _make_api_call_with_retries_async(self, method, url, headers, params=N } request_kwargs: dict[str, Any] = {**user_request_kwargs, **base_request_kwargs} async with session.request(**request_kwargs) as response: - response.raise_for_status() + if self._should_raise_for_status(response.status): + response.raise_for_status() # Return status and json content for async processing content = await response.json() return response.status, content diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index f553be54cf963..4e89a8900a4c7 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -20,6 +20,7 @@ import base64 import unittest import uuid +from collections.abc import Mapping from typing import TYPE_CHECKING, Any from unittest import mock from unittest.mock import AsyncMock, call @@ -180,13 +181,44 @@ def create_successful_response_mock(content): return response -def create_post_side_effect(status_code=429): - """create mock response for post side effect""" - response = mock.MagicMock() - response.status_code = status_code - response.reason = "test" - response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=response) - return response +def create_post_response( + status_code: int = 429, + *, + json_body: Mapping[str, Any] | None = None, + reason: str = "test", + http_error: BaseException | None = None, + raise_for_status: bool | None = None, +): + """ + Build a mock response object for requests.request/post. + Defaults: + - 2xx/3xx: raise_for_status() does nothing. + - 4xx/5xx: raise_for_status() raises requests.exceptions.HTTPError(response=resp). + Customization: + - json_body: controls resp.json() output. + - http_error: explicitly set what raise_for_status raises (overrides default behavior). + - raise_for_status: force-enable/disable raising regardless of status_code. + """ + resp = mock.MagicMock() + resp.status_code = status_code + resp.reason = reason + resp.json.return_value = dict(json_body) if json_body is not None else {} + + _default_should_raise = status_code >= 400 + + if raise_for_status is None: + should_raise = _default_should_raise + else: + should_raise = raise_for_status + + if http_error is not None: + resp.raise_for_status.side_effect = http_error + elif should_raise: + resp.raise_for_status.side_effect = requests.exceptions.HTTPError(response=resp) + else: + resp.raise_for_status.return_value = None + + return resp def create_async_request_client_response_error(request_info=None, history=None, status_code=429): @@ -203,9 +235,7 @@ def create_async_request_client_response_error(request_info=None, history=None, def create_async_connection_error(): - response = mock.MagicMock() - response.raise_for_status.side_effect = aiohttp.ClientConnectionError() - return response + return aiohttp.ClientConnectionError() def create_async_request_client_response_success(json=GET_RESPONSE, status_code=200): @@ -290,12 +320,14 @@ def test_execute_query_multiple_times_give_fresh_query_ids_each_time( ("sql", "statement_count", "expected_response", "expected_query_ids"), [(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])], ) + @mock.patch(f"{HOOK_PATH}._make_api_call_with_retries") @mock.patch(f"{HOOK_PATH}._get_conn_params") @mock.patch(f"{HOOK_PATH}.get_headers") def test_execute_query_exception_without_statement_handle( self, mock_get_header, mock_conn_param, + mock_make_api_call, sql, statement_count, expected_response, @@ -306,13 +338,12 @@ def test_execute_query_exception_without_statement_handle( Test execute_query method by mocking the exception response and raise airflow exception without statementHandle in the response """ - side_effect = create_post_side_effect() - mock_requests.request.side_effect = side_effect + # status_code, json payload without statementHandle + mock_make_api_call.return_value = (None, {"foo": "bar"}) hook = SnowflakeSqlApiHook("mock_conn_id") - with pytest.raises(AirflowException) as exception_info: + with pytest.raises(AirflowException): hook.execute_query(sql, statement_count) - assert exception_info @pytest.mark.parametrize( ("sql", "statement_count", "bindings"), @@ -358,6 +389,8 @@ def test_check_query_output(self, mock_geturl_header_params, query_ids, mock_req params = {"requestId": str(req_id), "page": 2, "pageSize": 10} mock_geturl_header_params.return_value = HEADERS, params, "/test/airflow/" mock_requests.request.return_value.json.return_value = GET_RESPONSE + # Make sure status code 200 when query is success. + mock_requests.request.return_value.status_code = 200 hook = SnowflakeSqlApiHook("mock_conn_id") with mock.patch.object(hook.log, "info") as mock_log_info: hook.check_query_output(query_ids) @@ -382,7 +415,7 @@ def test_check_query_output_exception( "stop": tenacity.stop_after_attempt(2), # Only 2 attempts instead of default 5 } hook = SnowflakeSqlApiHook("mock_conn_id", api_retry_args=custom_retry_args) - mock_requests.request.side_effect = [create_post_side_effect(status_code=500)] * 3 + mock_requests.request.side_effect = [create_post_response(status_code=500)] * 3 with pytest.raises(requests.exceptions.HTTPError): hook.check_query_output(query_ids) @@ -1514,6 +1547,40 @@ def test_make_api_call_with_retries_rejects_http_request_kwargs_overriding_ident ): hook._make_api_call_with_retries("GET", API_URL, HEADERS) + def test_make_api_call_with_422_does_not_raise_for_status(self, mock_requests): + """Test that HTTP 422 responses do not call raise_for_status and pass through the response body.""" + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + response = create_post_response( + status_code=422, json_body={"code": "Query was failed when runtime..", "message": "sync job"} + ) + mock_requests.request.return_value = response + + status, body = hook._make_api_call_with_retries("GET", API_URL, HEADERS) + + assert status == 422 + assert body == {"code": "Query was failed when runtime..", "message": "sync job"} + + # Validate 422 don't raise http error. + response.raise_for_status.assert_not_called() + # Decode should call once. + response.json.assert_called_once() + + def test_make_api_call_with_500_raises_for_status(self, mock_requests): + """Test that HTTP 500 responses call raise_for_status and do not pass through the response body.""" + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + response = create_post_response(status_code=500, json_body={"error": "internal error"}) + mock_requests.request.return_value = response + + with pytest.raises(requests.exceptions.HTTPError): + hook._make_api_call_with_retries("GET", API_URL, HEADERS) + + # 500 status code should raise HTTPError. + response.raise_for_status.assert_called_once() + # After raise _make_api_call_with_retries will return control immediately. + response.json.assert_not_called() + @pytest.mark.asyncio async def test_make_api_call_with_retries_async_passes_timeout_to_clientsession(self): """ @@ -1611,3 +1678,48 @@ async def test_make_api_call_with_retries_async_rejects_aiohttp_request_kwargs_o match=r"aiohttp_request_kwargs must not override request identity fields", ): await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) + + @pytest.mark.asyncio + async def test_make_api_call_with_422_does_not_raise_for_status_async(self, mock_async_request): + """Test that HTTP 422 responses do not call raise_for_status and pass through the response body (async).""" + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + response = create_async_request_client_response_error(status_code=422) + + # 422 should NOT call raise_for_status! We have to provide a valid JSON payload + response.json = AsyncMock( + return_value={"code": "Query was failed when runtime..", "message": "async job!"} + ) + + # If raise_for_status() is called, this side effect will explode the test! + response.raise_for_status.side_effect = AssertionError( + "raise_for_status should not be called for 422" + ) + + mock_async_request.__aenter__.return_value = response + + status, body = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) + + assert status == 422 + assert body == {"code": "Query was failed when runtime..", "message": "async job!"} + + response.raise_for_status.assert_not_called() + response.json.assert_awaited_once() + + @pytest.mark.asyncio + async def test_make_api_call_with_500_raises_for_status_async(self, mock_async_request): + """Test that HTTP 500 responses call raise_for_status and do not pass through the response body (async).""" + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + + response = create_async_request_client_response_error(status_code=500) + + # If json() is called, the test must fail.. + response.json = AsyncMock(side_effect=AssertionError("json should not be called on 500")) + + mock_async_request.__aenter__.return_value = response + + with pytest.raises(aiohttp.ClientResponseError): + await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS) + + response.raise_for_status.assert_called_once() + response.json.assert_not_called()