Skip to content

Commit

Permalink
rest_api: allow specifying custom session (feat/1843) (#1844)
Browse files Browse the repository at this point in the history
* allows requests.Session in ClientConfig

* allows requests_oauthlib by using `requests.Session` instead of `dlt.sources.helpers.requests.Session`
  • Loading branch information
willi-mueller authored Sep 23, 2024
1 parent 7b4209e commit a39ebad
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 1 deletion.
1 change: 1 addition & 0 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def create_resources(
headers=client_config.get("headers"),
auth=create_auth(client_config.get("auth")),
paginator=create_paginator(client_config.get("paginator")),
session=client_config.get("session"),
)

hooks = create_response_hooks(endpoint_config.get("response_actions"))
Expand Down
3 changes: 3 additions & 0 deletions dlt/sources/rest_api/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from dlt.extract.items import TTableHintTemplate
from dlt.extract.incremental.typing import LastValueFunc

from requests import Session

from dlt.sources.helpers.rest_client.typing import HTTPMethodBasic

from dlt.sources.helpers.rest_client.paginators import (
Expand Down Expand Up @@ -187,6 +189,7 @@ class ClientConfig(TypedDict, total=False):
headers: Optional[Dict[str, str]]
auth: Optional[AuthConfig]
paginator: Optional[PaginatorConfig]
session: Optional[Session]


class IncrementalRESTArgs(IncrementalArgs, total=False):
Expand Down
20 changes: 20 additions & 0 deletions tests/sources/rest_api/configurations/source_configs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from collections import namedtuple
from typing import cast, List

import requests
import dlt
import dlt.common
from dlt.common.typing import TSecretStrValue
from dlt.common.exceptions import DictValidationException
from dlt.common.configuration.specs import configspec

import dlt.sources.helpers
import dlt.sources.helpers.requests
from dlt.sources.helpers.rest_client.paginators import HeaderLinkPaginator
from dlt.sources.helpers.rest_client.auth import OAuth2AuthBase

Expand Down Expand Up @@ -304,6 +309,21 @@ class CustomOAuthAuth(OAuth2AuthBase):
},
],
},
{
"client": {
"base_url": "https://example.com",
"session": requests.Session(),
},
"resources": ["users"],
},
{
"client": {
"base_url": "https://example.com",
# This is a subclass of requests.Session and is thus also allowed
"session": dlt.sources.helpers.requests.Session(),
},
"resources": ["users"],
},
]


Expand Down
33 changes: 32 additions & 1 deletion tests/sources/rest_api/integration/test_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest import mock

import pytest
from requests import Request, Response
from requests import Request, Response, Session

import dlt
from dlt.common import pendulum
Expand Down Expand Up @@ -327,3 +327,34 @@ def test_posts_with_inremental_date_conversion(mock_api_server) -> None:
_, called_kwargs = mock_paginate.call_args_list[0]
assert called_kwargs["params"] == {"since": "1970-01-01", "until": "1970-01-02"}
assert called_kwargs["path"] == "posts"


def test_multiple_response_actions_on_every_response(mock_api_server, mocker):
class CustomSession(Session):
pass

def send_spy(*args, **kwargs):
return original_send(*args, **kwargs)

my_session = CustomSession()
original_send = my_session.send
mocked_send = mocker.patch.object(my_session, "send", side_effect=send_spy)

source = rest_api_source(
{
"client": {
"base_url": "https://api.example.com",
"session": my_session,
},
"resources": [
{
"name": "posts",
},
],
}
)

list(source.with_resources("posts").add_limit(1))

mocked_send.assert_called_once()
assert mocked_send.call_args[0][0].url == "https://api.example.com/posts"

0 comments on commit a39ebad

Please sign in to comment.