Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom headers/cookies in InferenceClient #1507

Merged
merged 1 commit into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/guides/inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ You can catch it and handle it in your code:
Some tasks require binary inputs, for example, when dealing with images or audio files. In this case, [`InferenceClient`]
tries to be as permissive as possible and accept different types:
- raw `bytes`
- a file-like object, opened as binary (`with open("audio.wav", "rb") as f: ...`)
- a file-like object, opened as binary (`with open("audio.flac", "rb") as f: ...`)
- a path (`str` or `Path`) pointing to a local file
- a URL (`str`) pointing to a remote file (e.g. `https://...`). In this case, the file will be downloaded locally before
sending it to the Inference API.
Expand Down
34 changes: 27 additions & 7 deletions src/huggingface_hub/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from typing import TYPE_CHECKING, Any, BinaryIO, ContextManager, Dict, Generator, List, Optional, Union, overload

from requests import HTTPError, Response
from requests.structures import CaseInsensitiveDict

from ._inference_types import ClassificationOutput, ConversationalOutput, ImageSegmentationOutput
from .constants import INFERENCE_ENDPOINT
Expand Down Expand Up @@ -96,17 +97,31 @@ class InferenceClient:
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
automatically selected for the task.
token (`str`, *optional*):
Hugging Face token. Will default to the locally saved token.
Hugging Face token. Will default to the locally saved token. Pass `token=False` if you don't want to send
your token to the server.
timeout (`float`, `optional`):
The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
headers (`Dict[str, str]`, `optional`):
Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
Values in this dictionary will override the default values.
cookies (`Dict[str, str]`, `optional`):
Additional cookies to send to the server.
"""

def __init__(
self, model: Optional[str] = None, token: Optional[str] = None, timeout: Optional[float] = None
self,
model: Optional[str] = None,
token: Union[str, bool, None] = None,
timeout: Optional[float] = None,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
) -> None:
self.model: Optional[str] = model
self.headers = build_hf_headers(token=token)
self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
if headers is not None:
self.headers.update(headers)
self.cookies = cookies
self.timeout = timeout

def __repr__(self):
Expand Down Expand Up @@ -157,7 +172,12 @@ def post(
with _open_as_binary(data) as data_as_binary:
try:
response = get_session().post(
url, json=json, data=data_as_binary, headers=self.headers, timeout=self.timeout
url,
json=json,
data=data_as_binary,
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
except TimeoutError as error:
# Convert any `TimeoutError` to a `InferenceTimeoutError`
Expand Down Expand Up @@ -214,7 +234,7 @@ def audio_classification(
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> client.audio_classification("audio.wav")
>>> client.audio_classification("audio.flac")
[{'score': 0.4976358711719513, 'label': 'hap'}, {'score': 0.3677836060523987, 'label': 'neu'},...]
```
"""
Expand Down Expand Up @@ -250,7 +270,7 @@ def automatic_speech_recognition(
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> client.automatic_speech_recognition("hello_world.wav")
>>> client.automatic_speech_recognition("hello_world.flac")
"hello world"
```
"""
Expand Down Expand Up @@ -760,7 +780,7 @@ def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes:
>>> client = InferenceClient()

>>> audio = client.text_to_speech("Hello world")
>>> Path("hello_world.wav").write_bytes(audio)
>>> Path("hello_world.flac").write_bytes(audio)
```
"""
response = self.post(json={"inputs": text}, model=model, task="text-to-speech")
Expand Down
36 changes: 36 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import io
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
from PIL import Image

from huggingface_hub import InferenceClient, hf_hub_download
from huggingface_hub._inference import _open_as_binary
from huggingface_hub.utils import build_hf_headers

from .testing_utils import with_production_testing

Expand Down Expand Up @@ -251,3 +253,37 @@ def test_recommended_model_from_supported_task(self) -> None:
def test_unsupported_task(self) -> None:
with self.assertRaises(NotImplementedError):
InferenceClient()._resolve_url(task="unknown-task")


class TestHeadersAndCookies(unittest.TestCase):
def test_headers_and_cookies(self) -> None:
client = InferenceClient(headers={"X-My-Header": "foo"}, cookies={"my-cookie": "bar"})
self.assertEqual(client.headers["X-My-Header"], "foo")
self.assertEqual(client.cookies["my-cookie"], "bar")

def test_headers_overwrite(self) -> None:
# Default user agent
self.assertTrue(InferenceClient().headers["user-agent"].startswith("unknown/None;"))

# Overwritten user-agent
self.assertEqual(InferenceClient(headers={"user-agent": "bar"}).headers["user-agent"], "bar")

# Case-insensitive overwrite
self.assertEqual(InferenceClient(headers={"USER-agent": "bar"}).headers["user-agent"], "bar")

@patch("huggingface_hub._inference.get_session")
def test_mocked_post(self, get_session_mock: MagicMock) -> None:
"""Test that headers and cookies are correctly passed to the request."""
client = InferenceClient(headers={"X-My-Header": "foo"}, cookies={"my-cookie": "bar"})
response = client.post(data=b"content", model="username/repo_name")
self.assertEqual(response, get_session_mock().post.return_value)

expected_user_agent = build_hf_headers()["user-agent"]
get_session_mock().post.assert_called_once_with(
"https://api-inference.huggingface.co/models/username/repo_name",
json=None,
data=b"content",
headers={"user-agent": expected_user_agent, "X-My-Header": "foo"},
cookies={"my-cookie": "bar"},
timeout=None,
)