Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,13 @@ def _should_retry_on_error(exception) -> bool:
return True
return False

@staticmethod
def _should_raise_for_status(status: int) -> bool:
# _process_response handles HTTP 422 to provide richer error context.
# The response payload must be passed through even when the status is 422.
# See https://docs.snowflake.com/en/developer-guide/sql-api/reference
return status >= 400 and status != 422

def _make_api_call_with_retries(
self, method: str, url: str, headers: dict, params: dict | None = None, json: dict | None = None
):
Expand Down Expand Up @@ -516,7 +523,8 @@ def _make_api_call_with_retries(
# user first, base second => base wins even if guard misses something
request_kwargs: dict[str, Any] = {**user_kwargs, **base_request_kwargs}
response = session.request(**request_kwargs)
response.raise_for_status()
if self._should_raise_for_status(response.status_code):
response.raise_for_status()
return response.status_code, response.json()

async def _make_api_call_with_retries_async(self, method, url, headers, params=None):
Expand Down Expand Up @@ -560,7 +568,8 @@ async def _make_api_call_with_retries_async(self, method, url, headers, params=N
}
request_kwargs: dict[str, Any] = {**user_request_kwargs, **base_request_kwargs}
async with session.request(**request_kwargs) as response:
response.raise_for_status()
if self._should_raise_for_status(response.status):
response.raise_for_status()
# Return status and json content for async processing
content = await response.json()
return response.status, content
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import base64
import unittest
import uuid
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from unittest import mock
from unittest.mock import AsyncMock, call
Expand Down Expand Up @@ -180,13 +181,44 @@ def create_successful_response_mock(content):
return response


def create_post_side_effect(status_code=429):
"""create mock response for post side effect"""
response = mock.MagicMock()
response.status_code = status_code
response.reason = "test"
response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=response)
return response
def create_post_response(
status_code: int = 429,
*,
json_body: Mapping[str, Any] | None = None,
reason: str = "test",
http_error: BaseException | None = None,
raise_for_status: bool | None = None,
):
"""
Build a mock response object for requests.request/post.
Defaults:
- 2xx/3xx: raise_for_status() does nothing.
- 4xx/5xx: raise_for_status() raises requests.exceptions.HTTPError(response=resp).
Customization:
- json_body: controls resp.json() output.
- http_error: explicitly set what raise_for_status raises (overrides default behavior).
- raise_for_status: force-enable/disable raising regardless of status_code.
"""
resp = mock.MagicMock()
resp.status_code = status_code
resp.reason = reason
resp.json.return_value = dict(json_body) if json_body is not None else {}

_default_should_raise = status_code >= 400

if raise_for_status is None:
should_raise = _default_should_raise
else:
should_raise = raise_for_status

if http_error is not None:
resp.raise_for_status.side_effect = http_error
elif should_raise:
resp.raise_for_status.side_effect = requests.exceptions.HTTPError(response=resp)
else:
resp.raise_for_status.return_value = None

return resp


def create_async_request_client_response_error(request_info=None, history=None, status_code=429):
Expand All @@ -203,9 +235,7 @@ def create_async_request_client_response_error(request_info=None, history=None,


def create_async_connection_error():
response = mock.MagicMock()
response.raise_for_status.side_effect = aiohttp.ClientConnectionError()
return response
return aiohttp.ClientConnectionError()


def create_async_request_client_response_success(json=GET_RESPONSE, status_code=200):
Expand Down Expand Up @@ -290,12 +320,14 @@ def test_execute_query_multiple_times_give_fresh_query_ids_each_time(
("sql", "statement_count", "expected_response", "expected_query_ids"),
[(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])],
)
@mock.patch(f"{HOOK_PATH}._make_api_call_with_retries")
@mock.patch(f"{HOOK_PATH}._get_conn_params")
@mock.patch(f"{HOOK_PATH}.get_headers")
def test_execute_query_exception_without_statement_handle(
self,
mock_get_header,
mock_conn_param,
mock_make_api_call,
sql,
statement_count,
expected_response,
Expand All @@ -306,13 +338,12 @@ def test_execute_query_exception_without_statement_handle(
Test execute_query method by mocking the exception response and raise airflow exception
without statementHandle in the response
"""
side_effect = create_post_side_effect()
mock_requests.request.side_effect = side_effect
# status_code, json payload without statementHandle
mock_make_api_call.return_value = (None, {"foo": "bar"})
hook = SnowflakeSqlApiHook("mock_conn_id")

with pytest.raises(AirflowException) as exception_info:
with pytest.raises(AirflowException):
hook.execute_query(sql, statement_count)
assert exception_info

@pytest.mark.parametrize(
("sql", "statement_count", "bindings"),
Expand Down Expand Up @@ -358,6 +389,8 @@ def test_check_query_output(self, mock_geturl_header_params, query_ids, mock_req
params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
mock_geturl_header_params.return_value = HEADERS, params, "/test/airflow/"
mock_requests.request.return_value.json.return_value = GET_RESPONSE
# Make sure status code 200 when query is success.
mock_requests.request.return_value.status_code = 200
hook = SnowflakeSqlApiHook("mock_conn_id")
with mock.patch.object(hook.log, "info") as mock_log_info:
hook.check_query_output(query_ids)
Expand All @@ -382,7 +415,7 @@ def test_check_query_output_exception(
"stop": tenacity.stop_after_attempt(2), # Only 2 attempts instead of default 5
}
hook = SnowflakeSqlApiHook("mock_conn_id", api_retry_args=custom_retry_args)
mock_requests.request.side_effect = [create_post_side_effect(status_code=500)] * 3
mock_requests.request.side_effect = [create_post_response(status_code=500)] * 3
with pytest.raises(requests.exceptions.HTTPError):
hook.check_query_output(query_ids)

Expand Down Expand Up @@ -1514,6 +1547,40 @@ def test_make_api_call_with_retries_rejects_http_request_kwargs_overriding_ident
):
hook._make_api_call_with_retries("GET", API_URL, HEADERS)

def test_make_api_call_with_422_does_not_raise_for_status(self, mock_requests):
"""Test that HTTP 422 responses do not call raise_for_status and pass through the response body."""
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")

response = create_post_response(
status_code=422, json_body={"code": "Query was failed when runtime..", "message": "sync job"}
)
mock_requests.request.return_value = response

status, body = hook._make_api_call_with_retries("GET", API_URL, HEADERS)

assert status == 422
assert body == {"code": "Query was failed when runtime..", "message": "sync job"}

# Validate 422 don't raise http error.
response.raise_for_status.assert_not_called()
# Decode should call once.
response.json.assert_called_once()

def test_make_api_call_with_500_raises_for_status(self, mock_requests):
"""Test that HTTP 500 responses call raise_for_status and do not pass through the response body."""
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")

response = create_post_response(status_code=500, json_body={"error": "internal error"})
mock_requests.request.return_value = response

with pytest.raises(requests.exceptions.HTTPError):
hook._make_api_call_with_retries("GET", API_URL, HEADERS)

# 500 status code should raise HTTPError.
response.raise_for_status.assert_called_once()
# After raise _make_api_call_with_retries will return control immediately.
response.json.assert_not_called()

@pytest.mark.asyncio
async def test_make_api_call_with_retries_async_passes_timeout_to_clientsession(self):
"""
Expand Down Expand Up @@ -1611,3 +1678,48 @@ async def test_make_api_call_with_retries_async_rejects_aiohttp_request_kwargs_o
match=r"aiohttp_request_kwargs must not override request identity fields",
):
await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS)

@pytest.mark.asyncio
async def test_make_api_call_with_422_does_not_raise_for_status_async(self, mock_async_request):
"""Test that HTTP 422 responses do not call raise_for_status and pass through the response body (async)."""
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")

response = create_async_request_client_response_error(status_code=422)

# 422 should NOT call raise_for_status! We have to provide a valid JSON payload
response.json = AsyncMock(
return_value={"code": "Query was failed when runtime..", "message": "async job!"}
)

# If raise_for_status() is called, this side effect will explode the test!
response.raise_for_status.side_effect = AssertionError(
"raise_for_status should not be called for 422"
)

mock_async_request.__aenter__.return_value = response

status, body = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS)

assert status == 422
assert body == {"code": "Query was failed when runtime..", "message": "async job!"}

response.raise_for_status.assert_not_called()
response.json.assert_awaited_once()

@pytest.mark.asyncio
async def test_make_api_call_with_500_raises_for_status_async(self, mock_async_request):
"""Test that HTTP 500 responses call raise_for_status and do not pass through the response body (async)."""
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")

response = create_async_request_client_response_error(status_code=500)

# If json() is called, the test must fail..
response.json = AsyncMock(side_effect=AssertionError("json should not be called on 500"))

mock_async_request.__aenter__.return_value = response

with pytest.raises(aiohttp.ClientResponseError):
await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS)

response.raise_for_status.assert_called_once()
response.json.assert_not_called()