diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index 86e72ccf4c..683a2a9fcd 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -78,10 +78,13 @@ def __init__( self.auth = auth if session: - # dlt.sources.helpers.requests.session.Session - # has raise_for_status=True by default + # If the `session` is provided (for example, an instance of + # dlt.sources.helpers.requests.session.Session), warn if + # it has raise_for_status=True by default self.session = _warn_if_raise_for_status_and_return(session) else: + # Otherwise, create a new Client with disabled raise_for_status + # to allow for custom error handling in the hooks from dlt.sources.helpers.requests.retry import Client self.session = Client(raise_for_status=False).session @@ -182,9 +185,9 @@ def paginate( **kwargs (Any): Optional arguments to that the Request library accepts, such as `stream`, `verify`, `proxies`, `cert`, `timeout`, and `allow_redirects`. - Yields: - PageData[Any]: A page of data from the paginated API response, along with request and response context. + PageData[Any]: A page of data from the paginated API response, along with request + and response context. Raises: HTTPError: If the response status code is not a success code. This is raised @@ -200,9 +203,9 @@ def paginate( data_selector = data_selector or self.data_selector hooks = hooks or {} - def raise_for_status(response: Response, *args: Any, **kwargs: Any) -> None: - response.raise_for_status() - + # Add the raise_for_status hook to ensure an exception is raised on + # HTTP error status codes. This is a fallback to handle errors + # unless explicitly overridden in the provided hooks. if "response" not in hooks: hooks["response"] = [raise_for_status] @@ -305,6 +308,10 @@ def detect_paginator(self, response: Response, data: Any) -> BasePaginator: return paginator +def raise_for_status(response: Response, *args: Any, **kwargs: Any) -> None: + response.raise_for_status() + + def _warn_if_raise_for_status_and_return(session: BaseSession) -> BaseSession: """A generic function to warn if the session has raise_for_status enabled.""" if getattr(session, "raise_for_status", False): @@ -312,3 +319,4 @@ def _warn_if_raise_for_status_and_return(session: BaseSession) -> BaseSession: "The session provided has raise_for_status enabled. This may cause unexpected behavior." ) return session + diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index 0f9857b45a..b11f2799b9 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -50,6 +50,7 @@ APIKeyAuth, OAuth2ClientCredentials, ) +from dlt.sources.helpers.rest_client.client import raise_for_status from dlt.extract.resource import DltResource @@ -530,12 +531,6 @@ def response_action_hook(response: Response, *args: Any, **kwargs: Any) -> None: ) raise IgnoreResponseException - # If there are hooks, then the REST client does not raise for status - # If no action has been taken and the status code indicates an error, - # raise an HTTP error based on the response status - elif not action_type: - response.raise_for_status() - return response_action_hook @@ -570,7 +565,8 @@ def remove_field(response: Response, *args, **kwargs) -> Response: """ if response_actions: hooks = [_create_response_action_hook(action) for action in response_actions] - return {"response": hooks} + fallback_hooks = [raise_for_status] + return {"response": hooks + fallback_hooks} return None diff --git a/tests/sources/rest_api/conftest.py b/tests/sources/rest_api/conftest.py index 7f20dc2252..bc58a18e5c 100644 --- a/tests/sources/rest_api/conftest.py +++ b/tests/sources/rest_api/conftest.py @@ -139,7 +139,21 @@ def post_detail_404(request, context): return {"id": post_id, "body": f"Post body {post_id}"} else: context.status_code = 404 - return {"error": "Post not found"} + return {"error": f"Post with id {post_id} not found"} + + @router.get(r"/posts/\d+/some_details_404_others_422") + def post_detail_404_422(request, context): + """Return 404 No Content for post with id 1. Return 422 for post with id > 1. + Used to test ignoring 404 and 422 responses.""" + post_id = int(request.url.split("/")[-2]) + if post_id < 1: + return {"id": post_id, "body": f"Post body {post_id}"} + elif post_id == 1: + context.status_code = 404 + return {"error": f"Post with id {post_id} not found"} + else: + context.status_code = 422 + return None @router.get(r"/posts/\d+/some_details_204") def post_detail_204(request, context): diff --git a/tests/sources/rest_api/integration/test_offline.py b/tests/sources/rest_api/integration/test_offline.py index cb91e0d680..d91cf0c0aa 100644 --- a/tests/sources/rest_api/integration/test_offline.py +++ b/tests/sources/rest_api/integration/test_offline.py @@ -118,6 +118,54 @@ def test_ignoring_endpoint_returning_404(mock_api_server): } }, "response_actions": [ + { + "status_code": 422, + "action": "ignore", + }, + { + "status_code": 404, + "action": "ignore", + }, + ], + }, + }, + ], + } + ) + + res = list(mock_source.with_resources("posts", "post_details").add_limit(1)) + + assert res[:5] == [ + {"id": 0, "body": "Post body 0"}, + {"id": 0, "title": "Post 0"}, + {"id": 1, "title": "Post 1"}, + {"id": 2, "title": "Post 2"}, + {"id": 3, "title": "Post 3"}, + ] + + +def test_ignoring_endpoint_returning_404_others_422(mock_api_server): + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + "posts", + { + "name": "post_details", + "endpoint": { + "path": "posts/{post_id}/some_details_404_others_422", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + } + }, + "response_actions": [ + { + "status_code": 422, + "action": "ignore", + }, { "status_code": 404, "action": "ignore", diff --git a/tests/sources/rest_api/integration/test_response_actions.py b/tests/sources/rest_api/integration/test_response_actions.py index 36a7990db3..da5011077d 100644 --- a/tests/sources/rest_api/integration/test_response_actions.py +++ b/tests/sources/rest_api/integration/test_response_actions.py @@ -1,10 +1,14 @@ from dlt.common import json from dlt.sources.helpers.requests import Response +from dlt.sources.helpers.rest_client.exceptions import IgnoreResponseException from dlt.sources.rest_api import create_response_hooks, rest_api_source def test_response_action_on_status_code(mock_api_server, mocker): - mock_response_hook = mocker.Mock() + def custom_hook(response, *args, **kwargs): + raise IgnoreResponseException + + mock_response_hook = mocker.Mock(side_effect=custom_hook) mock_source = rest_api_source( { "client": {"base_url": "https://api.example.com"}, @@ -108,7 +112,7 @@ def add_field(response: Response, *args, **kwargs) -> Response: {"status_code": 200, "action": mock_response_hook_2}, ] hooks = create_response_hooks(response_actions) - assert len(hooks.get("response")) == 2 + assert len(hooks.get("response")) == 3 # 2 custom hooks + 1 fallback hook mock_source = rest_api_source( {