From 644cc82837e106973bf86768789cd04ba8a0ff2d Mon Sep 17 00:00:00 2001 From: Justin Li Date: Mon, 13 Jan 2025 21:13:02 -0500 Subject: [PATCH] allow setting custom headers --- tests/test_backend.py | 21 +++++++++++++++++++++ turbopuffer/backend.py | 8 ++++++-- turbopuffer/namespace.py | 4 ++-- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/tests/test_backend.py b/tests/test_backend.py index 64814aa..7927b0f 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -40,3 +40,24 @@ def test_429_retried(): with pytest.raises(tpuf.error.APIError): backend.make_api_request('namespaces', payload={}) assert sleep.call_count == tpuf.max_retries - 1 + + +def test_custom_headers(): + backend = tpuf_backend.Backend("fake_api_key", headers = {"foo": "bar"}) + assert backend.session.headers["foo"] == "bar" + + ns = tpuf.Namespace('fake_namespace', headers = {"foo": "bar"}) + assert ns.backend.session.headers["foo"] == "bar" + + +def test_backend_eq(): + backend = tpuf_backend.Backend("fake_api_key", headers = {"foo": "bar"}) + + backend2 = tpuf_backend.Backend("fake_api_key", headers = {"foo": "notbar"}) + assert backend != backend2 + + backend2 = tpuf_backend.Backend("fake_api_key", headers = {"foo": "bar"}) + assert backend == backend2 + + backend2 = tpuf_backend.Backend("fake_api_key2", headers = {"foo": "bar"}) + assert backend != backend2 diff --git a/turbopuffer/backend.py b/turbopuffer/backend.py index 5bfaa57..b8e5ced 100644 --- a/turbopuffer/backend.py +++ b/turbopuffer/backend.py @@ -31,18 +31,22 @@ class Backend: api_base_url: str session: requests.Session - def __init__(self, api_key: Optional[str] = None): + def __init__(self, api_key: Optional[str] = None, headers: Optional[dict] = None): self.api_key = find_api_key(api_key) self.api_base_url = clean_api_base_url(tpuf.api_base_url) + self.headers = headers self.session = requests.Session() self.session.headers.update({ 'Authorization': f'Bearer {self.api_key}', 'User-Agent': f'tpuf-python/{tpuf.VERSION} {requests.utils.default_headers()["User-Agent"]}', }) + if headers is not None: + self.session.headers.update(headers) + def __eq__(self, other): if isinstance(other, Backend): - return self.api_key == other.api_key and self.api_base_url == other.api_base_url + return self.api_key == other.api_key and self.api_base_url == other.api_base_url and self.headers == other.headers else: return False diff --git a/turbopuffer/namespace.py b/turbopuffer/namespace.py index b1aa29a..9c0a38a 100644 --- a/turbopuffer/namespace.py +++ b/turbopuffer/namespace.py @@ -92,7 +92,7 @@ class Namespace: metadata: Optional[dict] = None - def __init__(self, name: str, api_key: Optional[str] = None): + def __init__(self, name: str, api_key: Optional[str] = None, headers: Optional[dict] = None): """ Creates a new turbopuffer.Namespace object for querying the turbopuffer API. @@ -101,7 +101,7 @@ def __init__(self, name: str, api_key: Optional[str] = None): Specifying an api_key here will override the global configuration for API calls to this namespace. """ self.name = name - self.backend = Backend(api_key) + self.backend = Backend(api_key, headers) def __str__(self) -> str: return f'tpuf-namespace:{self.name}'