diff --git a/docs/source/guides/inference.mdx b/docs/source/guides/inference.mdx index 5b7c9aea71..5cb59e30dd 100644 --- a/docs/source/guides/inference.mdx +++ b/docs/source/guides/inference.mdx @@ -198,7 +198,7 @@ You can catch it and handle it in your code: Some tasks require binary inputs, for example, when dealing with images or audio files. In this case, [`InferenceClient`] tries to be as permissive as possible and accept different types: - raw `bytes` -- a file-like object, opened as binary (`with open("audio.wav", "rb") as f: ...`) +- a file-like object, opened as binary (`with open("audio.flac", "rb") as f: ...`) - a path (`str` or `Path`) pointing to a local file - a URL (`str`) pointing to a remote file (e.g. `https://...`). In this case, the file will be downloaded locally before sending it to the Inference API. diff --git a/src/huggingface_hub/_inference.py b/src/huggingface_hub/_inference.py index 4aab656895..5fa3a47180 100644 --- a/src/huggingface_hub/_inference.py +++ b/src/huggingface_hub/_inference.py @@ -44,6 +44,7 @@ from typing import TYPE_CHECKING, Any, BinaryIO, ContextManager, Dict, Generator, List, Optional, Union, overload from requests import HTTPError, Response +from requests.structures import CaseInsensitiveDict from ._inference_types import ClassificationOutput, ConversationalOutput, ImageSegmentationOutput from .constants import INFERENCE_ENDPOINT @@ -96,17 +97,31 @@ class InferenceClient: or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is automatically selected for the task. token (`str`, *optional*): - Hugging Face token. Will default to the locally saved token. + Hugging Face token. Will default to the locally saved token. Pass `token=False` if you don't want to send + your token to the server. timeout (`float`, `optional`): The maximum number of seconds to wait for a response from the server. Loading a new model in Inference API can take up to several minutes. Defaults to None, meaning it will loop until the server is available. + headers (`Dict[str, str]`, `optional`): + Additional headers to send to the server. By default only the authorization and user-agent headers are sent. + Values in this dictionary will override the default values. + cookies (`Dict[str, str]`, `optional`): + Additional cookies to send to the server. """ def __init__( - self, model: Optional[str] = None, token: Optional[str] = None, timeout: Optional[float] = None + self, + model: Optional[str] = None, + token: Union[str, bool, None] = None, + timeout: Optional[float] = None, + headers: Optional[Dict[str, str]] = None, + cookies: Optional[Dict[str, str]] = None, ) -> None: self.model: Optional[str] = model - self.headers = build_hf_headers(token=token) + self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent' + if headers is not None: + self.headers.update(headers) + self.cookies = cookies self.timeout = timeout def __repr__(self): @@ -157,7 +172,12 @@ def post( with _open_as_binary(data) as data_as_binary: try: response = get_session().post( - url, json=json, data=data_as_binary, headers=self.headers, timeout=self.timeout + url, + json=json, + data=data_as_binary, + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, ) except TimeoutError as error: # Convert any `TimeoutError` to a `InferenceTimeoutError` @@ -214,7 +234,7 @@ def audio_classification( ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() - >>> client.audio_classification("audio.wav") + >>> client.audio_classification("audio.flac") [{'score': 0.4976358711719513, 'label': 'hap'}, {'score': 0.3677836060523987, 'label': 'neu'},...] ``` """ @@ -250,7 +270,7 @@ def automatic_speech_recognition( ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() - >>> client.automatic_speech_recognition("hello_world.wav") + >>> client.automatic_speech_recognition("hello_world.flac") "hello world" ``` """ @@ -760,7 +780,7 @@ def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes: >>> client = InferenceClient() >>> audio = client.text_to_speech("Hello world") - >>> Path("hello_world.wav").write_bytes(audio) + >>> Path("hello_world.flac").write_bytes(audio) ``` """ response = self.post(json={"inputs": text}, model=model, task="text-to-speech") diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 7bf57d81e4..8b0b45b7ef 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -14,6 +14,7 @@ import io import unittest from pathlib import Path +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -21,6 +22,7 @@ from huggingface_hub import InferenceClient, hf_hub_download from huggingface_hub._inference import _open_as_binary +from huggingface_hub.utils import build_hf_headers from .testing_utils import with_production_testing @@ -251,3 +253,37 @@ def test_recommended_model_from_supported_task(self) -> None: def test_unsupported_task(self) -> None: with self.assertRaises(NotImplementedError): InferenceClient()._resolve_url(task="unknown-task") + + +class TestHeadersAndCookies(unittest.TestCase): + def test_headers_and_cookies(self) -> None: + client = InferenceClient(headers={"X-My-Header": "foo"}, cookies={"my-cookie": "bar"}) + self.assertEqual(client.headers["X-My-Header"], "foo") + self.assertEqual(client.cookies["my-cookie"], "bar") + + def test_headers_overwrite(self) -> None: + # Default user agent + self.assertTrue(InferenceClient().headers["user-agent"].startswith("unknown/None;")) + + # Overwritten user-agent + self.assertEqual(InferenceClient(headers={"user-agent": "bar"}).headers["user-agent"], "bar") + + # Case-insensitive overwrite + self.assertEqual(InferenceClient(headers={"USER-agent": "bar"}).headers["user-agent"], "bar") + + @patch("huggingface_hub._inference.get_session") + def test_mocked_post(self, get_session_mock: MagicMock) -> None: + """Test that headers and cookies are correctly passed to the request.""" + client = InferenceClient(headers={"X-My-Header": "foo"}, cookies={"my-cookie": "bar"}) + response = client.post(data=b"content", model="username/repo_name") + self.assertEqual(response, get_session_mock().post.return_value) + + expected_user_agent = build_hf_headers()["user-agent"] + get_session_mock().post.assert_called_once_with( + "https://api-inference.huggingface.co/models/username/repo_name", + json=None, + data=b"content", + headers={"user-agent": expected_user_agent, "X-My-Header": "foo"}, + cookies={"my-cookie": "bar"}, + timeout=None, + )